-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
109 lines (91 loc) · 3.44 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import pandas as pd
import torch
import model
import sample
import train
import utils
logger = utils.logger
# Check if GPU is available
train_on_gpu = torch.cuda.is_available()
device = 'cuda' if train_on_gpu else 'cpu'
logger.info(f'Running on {device}')
# --------------- SETUP -----------------------------
config = {
'n_hidden': 128,
'n_layers': 4,
'batch_size': 128,
'seq_length': 50,
'n_epochs': 100,
'n_epochs_finetune': 50,
'lr': 0.001
}
logger.info(f'Config {config}')
# --------------- PRIOR MODEL -----------------------------
# Setup model
logger.info('Instantiating the model')
net = model.CharRNN(utils.chars, n_hidden=config['n_hidden'], n_layers=config['n_layers'])
logger.info(net)
# Load training data
logger.info('Loading and processing input data')
chemreps = pd.read_csv("input/chembl_28_chemreps.csv.gz")
chemreps = chemreps[chemreps.canonical_smiles.str.len() <= 100]
# Encode the text
text = '\n'.join(chemreps.canonical_smiles.values)
encoded = np.array([utils.char2int[ch] for ch in text])
# Train
logger.info('Training')
train_info = train.train(
net, encoded,
epochs=config['n_epochs'],
batch_size=config['batch_size'],
seq_length=config['seq_length'],
lr=config['lr'],
print_every=10000
)
train_info = pd.DataFrame(train_info, columns=['epoch','step','train_loss','val_loss'])
train_info.to_csv('output/Smiles-LSTM_ChEMBL28_prior_info.csv', index=False)
# Sample model
logger.info('Sampling the unbiased model')
sample_prior = sample.get_sample_frame(net, size=100000, prime='B')
sample_prior['set'] = 'prior'
# Save prior model and sample output
logger.info('Saving the unbiased model and its sample output')
torch.save(net.state_dict(), 'output/Smiles-LSTM_ChEMBL28_prior.pt')
sample_prior.drop(columns=['ROMol']).to_csv('output/Smiles-LSTM_ChEMBL28_prior.csv')
# --------------- FINE TUNING -----------------------------
# Setup model
logger.info('Reloading the unbiased model for finetuning')
net = model.CharRNN(utils.chars, n_hidden=config['n_hidden'], n_layers=config['n_layers'])
net.load_state_dict(torch.load('output/Smiles-LSTM_ChEMBL28_prior.pt', map_location=torch.device(device)))
print(net)
# Load training data
logger.info('Loading and processing input data for finetuning')
data = pd.read_csv('input/ChEMBL_ADORA2a_IC50-Ki.csv.gz')
data = data[data['pChEMBL Value'] >= 7]
# Encode the text
actives = '\n'.join(data.Smiles)
encoded = np.array([utils.char2int[ch] for ch in actives])
# Train
logger.info('Finetuning')
train_info = train.train(
net, encoded,
epochs=config['n_epochs_finetune'],
batch_size=config['batch_size'],
seq_length=config['seq_length'],
lr=config['lr'],
print_every=10000
)
train_info = pd.DataFrame(train_info, columns=['epoch','step','train_loss','val_loss'])
train_info.to_csv('output/Smiles-LSTM_ChEMBL28_finetune_info.csv', index=False)
# Sample model
logger.info('Sampling the finetuned model')
sample_ft = sample.get_sample_frame(net, size=100000, prime='B')
sample_ft['set'] = 'finetune'
# Save prior model and sample output
logger.info('Saving the finetuned model and its sample output')
torch.save(net.state_dict(), 'output/Smiles-LSTM_ChEMBL28_finetune.pt')
sample_ft.drop(columns=['ROMol']).to_csv('output/Smiles-LSTM_ChEMBL28_finetune.csv')
# Combine samples from prior and fine-tuned model and save
sample_both = pd.concat([sample_prior, sample_ft])
sample_both.drop(columns=['ROMol']).to_csv('output/Smiles-LSTM_ChEMBL28_both.csv')