-
Notifications
You must be signed in to change notification settings - Fork 4
/
qm9_dataset.py
96 lines (70 loc) · 2.76 KB
/
qm9_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import torch.utils.data
def qm9_collate_batch(batch):
print(batch)
drug1, drug2, label1, label2 = list(zip(*batch))
print("drug1 ", drug1)
print("drug2 ", drug1)
ddi_idxs1, ddi_idxs2 = collate_drug_pairs(drug1, drug2)
drug1 = (*collate_drugs(drug1), *ddi_idxs1)
drug2 = (*collate_drugs(drug2), *ddi_idxs2)
label1 = collate_labels(label1)
label2 = collate_labels(label2)
return (*drug1, *drug2, label1, label2)
def collate_drug_pairs(drugs1, drugs2):
n_atom1 = [d['n_atom'] for d in drugs1]
n_atom2 = [d['n_atom'] for d in drugs2]
c_atom1 = [sum(n_atom1[:k]) for k in range(len(n_atom1))]
c_atom2 = [sum(n_atom2[:k]) for k in range(len(n_atom2))]
ddi_seg_i1, ddi_seg_i2, ddi_idx_j1, ddi_idx_j2 = zip(*[
(i1 + c1, i2 + c2, i2, i1)
for l1, l2, c1, c2 in zip(n_atom1, n_atom2, c_atom1, c_atom2)
for i1 in range(l1) for i2 in range(l2)])
ddi_seg_i1 = torch.LongTensor(ddi_seg_i1)
ddi_idx_j1 = torch.LongTensor(ddi_idx_j1)
ddi_seg_i2 = torch.LongTensor(ddi_seg_i2)
ddi_idx_j2 = torch.LongTensor(ddi_idx_j2)
return (ddi_seg_i1, ddi_idx_j1), (ddi_seg_i2, ddi_idx_j2)
def collate_labels(labels):
concat_labels = torch.Tensor(np.stack(labels))
return concat_labels
def collate_drugs(drugs):
c_atoms = [sum(d['n_atom'] for d in drugs[:k]) for k in range(len(drugs))]
atom_feat = torch.FloatTensor(np.vstack([d['atom_feat'] for d in drugs]))
atom_type = torch.LongTensor(np.hstack([d['atom_type'] for d in drugs]))
bond_type = torch.LongTensor(np.hstack([d['bond_type'] for d in drugs]))
bond_seg_i = torch.LongTensor(np.hstack([
np.array(d['bond_seg_i']) + c for d, c in zip(drugs, c_atoms)]))
bond_idx_j = torch.LongTensor(np.hstack([
np.array(d['bond_idx_j']) + c for d, c in zip(drugs, c_atoms)]))
batch_seg_m = torch.LongTensor(np.hstack([
[k] * d['n_atom'] for k, d in enumerate(drugs)]))
return batch_seg_m, atom_type, atom_feat, bond_type, bond_seg_i, bond_idx_j
class QM9Dataset(torch.utils.data.Dataset):
def __init__(
self,
graph_dict,
pairs_dataset=None):
assert pairs_dataset
self.graph_dict = graph_dict
self.graph_idx_list = list(graph_dict.keys())
print("graph idx list ", self.graph_idx_list)
self.feeding_insts = pairs_dataset
def prepare_feeding_insts(self):
# Could add optimizing logic here
return
def __len__(self):
return len(self.feeding_insts)
def __getitem__(self, idx):
print("idx ", idx)
instance = self.feeding_insts[idx]
# drug lookup
instance = self.drug_structure_lookup(instance)
return instance
def drug_structure_lookup(self, instance):
drug_idx1, drug_idx2, label1, label2 = instance
print("drug_idx1 ", drug_idx1)
print("self.graph_dict ", self.graph_dict)
drug1 = self.graph_dict[drug_idx1]
drug2 = self.graph_dict[drug_idx2]
return drug1, drug2, label1, label2