diff --git a/sample.py b/sample.py index 401153f..1ace1f5 100644 --- a/sample.py +++ b/sample.py @@ -3,7 +3,7 @@ import pandas as pd import torch import torch.nn.functional as F -from rdkit import RDLogger +from rdkit import RDLogger, Chem import model from utils import * @@ -102,7 +102,8 @@ def get_sample(net: model.CharRNN, size: int, prime: str = 'B', top_k: Optional[ return ''.join(chars) -def get_sample_frame(net: model.CharRNN, size: int, prime: str = 'B', top_k: Optional[int] = None) -> pd.DataFrame: +def get_sample_frame(net: model.CharRNN, size: int, prime: str = 'B', top_k: Optional[int] = None, + verbose: bool = True) -> pd.DataFrame: """ Wrapper for sampling the net, splitting the output into SMILES string, converting to RDKit mols, checking validty, and computing descriptors @@ -112,6 +113,7 @@ def get_sample_frame(net: model.CharRNN, size: int, prime: str = 'B', top_k: Opt size (int): Sample this many characters prime (str): Prime net with string top_k (Optional[int]): Pick from top K characters + verbose (bool) Returns: sample (pd.DataFrame) @@ -120,16 +122,39 @@ def get_sample_frame(net: model.CharRNN, size: int, prime: str = 'B', top_k: Opt net.eval() sample = get_sample(net, size=size, prime=prime, top_k=top_k).split('\n') sample = pd.DataFrame(sample, columns=['SMILES']) - sample['set'] = 'prior' sample['ROMol'] = sample.SMILES.map(Chem.MolFromSmiles) sample = sample[sample.ROMol.notna()] num_valid = sample.ROMol.notna().sum() num_invalid = sample.shape[0] - num_valid - print(f'Valid molecules {num_valid}/{num_valid + num_invalid}') + if verbose: + print(f'Valid molecules {num_valid}/{num_valid + num_invalid}') # Compute descriptors of samples for desc in descriptors: sample[desc] = sample.ROMol.map(descriptors[desc]) return sample + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description='Sample trained SmilesLSTM model') + parser.add_argument('-m', '--model', help='Trained model checkpoint (.pt file)', required=True) + parser.add_argument('--hidden', help='Hidden units (default: %(default)d)', required=False, default=56, type=int) + parser.add_argument('--layers', help='Layers (default: %(default)d)', required=False, default=2, type=int) + parser.add_argument('-s', '--size', help='Sample this many characters (default: %(default)d)', required=False, + default=100000, + type=int) + parser.add_argument('-o', '--output', help='Output CSV file', required=True, type=str) + args = parser.parse_args() + + train_on_gpu = torch.cuda.is_available() + device = 'cuda' if train_on_gpu else 'cpu' + + net = model.CharRNN(chars, n_hidden=args.hidden, n_layers=args.layers) + net.load_state_dict(torch.load(args.model, map_location=torch.device(device))) + + net_sample = get_sample_frame(net, size=args.size, verbose=False) + net_sample.drop(columns=['ROMol']).to_csv(args.output, index=False)