Skip to content

Commit

Permalink
clear
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Feb 23, 2022
1 parent fcc8dcc commit 6f9a898
Show file tree
Hide file tree
Showing 26 changed files with 111 additions and 617 deletions.
2 changes: 0 additions & 2 deletions .gitignore

This file was deleted.

81 changes: 81 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# AGAT

Source code for paper "M2GCN: Multi-Modal Graph Convolutional Network for Polypharmacy Side Effects Discovery"

## Requirements

The code has been tested under Python 3.8, with the following packages installed (along with their dependencies):

- torch >= 1.9.0
- pytorch-lightning >= 1.4.4
- torchmetrics >= 0.5.0
- torch-scatter >= 2.0.9
- torch-sparse >= 0.6.12
- numpy
- pandas
- tqdm
- yaml

## unzip

**Since git limits the size of a single file upload (<25M), we divide the datasets and the pre-trained models into multiple volumes. Please unzip the files in the directories `data`and`lightning_logs` first.**

```
cd ./data
sh do_unzip.sh
cd ../lightning_logs
sh do_unzip.sh
```

## Files in the folder

- **/data:** Store the dataset and prepared data.
- **/dataloader:** Codes of the dataloader.
- **/models:** Codes of the AGAT model , link-prediction task and simi-node-classification task .
- **/utils:** Codes for data prepareing and some other utils.
- **/lightning_logs:** Store the trained model parameters, setting files, checkpoints, logs and results.
- **main.py:** The main entrance of running.

## Basic usage

### Train AGAT

train M2GCN by

```
python main.py
```

The default configuration file is `setting/settings.yaml`.

And if you want to adjust the hyperparameters of the model, you can modify it in `.setting/settings.yaml`, or create a similar configuration file, and specify `--setting_path` like this:

```
python main.py --setting_path yourpath.yaml
```

Checkpoints, logs, and results during training will be stored in the directory: `./lightning_logs/version_0`

And you can run `tensorboard --logdir lightning_logs/version_0` to monitor the training progress.

### Link Prediction with pre-trained model

You can predict the interaction between drugs through the pre-trained model we provide.

**Since git limits the size of a single file upload (<25M), we divide the pre-trained model into multiple volumes. Please unzip the files in the directory `./lightning_logs/pre-trained/checkpoints/` first.**

Load the pre-trained model and predict the test dataset by:

```
python main.py --test --ckpt_path ./lightning_logs/pre-trained/checkpoints/pre-trained.ckpt
```

The result(auc,aupr) will be stored in the directory: `./lightning_logs/version_0`

If you want to load your trained model to predict the test data set, you only need to change `--ckpt_path`like this:

```
python main.py --test --ckpt_path yourpath.ckpt
```

PS: Keep the configuration file unchanged during training and testing.
5 changes: 5 additions & 0 deletions data/do_unzip.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
zip -s 0 myfile.zip --out file.zip
zip -F file.zip --out file-large.zip
unzip file-large.zip

Binary file added data/myfile.z01
Binary file not shown.
Binary file added data/myfile.z02
Binary file not shown.
Binary file added data/myfile.z03
Binary file not shown.
Binary file added data/myfile.z04
Binary file not shown.
Binary file added data/myfile.z05
Binary file not shown.
Binary file added data/myfile.zip
Binary file not shown.
21 changes: 0 additions & 21 deletions dataloader/link_rank_dataloader.py

This file was deleted.

5 changes: 5 additions & 0 deletions lightning_logs/do_unzip.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
zip -s 0 myfile.zip --out file.zip
zip -F file.zip --out file-large.zip
unzip file-large.zip

Binary file added lightning_logs/myfile.z01
Binary file not shown.
Binary file added lightning_logs/myfile.z02
Binary file not shown.
Binary file added lightning_logs/myfile.z03
Binary file not shown.
Binary file added lightning_logs/myfile.zip
Binary file not shown.
11 changes: 1 addition & 10 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import os

import torch

from dataloader.link_pre_dataloader import LinkPredictionDataloader
from dataloader.link_rank_dataloader import LinkRankDataloader
from dataloader.node_cla_dataloader import NodeClassificationDataloader
from models.LinkPreTask import LinkPredictionTask
from models.LinkRankTask import LinkRankTask
from models.NodeCLTask import NodeClassificationTask
import pytorch_lightning as pl
import yaml
import argparse

TASK = {
'link_pre':(LinkPredictionDataloader,LinkPredictionTask),
'link_rank':(LinkRankDataloader,LinkRankTask),
'simi_node_CL':(NodeClassificationDataloader,NodeClassificationTask)
}

Expand All @@ -32,11 +28,6 @@ def get_trainer_model_dataloader_from_yaml(yaml_path):


def train(parser):
# dl=NSDataloader(batch_size=512*32)
# model = M2GCNModel(N=dl.num_nodes,adj_list=dl.adj_list,lam=0.5)
# checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='auc',mode='max')
# trainer = pl.Trainer(max_epochs=10,callbacks=[checkpoint_callback],gpus=1,reload_dataloaders_every_n_epochs=1)
# trainer.fit(model,dl)
args = parser.parse_args()
setting_path = args.setting_path
trainer,model,dl = get_trainer_model_dataloader_from_yaml(setting_path)
Expand All @@ -59,7 +50,7 @@ def test(parser):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--setting_path',type=str,default='settings/wn_settings.yaml')
parser.add_argument('--setting_path',type=str,help='model setting file path')
parser.add_argument("--test", action='store_true', help='test or train')
temp_args, _ = parser.parse_known_args()
if temp_args.test:
Expand Down
7 changes: 4 additions & 3 deletions models/LinkPreTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class LinkPredictionTask(pl.LightningModule):
def __init__(self,edge_index,edge_type,feature,N,aggregator,use_feature,feature_dim,d_model,type_num, L,use_gradient_checkpointing,neg_num,dropout,lr,wd):
def __init__(self,edge_index,edge_type,feature,N,aggregator,use_feature,feature_dim,d_model,type_num,lambed, L,use_gradient_checkpointing,neg_num,dropout,lr,wd):
super(LinkPredictionTask, self).__init__()
# 工程类组件
self.save_hyperparameters(ignore=['edge_index','edge_type','feature','N','degree'])
Expand Down Expand Up @@ -68,9 +68,10 @@ def training_step(self, batch,*args, **kwargs) -> STEP_OUTPUT:
if self.hparams.aggregator=='agat':
logits = (em[:, source] * self.w[target].unsqueeze(0)).sum(-1).T # bs,t
l2 = self.loss2(logits,pos_edge_type-1)
loss = loss + self.hparams.lambed * l2
self.log('loss2', l2, prog_bar=True)
self.log('loss_all', l1+l2, prog_bar=True)
loss = loss+l2
self.log('loss_all', loss, prog_bar=True)

return loss

def validation_step(self, batch,*args, **kwargs) -> Optional[STEP_OUTPUT]:
Expand Down
126 changes: 0 additions & 126 deletions models/LinkRankTask.py

This file was deleted.

25 changes: 0 additions & 25 deletions settings/aifb_settings.yaml

This file was deleted.

26 changes: 0 additions & 26 deletions settings/ama_settings.yaml

This file was deleted.

23 changes: 0 additions & 23 deletions settings/pub_settings.yaml

This file was deleted.

Loading

0 comments on commit 6f9a898

Please sign in to comment.