Skip to content

Commit

Permalink
fixed wrong reads
Browse files Browse the repository at this point in the history
  • Loading branch information
Xeadriel committed Jun 3, 2024
1 parent 7b3f514 commit 8a02d86
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions scripts/fineTune.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def fineTune(args):
labels = []
if args.task == "MRPC":
sentencePairs, labels = prepareMRPCData('scripts/tasks/MRPC/msr_paraphrase_train.csv')
evalPairs = sentencePairs[int(len(sentencePairs[0]) / 10):]
evalLabels = labels[int(len(sentencePairs[0]) / 10):]
evalPairs = [sentencePairs[0][:400], sentencePairs[1][:400]]
evalLabels = labels[:400]
sentencePairs = [sentencePairs[0][400:], sentencePairs[1][400:]]
labels = labels[400:]
elif args.task == "SST-2":
sentencePairs, labels, evalPairs, evalLabels, _, _ = prepareSSTGlueData('scripts/tasks/SSTGLUE/train.tsv',
'scripts/tasks/SSTGLUE/dev.tsv', 'scripts/tasks/SSTGLUE/test.tsv')
Expand All @@ -62,6 +64,9 @@ def fineTune(args):
model = model.to(device)
model.train()

best_simscore = -9999
best_loss = 9999

if args.optimizer == "adamw":
optimizer = AdamW(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
Expand Down Expand Up @@ -252,8 +257,8 @@ def predictSentencePairs(
sents1 = [ [w for w in sent if w in sif_weights and w != vocab.unk] for sent in pairs[0] ]
sents2 = [ [w for w in sent if w in sif_weights and w != vocab.unk] for sent in pairs[1] ]

similarities = torch.zeros((len(sents1), 1), requires_grad=True, device=device)
similaritiesTmp = torch.zeros((len(sents1), 1), requires_grad=False, device=device)
similarities = torch.zeros(len(sents1), requires_grad=True, device=device)
similaritiesTmp = torch.zeros(len(sents1), requires_grad=False, device=device)

for y in range(len(sents1)):
sent1 = sents1[y]
Expand All @@ -262,9 +267,9 @@ def predictSentencePairs(

for word1 in sent1:
for word2 in sent2:
similaritiesTmp[y] += model[word1] * model[word2] * sif_weights[word1] * sif_weights[word2]
similaritiesTmp[y] += (model[word1] * model[word2] * sif_weights[word1] * sif_weights[word2])[0]

similarities += similaritiesTmp
similarities = similarities + similaritiesTmp

return similarities

Expand Down

0 comments on commit 8a02d86

Please sign in to comment.