Skip to content

Commit

Permalink
Merge pull request #165 from aryankeluskar/master
Browse files Browse the repository at this point in the history
migrated to newer version of lightning
  • Loading branch information
mortonjt authored Nov 11, 2024
2 parents 20b5272 + b022f17 commit ec661fa
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 22 deletions.
2 changes: 1 addition & 1 deletion deepblast/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def gap_mask(states: str, sparse=False):
if sparse:
return mat
else:
return mat.toarray().astype(np.bool)
return mat.toarray().astype(bool)


def window(seq, n=2):
Expand Down
28 changes: 7 additions & 21 deletions deepblast/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, batch_size=20,
):

super(DeepBLAST, self).__init__()
self.validation_step_outputs = []
self.save_hyperparameters(ignore=['lm', 'tokenizer'])

if device == 'gpu': # this is for users, in case they specify gpu
Expand All @@ -74,6 +75,7 @@ def __init__(self, batch_size=20,
n_input, n_units, n_embed, n_layers, dropout=dropout, lm=lm,
alignment_mode=alignment_mode,
device=device)
self.tokenizer = tokenizer

def align(self, x, y):
x_code = get_sequence(x, self.tokenizer)[0].to(self.device)
Expand Down Expand Up @@ -236,6 +238,7 @@ def validation_step(self, batch, batch_idx):
predA, theta, gap = self.aligner(seq, order)
x, xlen, y, ylen = unpack_sequences(seq, order)
loss = self.compute_loss(xlen, ylen, predA, A, P, G, theta)
self.validation_step_outputs.append(loss)

assert torch.isnan(loss).item() is False

Expand Down Expand Up @@ -291,27 +294,10 @@ def test_step(self, batch, batch_idx):
statistics['key_name'] = other_names
return statistics

def validation_epoch_end(self, outputs):
loss_f = lambda x: x['validation_loss']
losses = list(map(loss_f, outputs))
loss = sum(losses) / len(losses)
self.logger.experiment.add_scalar('val_loss', loss, self.global_step)
# self.log('validation_loss') = loss

# metrics = ['val_tp', 'val_fp', 'val_fn', 'val_perc_id',
# 'val_ppv', 'val_fnr', 'val_fdr']
# scores = []
# for i, m in enumerate(metrics):
# loss_f = lambda x: x['log'][m]
# losses = list(map(loss_f, outputs))
# scalar = sum(losses) / len(losses)
# scores.append(scalar)
# self.logger.experiment.add_scalar(m, scalar, self.global_step)

tensorboard_logs = dict(
[('val_loss', loss)] # + list(zip(metrics, scores))
)
return {'val_loss': loss, 'log': tensorboard_logs}
def on_validation_epoch_end(self):
epoch_average = torch.stack(self.validation_step_outputs).mean()
self.log("validation_epoch_average", epoch_average)
self.validation_step_outputs.clear() # free memory

def configure_optimizers(self):
# Freeze language model
Expand Down

0 comments on commit ec661fa

Please sign in to comment.