Skip to content

Commit

Permalink
removed match, case
Browse files Browse the repository at this point in the history
  • Loading branch information
Xeadriel committed May 23, 2024
1 parent 2a03a36 commit 907fa72
Show file tree
Hide file tree
Showing 10 changed files with 3,257 additions and 24,416 deletions.
64 changes: 32 additions & 32 deletions scripts/fineTune.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ def fineTune(args):

sentencePairs = []
labels = []
match args.task:
case "MRPC":
sentencePairs, labels = prepareMRPCData('scripts/tasks/MRPC/msr_paraphrase_train.csv')
evalPairs = sentencePairs[int(len(sentencePairs[0]) / 10):]
evalLabels = labels[int(len(sentencePairs[0]) / 10):]
case "SST-2":
sentencePairs, labels, evalPairs, evalLabels, _, _ = prepareSSTGlueData('scripts/tasks/SSTGLUE/train.tsv',
if args.task == "MRPC":
sentencePairs, labels = prepareMRPCData('scripts/tasks/MRPC/msr_paraphrase_train.csv')
evalPairs = sentencePairs[int(len(sentencePairs[0]) / 10):]
evalLabels = labels[int(len(sentencePairs[0]) / 10):]
elif args.task == "SST-2":
sentencePairs, labels, evalPairs, evalLabels, _, _ = prepareSSTGlueData('scripts/tasks/SSTGLUE/train.tsv',
'scripts/tasks/SSTGLUE/dev.tsv', 'scripts/tasks/SSTGLUE/test.tsv')
case "RTE":
sentencePairs, labels, evalPairs, evalLabels, _, _ = prepareRTEGlueData('scripts/tasks/RTE/train.tsv',
'scripts/tasks/RTE/dev.tsv', 'scripts/tasks/RTE/test.tsv')
elif args.task == "RTE":
sentencePairs, labels, evalPairs, evalLabels, _, _ = prepareRTEGlueData('scripts/tasks/RTE/train.tsv',
'scripts/tasks/RTE/dev.tsv', 'scripts/tasks/RTE/test.tsv')

indices = list(range(len(sentencePairs[0])))

logger.info(model)
Expand Down Expand Up @@ -113,7 +113,7 @@ def fineTune(args):
logits = predictSentencePairs(model, iterationPairs)

lossSim = F.binary_cross_entropy_with_logits(
torch.tensor(logits, dtype=torch.float, device=device, requires_grad=True),
logits,
torch.tensor(iterationLabels, dtype=torch.float, device=device, requires_grad=False), reduction="none"
)
loss.add("sim", lossSim)
Expand Down Expand Up @@ -173,20 +173,20 @@ def fineTune(args):

"""--------------- similarity benchmark ---------------"""
with Timer(elapsed, "benchmark on evaluation data", sync_cuda=True):
match args.task:
case "MRPC":
#F1 score
simscore = benchmarkMRPC(model, evalPairs, evalLabels)[4]
case "SST-2":
#threshold search based predictions
predictions = predictSSTGlue(model, evalPairs, evalPairs, evalLabels)[1]
#accuracy
simscore = sum([predictions[x] == evalLabels[x] for x in range(len(predictions))]) / len(predictions)
case "RTE":
#threshold search based predictions
predictions = predictRTE(model, evalPairs, evalPairs, evalLabels)[1]
#accuracy
simscore = sum([predictions[x] == evalLabels[x] for x in range(len(predictions))]) / len(predictions)
if args.task == "MRPC":
#F1 score
simscore = benchmarkMRPC(model, evalPairs, evalLabels)[4]
elif args.task == "SST-2":
#threshold search based predictions
predictions = predictSSTGlue(model, evalPairs, evalPairs, evalLabels)[1]
#accuracy
simscore = sum([predictions[x] == evalLabels[x] for x in range(len(predictions))]) / len(predictions)
elif args.task == "RTE":
#threshold search based predictions
predictions = predictRTE(model, evalPairs, evalPairs, evalLabels)[1]
#accuracy
simscore = sum([predictions[x] == evalLabels[x] for x in range(len(predictions))]) / len(predictions)


if simscore > best_simscore:
best_iter = i
Expand Down Expand Up @@ -252,18 +252,18 @@ def predictSentencePairs(
sents1 = [ [w for w in sent if w in sif_weights and w != vocab.unk] for sent in pairs[0] ]
sents2 = [ [w for w in sent if w in sif_weights and w != vocab.unk] for sent in pairs[1] ]

similarities = torch.zeros((len(sents1), 1), requires_grad=False, device=device)
similarities = torch.zeros((len(sents1), 1), requires_grad=True, device=device)

for i in range(len(sents1)):
sent1 = sents1[i]
sent2 = sents2[i]
for y in range(len(sents1)):
sent1 = sents1[y]
sent2 = sents2[y]

similarity = torch.zeros((1, 1), requires_grad=False, device=device)
similarity = 0
for word1 in sent1:
for word2 in sent2:
similarity[0] += model[word1] * model[word2] * sif_weights[word1] * sif_weights[word2]
similarity += model[word1] * model[word2] * sif_weights[word1] * sif_weights[word2]

similarities[i] += similarity[0]
similarities[y] += similarity

return similarities

Expand Down
Loading

0 comments on commit 907fa72

Please sign in to comment.