Skip to content

Commit

Permalink
Merge pull request #125 from flatironinstitute/ipynb-update
Browse files Browse the repository at this point in the history
Notebook update
  • Loading branch information
mortonjt authored Mar 7, 2023
2 parents bd00cd0 + d007a05 commit 0e86f71
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 995 deletions.
6 changes: 3 additions & 3 deletions ci/pip_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ threadpoolctl==3.1.0
tinycss2==1.2.1
tokenizers==0.13.2
toolz==0.12.0
torch==1.13.1+cu116
torchaudio==0.13.1+cu116
torch==1.13.1
torchaudio==0.13.1
torchmetrics==0.11.0
torchvision==0.14.1+cu116
torchvision==0.14.1
tornado==6.2
tqdm==4.64.1
traitlets==5.8.0
Expand Down
109 changes: 0 additions & 109 deletions deepblast/tests/test_trainer.py

This file was deleted.

76 changes: 76 additions & 0 deletions examples/simulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
from deepblast.sim import hmm_alignments
import argparse
import numpy as np
from pytorch_lightning import Trainer
from transformers import T5EncoderModel, T5Tokenizer
from deepblast.trainer import DeepBLAST


# create simulation dataset
hmm = '../data/zf-C2H2.hmm'
n_alignments = 100
np.random.seed(0)
align_df = hmm_alignments(n=40, seed=0, n_alignments=n_alignments, hmmfile=hmm)

cols = [
'chain1_name', 'chain2_name', 'tmscore1', 'tmscore2', 'rmsd',
'chain1', 'chain2', 'alignment'
]
align_df.columns = cols

# split into train/test/validation dataset
parts = n_alignments // 10
train_df = align_df.iloc[:parts * 8]
test_df = align_df.iloc[parts * 8:parts * 9]
valid_df = align_df.iloc[parts * 9:]

# save the files to disk.
if not os.path.exists('data'):
os.mkdir('data')

train_df.to_csv('data/train.txt', sep='\t', index=None, header=None)
test_df.to_csv('data/test.txt', sep='\t', index=None, header=None)
valid_df.to_csv('data/valid.txt', sep='\t', index=None, header=None)

output_dir = 'simulation_results'
if not os.path.exists(output_dir):
os.mkdir(output_dir)

# Load the protrans model
tokenizer = T5Tokenizer.from_pretrained(
"Rostlab/prot_t5_xl_uniref50", do_lower_case=False )
lm = T5EncoderModel.from_pretrained(
"Rostlab/prot_t5_xl_uniref50")


# Create the deepblast model
model = DeepBLAST(
train_pairs=f'{os.getcwd()}/data/train.txt',
test_pairs=f'{os.getcwd()}/data/test.txt',
valid_pairs=f'{os.getcwd()}/data/valid.txt',
output_directory=output_dir,
hidden_dim=1024,
embedding_dim=1024,
batch_size=10,
num_workers=10,
layers=1,
learning_rate=5e-5,
loss='cross_entropy',
lm=lm,
tokenizer=tokenizer
)

# Fit the DeepBLAST model
trainer = Trainer(
max_epochs=10,
limit_train_batches=10, # short run, we'll only train 10 batches / epoch
limit_val_batches=10, # short run, ...
gpus=1,
check_val_every_n_epoch=1,
# profiler=profiler,
fast_dev_run=True,
# auto_scale_batch_size='power'
)

trainer.fit(model)
Loading

0 comments on commit 0e86f71

Please sign in to comment.