diff --git a/firelang/models/_fireword.py b/firelang/models/_fireword.py index 3cfd464..e331ac4 100644 --- a/firelang/models/_fireword.py +++ b/firelang/models/_fireword.py @@ -227,9 +227,6 @@ def loss_skipgram( loss = Loss() logits = x1 * x2 - print("########################################") - print(type(logits)) - print("########################################") loss_sim = F.binary_cross_entropy_with_logits( logits, labels.float(), reduction="none" ) diff --git a/scripts/fineTune.py b/scripts/fineTune.py index 1a7911b..a3549b2 100644 --- a/scripts/fineTune.py +++ b/scripts/fineTune.py @@ -30,6 +30,7 @@ total_timer = Timer(elapsed, "total") +device = "cuda" @total_timer def fineTune(args): @@ -110,10 +111,10 @@ def fineTune(args): loss = Loss() logits = predictSentencePairs(model, iterationPairs) - + lossSim = F.binary_cross_entropy_with_logits( - FireTensor(torch.tensor(logits, dtype=torch.float, device=device, requires_grad=True)), - FireTensor(torch.tensor(iterationLabels, dtype=torch.float, device=device, requires_grad=True)), reduction="none" + torch.tensor(logits, dtype=torch.float, device=device, requires_grad=True), + torch.tensor(iterationLabels, dtype=torch.float, device=device, requires_grad=False), reduction="none" ) loss.add("sim", lossSim) @@ -129,11 +130,11 @@ def fineTune(args): else: steploss.backward() - grad_norm = ( - torch.cat([p.grad.data.reshape(-1) for p in model.parameters()]) - .norm() - .item() - ) + # grad_norm = ( + # torch.cat([p.grad.data.reshape(-1) for p in model.parameters()]) + # .norm() + # .item() + # ) if args.profile: logger.debug("----- backward -----") @@ -143,6 +144,7 @@ def fineTune(args): with Timer(elapsed, "optim", sync_cuda=True): with Timer(elapsed, "step"): for name, p in model.named_parameters(): + print(name, p, p.grad) isnan = p.grad.isnan() isinf = p.grad.isinf() isinvalid = isnan | isinf @@ -193,7 +195,7 @@ def fineTune(args): model.save(args.savedir) logger.info( - f"Iter {i}. Loss={loss}; grad={grad_norm:.3g}; " + f"Iter {i}. Loss={loss}; grad=grad_norm:.3g; " f"lr={scheduler.get_last_lr()[0]:.3g}; " f"sim={simscore:.3f}%" f"loss={total_loss}" @@ -247,40 +249,23 @@ def predictSentencePairs( w: sif_alpha / (sif_alpha + prob) for w, prob in probs.items() } - sents1 = pairs[0] - sents2 = pairs[1] - allsents = sents1 + sents2 - allsents = [ - [w for w in sent if w in sif_weights and w != vocab.unk] - for sent in allsents - ] - - """ similarity """ - with Timer(elapsed, "similarity", sync_cuda=True): - simmat = sentence_simmat(model, allsents, sif_weights) - print("#######################################################################################") - print(type(simmat)) - print("#######################################################################################") - """ regularization: sim(i,j) <- sim(i,j) - 0.5 * (sim(i,i) + sim(j,j)) - halved bc (9)""" - with Timer(elapsed, "regularization"): - diag = np.diag(simmat) - simmat = simmat - 0.5 * (diag.reshape(-1, 1) + diag.reshape(1, -1)) - - """ smoothing by standardization """ - with Timer(elapsed, "smooth"): - mean1 = np.mean(simmat, axis=1, keepdims=True) - std1 = np.std(simmat, axis=1, keepdims=True) - mean0 = np.mean(simmat, axis=0, keepdims=True) - std0 = np.std(simmat, axis=0, keepdims=True) - simmat = (simmat - (mean1 + mean0) / 2) / (std0 * std1) ** 0.5 - - N = len(pairs[0]) - preds = [simmat[i, i + N] for i in range(N)] - preds = np.exp(preds) - preds = np.array(preds) - - return preds + sents1 = [ [w for w in sent if w in sif_weights and w != vocab.unk] for sent in pairs[0] ] + sents2 = [ [w for w in sent if w in sif_weights and w != vocab.unk] for sent in pairs[1] ] + + similarities = torch.zeros((len(sents1), 1), requires_grad=False, device=device) + + for i in range(len(sents1)): + sent1 = sents1[i] + sent2 = sents2[i] + + similarity = torch.zeros((1, 1), requires_grad=False, device=device) + for word1 in sent1: + for word2 in sent2: + similarity[0] += model[word1] * model[word2] * sif_weights[word1] * sif_weights[word2] + + similarities[i] += similarity[0] + + return similarities def set_seed(seed): random.seed(seed)