Skip to content

Commit

Permalink
add CLI interface to sample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
gmattedi committed Sep 15, 2021
1 parent b64999c commit 464139f
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import torch
import torch.nn.functional as F
from rdkit import RDLogger
from rdkit import RDLogger, Chem

import model
from utils import *
Expand Down Expand Up @@ -102,7 +102,8 @@ def get_sample(net: model.CharRNN, size: int, prime: str = 'B', top_k: Optional[
return ''.join(chars)


def get_sample_frame(net: model.CharRNN, size: int, prime: str = 'B', top_k: Optional[int] = None) -> pd.DataFrame:
def get_sample_frame(net: model.CharRNN, size: int, prime: str = 'B', top_k: Optional[int] = None,
verbose: bool = True) -> pd.DataFrame:
"""
Wrapper for sampling the net, splitting the output into SMILES string, converting to
RDKit mols, checking validty, and computing descriptors
Expand All @@ -112,6 +113,7 @@ def get_sample_frame(net: model.CharRNN, size: int, prime: str = 'B', top_k: Opt
size (int): Sample this many characters
prime (str): Prime net with string
top_k (Optional[int]): Pick from top K characters
verbose (bool)
Returns:
sample (pd.DataFrame)
Expand All @@ -120,16 +122,39 @@ def get_sample_frame(net: model.CharRNN, size: int, prime: str = 'B', top_k: Opt
net.eval()
sample = get_sample(net, size=size, prime=prime, top_k=top_k).split('\n')
sample = pd.DataFrame(sample, columns=['SMILES'])
sample['set'] = 'prior'
sample['ROMol'] = sample.SMILES.map(Chem.MolFromSmiles)
sample = sample[sample.ROMol.notna()]

num_valid = sample.ROMol.notna().sum()
num_invalid = sample.shape[0] - num_valid
print(f'Valid molecules {num_valid}/{num_valid + num_invalid}')
if verbose:
print(f'Valid molecules {num_valid}/{num_valid + num_invalid}')

# Compute descriptors of samples
for desc in descriptors:
sample[desc] = sample.ROMol.map(descriptors[desc])

return sample


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description='Sample trained SmilesLSTM model')
parser.add_argument('-m', '--model', help='Trained model checkpoint (.pt file)', required=True)
parser.add_argument('--hidden', help='Hidden units (default: %(default)d)', required=False, default=56, type=int)
parser.add_argument('--layers', help='Layers (default: %(default)d)', required=False, default=2, type=int)
parser.add_argument('-s', '--size', help='Sample this many characters (default: %(default)d)', required=False,
default=100000,
type=int)
parser.add_argument('-o', '--output', help='Output CSV file', required=True, type=str)
args = parser.parse_args()

train_on_gpu = torch.cuda.is_available()
device = 'cuda' if train_on_gpu else 'cpu'

net = model.CharRNN(chars, n_hidden=args.hidden, n_layers=args.layers)
net.load_state_dict(torch.load(args.model, map_location=torch.device(device)))

net_sample = get_sample_frame(net, size=args.size, verbose=False)
net_sample.drop(columns=['ROMol']).to_csv(args.output, index=False)

0 comments on commit 464139f

Please sign in to comment.