-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- added src/models.py to include the MPNN model class that which will be utilized for all benchmarking experiments - added src/util.py to include helper functions for training, eval, and record metrics - added src/dataset.py to include helper functions for loading Physical chemistry datasets from MoleculeNet - Updated the README.md Signed-off-by: Akhil Akella <aakella@swing.lcrc.anl.gov>
- Loading branch information
1 parent
a928059
commit 49916ae
Showing
4 changed files
with
418 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,77 @@ | ||
# gnn_uncertainty_ensembles | ||
The project will focus on the design and development of methods to automate the development of graph neural network ensembles and use them for uncertainty quantification | ||
# MetalgPy + gnnNAS | ||
|
||
## About | ||
The project focusses on leveraging the general purpose library [MetalgPy](https://github.com/deephyper/metalgpy) to write symbolized ML programs capable of leveraging graph hyperparameters for better surrogate model fitting. Our goal was to use `MetalgPy` to search for a representation learning algorithm for graph structures. | ||
|
||
## Packages | ||
|
||
- `PyTorch` | ||
- `PyTorch-Geometric` | ||
- `MetalgPy` | ||
|
||
```shell | ||
# Install Pytorch | ||
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116 | ||
|
||
# Install Pytorch Geometric | ||
pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html | ||
pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html | ||
pip install -q git+https://github.com/pyg-team/pytorch_geometric.git | ||
|
||
# Install DeepHyper/MetalgPy | ||
pip install -q metalgpy | ||
|
||
# Install rdkit for the datasets | ||
pip install -q rdkit-pypi | ||
``` | ||
|
||
## Datasets | ||
|
||
We use three benchmark datasets | ||
|
||
- `GNN Benchmark Dataset` | ||
- `Planetoid-1` | ||
- `MoleculeNet` | ||
|
||
#### GNN Benchmark Dataset | ||
|
||
A variety of artificially and semi-artificially generated graph datasets. It is composed of datasets such as `PATTERN`, `CLUSTER`, `MNIST`, `CIFAR-10`, `TSP`, `CSL`. | ||
|
||
`Reference`: https://arxiv.org/abs/2003.00982 | ||
`Resource`: https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html | ||
|
||
#### Planetoid-1: | ||
|
||
`Planetoid` dataset comprising of citation network datasets `Cora`, `Citeseer`, and `Pubmed`. These are three benchmark datasets used for semi-supervised node classification tasks. Each of the mentioned graph datasets contains bag-of-words representation of documents and citation links between the documents | ||
|
||
`Reference`: https://arxiv.org/pdf/1603.08861.pdf | ||
`Resource`: https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html | ||
|
||
#### MoleculeNet: | ||
|
||
`MoleculetNet`: MoleculeNet is a benchmark specially designed for testing machine learning methods of molecular properties. As we aim to facilitate the development of molecular machine learning method, this work curates a number of dataset collections, creates a suite of software that implements many known featurizations and previously proposed algorithms. All methods and datasets are integrated as parts of the open source DeepChem package(MIT license). | ||
|
||
Within the `MoleculetNet`, we are interested to benchmark, Quantum Mechanics, and Physical chemistry datasets | ||
|
||
`Quantum Mechanics`: | ||
- QM7/QM7b (structure): Electronic properties(atomization energy, HOMO/LUMO, etc.) determined using ab-initio density functional theory(DFT). | ||
- QM8 (structure): Electronic spectra and excited state energy of small molecules calculated by multiple quantum mechanic methods. | ||
- QM9 (structure): Geometric, energetic, electronic and thermodynamic properties of DFT-modelled small molecules. | ||
|
||
`Physical chemistry`: | ||
- ESOL: Water solubility data(log solubility in mols per litre) for common organic small molecules. | ||
- FreeSolv: Experimental and calculated hydration free energy of small molecules in water. | ||
- Lipophilicity: Experimental results of octanol/water distribution coefficient(logD at pH 7.4). | ||
|
||
`Reference`: https://moleculenet.org/datasets-1 | ||
`Resource`: https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html | ||
|
||
## Results | ||
TBA | ||
|
||
## Author | ||
[Akhil Pandey](https://github.com/akhilpandey95) | ||
|
||
## Supervisor | ||
[Prasanna Balaprakash](https://github.com/pbalapra) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# This Source Code Form is subject to the terms of the | ||
# BSD 2-Clause "Simplified" License. If a copy of the same | ||
# was not distributed with this file, You can obtain one at | ||
# https://github.com/akhilpandey95/gnnNAS/blob/master/LICENSE. | ||
|
||
import torch | ||
import numpy as np | ||
import torch_geometric as pyg | ||
|
||
# define the helper method to load dataset | ||
def load_molnet_phys_chem_data(name, batch_size, training_split, seed=2022): | ||
""" | ||
Load the specific Graph dataset from MoleculeNet | ||
Parameters | ||
---------- | ||
arg1 | name: str | ||
Name of the dataset to import from Pytorch Geometric MoleculeNet dataloader. | ||
arg2 | batch_size: int | ||
Batch size for creating the train/test dataloaders. | ||
arg3 | training_split: float | ||
Percentage of samples to be kept in training set. | ||
arg4 | seed: int | ||
Torch Random seed to ensure reproducibility. Default value is 2022 | ||
Returns | ||
------- | ||
Pytorch Geometric Dataset(s) | ||
torch_geometric.datasets.molecule_net.MoleculeNet | ||
""" | ||
# load the dataset | ||
dataset = pyg.datasets.MoleculeNet(root='/tmp/Molnet', name=name) | ||
|
||
# set the seed | ||
torch.manual_seed(seed) | ||
|
||
# shuffle the data | ||
dataset = dataset.shuffle() | ||
|
||
# set an stop index for gathering train data | ||
stop_index = int(np.floor(training_split*dataset.len())) | ||
|
||
# separate training data | ||
train_dataset = dataset[0:stop_index] | ||
|
||
# separate test data | ||
test_dataset = dataset[stop_index:] | ||
|
||
# create dataloaders for train and test samples | ||
train_loader = pyg.loader.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||
test_loader = pyg.loader.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | ||
|
||
return dataset, train_loader, test_loader | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# This Source Code Form is subject to the terms of the | ||
# BSD 2-Clause "Simplified" License. If a copy of the same | ||
# was not distributed with this file, You can obtain one at | ||
# https://github.com/akhilpandey95/gnnNAS/blob/master/LICENSE. | ||
|
||
import torch | ||
import numpy as np | ||
import torch_geometric as pyg | ||
|
||
class MPNN(torch.nn.Module): | ||
""" | ||
Creates an MPNN model in pytorch geometric | ||
""" | ||
def __init__( | ||
self, | ||
n_node_features: int, | ||
n_edge_features: int, | ||
n_hidden: int, | ||
n_output: int, | ||
MPNN_inp: torch.nn.Module, | ||
MPNN_hidden: torch.nn.Module, | ||
n_conv_blocks: int, | ||
skip_connection: str="plain") -> None: | ||
""" | ||
Build the MPNN model | ||
Parameters | ||
---------- | ||
arg1 | n_node_features: int | ||
Number of features at node level | ||
arg2 | n_edge_features: int | ||
Number of features at edge level | ||
arg3 | n_hidden: int | ||
Number of hidden activations | ||
arg4 | n_output: int | ||
Number of output activations | ||
arg5 | n_conv_blocks: int | ||
Number of convolutional kernels | ||
Returns | ||
------- | ||
Nothing | ||
None | ||
""" | ||
# super class the class structure | ||
super().__init__() | ||
|
||
# set the growth dimension | ||
self.growth_dimension = n_hidden | ||
|
||
# encode the node information | ||
self.node_encoder = MPNN_inp(n_node_features, n_hidden) | ||
|
||
# add the ability to add one or more conv layers | ||
conv_blocks = [] | ||
|
||
# ability to add one or more conv blocks | ||
for block in range(n_conv_blocks): | ||
if skip_connection == "dense": | ||
self.growth_dimension = n_hidden + (n_hidden * block) | ||
conv = MPNN_hidden(self.growth_dimension, n_hidden) | ||
norm = torch.nn.LayerNorm(n_hidden, elementwise_affine=True) | ||
act = torch.nn.ReLU(inplace=True) | ||
layer = pyg.nn.DeepGCNLayer(conv, norm, act, block=skip_connection) | ||
conv_blocks.append(layer) | ||
|
||
# group all the conv layers | ||
self.conv_layers = torch.nn.ModuleList(conv_blocks) | ||
|
||
# add the linear layers for flattening the output from MPNN | ||
self.flatten = torch.nn.Sequential( | ||
torch.nn.Linear(self.growth_dimension, n_hidden), | ||
torch.nn.ReLU(), | ||
torch.nn.Linear(n_hidden, n_output)) | ||
|
||
def forward(self, | ||
x: torch.Tensor, | ||
edge_index: torch.Tensor, | ||
batch_idx: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Process the MPNN model | ||
Parameters | ||
---------- | ||
arg1 | x: torch.Tensor | ||
Input features at node level | ||
arg2 | edge_index: torch.Tensor | ||
Index pairs of verticies | ||
arg3 | batch_idx: torch.Tensor | ||
Batch index | ||
Returns | ||
------- | ||
Tensor | ||
torch.Tensor | ||
""" | ||
# obtaint the input | ||
if isinstance(self.node_encoder, pyg.nn.MessagePassing): | ||
x = self.node_encoder(x, edge_index) | ||
else: | ||
x = self.node_encoder(x) | ||
|
||
# pass the node information to the conv layer | ||
x = self.conv_layers[0].conv(x, edge_index) | ||
|
||
# process the layers | ||
for layer in range(len(self.conv_layers[1:])): | ||
x = self.conv_layers[layer](x, edge_index) | ||
|
||
# obtain the output from the MPNN final layer | ||
y = pyg.nn.global_add_pool(x, batch=batch_idx) | ||
|
||
# pass the output to the linear output layer | ||
out = self.flatten(y) | ||
|
||
# return the output | ||
return out | ||
|
||
|
Oops, something went wrong.