Skip to content

Commit

Permalink
fine-tuning might work now, commit for HPC attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
Xeadriel committed May 18, 2024
1 parent 2b1306d commit 2a03a36
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 46 deletions.
3 changes: 0 additions & 3 deletions firelang/models/_fireword.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,6 @@ def loss_skipgram(
loss = Loss()

logits = x1 * x2
print("########################################")
print(type(logits))
print("########################################")
loss_sim = F.binary_cross_entropy_with_logits(
logits, labels.float(), reduction="none"
)
Expand Down
71 changes: 28 additions & 43 deletions scripts/fineTune.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

total_timer = Timer(elapsed, "total")

device = "cuda"

@total_timer
def fineTune(args):
Expand Down Expand Up @@ -110,10 +111,10 @@ def fineTune(args):
loss = Loss()

logits = predictSentencePairs(model, iterationPairs)

lossSim = F.binary_cross_entropy_with_logits(
FireTensor(torch.tensor(logits, dtype=torch.float, device=device, requires_grad=True)),
FireTensor(torch.tensor(iterationLabels, dtype=torch.float, device=device, requires_grad=True)), reduction="none"
torch.tensor(logits, dtype=torch.float, device=device, requires_grad=True),
torch.tensor(iterationLabels, dtype=torch.float, device=device, requires_grad=False), reduction="none"
)
loss.add("sim", lossSim)

Expand All @@ -129,11 +130,11 @@ def fineTune(args):
else:
steploss.backward()

grad_norm = (
torch.cat([p.grad.data.reshape(-1) for p in model.parameters()])
.norm()
.item()
)
# grad_norm = (
# torch.cat([p.grad.data.reshape(-1) for p in model.parameters()])
# .norm()
# .item()
# )

if args.profile:
logger.debug("----- backward -----")
Expand All @@ -143,6 +144,7 @@ def fineTune(args):
with Timer(elapsed, "optim", sync_cuda=True):
with Timer(elapsed, "step"):
for name, p in model.named_parameters():
print(name, p, p.grad)
isnan = p.grad.isnan()
isinf = p.grad.isinf()
isinvalid = isnan | isinf
Expand Down Expand Up @@ -193,7 +195,7 @@ def fineTune(args):
model.save(args.savedir)

logger.info(
f"Iter {i}. Loss={loss}; grad={grad_norm:.3g}; "
f"Iter {i}. Loss={loss}; grad=grad_norm:.3g; "
f"lr={scheduler.get_last_lr()[0]:.3g}; "
f"sim={simscore:.3f}%"
f"loss={total_loss}"
Expand Down Expand Up @@ -247,40 +249,23 @@ def predictSentencePairs(
w: sif_alpha / (sif_alpha + prob) for w, prob in probs.items()
}

sents1 = pairs[0]
sents2 = pairs[1]
allsents = sents1 + sents2
allsents = [
[w for w in sent if w in sif_weights and w != vocab.unk]
for sent in allsents
]

""" similarity """
with Timer(elapsed, "similarity", sync_cuda=True):
simmat = sentence_simmat(model, allsents, sif_weights)
print("#######################################################################################")
print(type(simmat))
print("#######################################################################################")
""" regularization: sim(i,j) <- sim(i,j) - 0.5 * (sim(i,i) + sim(j,j))
halved bc (9)"""
with Timer(elapsed, "regularization"):
diag = np.diag(simmat)
simmat = simmat - 0.5 * (diag.reshape(-1, 1) + diag.reshape(1, -1))

""" smoothing by standardization """
with Timer(elapsed, "smooth"):
mean1 = np.mean(simmat, axis=1, keepdims=True)
std1 = np.std(simmat, axis=1, keepdims=True)
mean0 = np.mean(simmat, axis=0, keepdims=True)
std0 = np.std(simmat, axis=0, keepdims=True)
simmat = (simmat - (mean1 + mean0) / 2) / (std0 * std1) ** 0.5

N = len(pairs[0])
preds = [simmat[i, i + N] for i in range(N)]
preds = np.exp(preds)
preds = np.array(preds)

return preds
sents1 = [ [w for w in sent if w in sif_weights and w != vocab.unk] for sent in pairs[0] ]
sents2 = [ [w for w in sent if w in sif_weights and w != vocab.unk] for sent in pairs[1] ]

similarities = torch.zeros((len(sents1), 1), requires_grad=False, device=device)

for i in range(len(sents1)):
sent1 = sents1[i]
sent2 = sents2[i]

similarity = torch.zeros((1, 1), requires_grad=False, device=device)
for word1 in sent1:
for word2 in sent2:
similarity[0] += model[word1] * model[word2] * sif_weights[word1] * sif_weights[word2]

similarities[i] += similarity[0]

return similarities

def set_seed(seed):
random.seed(seed)
Expand Down

0 comments on commit 2a03a36

Please sign in to comment.