-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
466e41a
commit f319134
Showing
9 changed files
with
332 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import torch | ||
import pytorch_lightning as pl | ||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS | ||
from torch.utils.data import TensorDataset, DataLoader | ||
from utils.sparse_utils import * | ||
from torch_sparse import coalesce | ||
from dataloader.link_pre_dataloader import LinkPredictionDataloader | ||
class LinkRankDataloader(LinkPredictionDataloader): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
import pytorch_lightning as pl | ||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS | ||
from torch.utils.data import TensorDataset, DataLoader | ||
from utils.sparse_utils import * | ||
from torch_sparse import coalesce | ||
class NodeClassificationDataloader(pl.LightningDataModule): | ||
def __init__(self,datapath,batch_size,num_workers=0): | ||
super(NodeClassificationDataloader, self).__init__() | ||
self.datapath = datapath | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
self.read_data() | ||
def read_data(self): | ||
data = torch.load(self.datapath) | ||
train_data, test_data, feature_data = data['train_data'], data['test_data'], data['feature_data'] | ||
self.feature_data = feature_data | ||
self.test_dataset = TensorDataset(test_data) | ||
self.train_dataset = TensorDataset(train_data) | ||
self.edge_index, self.edge_type = data['edge_index'], data['edge_type'] | ||
self.N, self.E = self.edge_index.max() + 1, self.edge_index.shape[1] | ||
|
||
def train_dataloader(self) -> TRAIN_DATALOADERS: | ||
return DataLoader(self.train_dataset,self.batch_size,shuffle=True,num_workers=self.num_workers,drop_last=True) | ||
|
||
def test_dataloader(self) -> EVAL_DATALOADERS: | ||
return DataLoader(self.test_dataset,batch_size=len(self.test_dataset)) | ||
|
||
def val_dataloader(self) -> EVAL_DATALOADERS: | ||
return DataLoader(self.test_dataset,batch_size=len(self.test_dataset)) | ||
|
||
if __name__ == '__main__': | ||
dataloader = NodeClassificationDataloader('../data/Aifb/all_data.pkl', 64) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import torchmetrics | ||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT | ||
from torch import nn | ||
import pytorch_lightning as pl | ||
import numpy as np | ||
|
||
from models.AGAT import AGAT | ||
from models.LinkPreTask import LinkPredictionTask | ||
class LinkRankTask(LinkPredictionTask): | ||
|
||
def __init__(self, edge_index, edge_type, feature, N, aggregator, use_feature, feature_dim, d_model, type_num, L, | ||
use_gradient_checkpointing, neg_num, dropout, lr, wd): | ||
super().__init__(edge_index, edge_type, feature, N, aggregator, use_feature, feature_dim, d_model, type_num, L, | ||
use_gradient_checkpointing, neg_num, dropout, lr, wd) | ||
|
||
def training_step(self, batch, *args, **kwargs) -> STEP_OUTPUT: | ||
return super().training_step(batch, *args, **kwargs) | ||
|
||
def validation_step(self, batch, *args, **kwargs) -> Optional[STEP_OUTPUT]: | ||
# 剔除不在训练集中的 | ||
return super().validation_step(batch, *args, **kwargs) | ||
|
||
def test_step(self, batch, *args, **kwargs) -> Optional[STEP_OUTPUT]: | ||
|
||
return super().test_step(batch, *args, **kwargs) | ||
|
||
def on_test_end(self) -> None: | ||
super().on_test_end() | ||
|
||
def on_fit_end(self) -> None: | ||
super().on_fit_end() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from typing import Optional, Union, List | ||
|
||
import torch | ||
import torchmetrics | ||
from pytorch_lightning.core.optimizer import LightningOptimizer | ||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT | ||
from torch import nn | ||
import pytorch_lightning as pl | ||
import numpy as np | ||
from torch.optim import Optimizer | ||
|
||
from models.AGAT import AGAT | ||
|
||
class NodeClassificationTask(pl.LightningModule): | ||
def __init__(self, edge_index, edge_type, feature, N, aggregator, use_feature, feature_dim, d_model, type_num, L, | ||
use_gradient_checkpointing, dropout, lr, wd): | ||
super(NodeClassificationTask, self).__init__() | ||
self.save_hyperparameters(ignore=['edge_index','edge_type','feature','N','degree']) | ||
self.register_buffer('edge_index', edge_index) | ||
self.register_buffer('edge_type', edge_type) | ||
edge_type_num = edge_type.max()+1 | ||
self.register_buffer('edge_feature', torch.eye(edge_type_num)) | ||
self.fc_edge = nn.Linear(edge_type_num, d_model) | ||
|
||
if use_feature: | ||
self.register_buffer('feature',feature) | ||
self.fc_node = nn.Linear(feature_dim, d_model) | ||
else: | ||
self.feature = nn.Parameter(torch.randn(N,d_model)) | ||
|
||
self.w = nn.Parameter(torch.FloatTensor(type_num, d_model)) | ||
nn.init.xavier_uniform_(self.w) | ||
if aggregator == 'agat': | ||
self.agat = AGAT(type_num, d_model, L, use_gradient_checkpointing, dropout) | ||
elif aggregator == 'sgat': | ||
self.sgat = AGAT(1, d_model, L, use_gradient_checkpointing, dropout) | ||
|
||
self.loss = nn.CrossEntropyLoss() | ||
self.max_macro_F1 = -np.inf | ||
self.max_micro_F1 = -np.inf | ||
self.micro_f1_cal = torchmetrics.F1(num_classes=type_num,average='micro',multiclass=True) | ||
self.macro_f1_cal = torchmetrics.F1(num_classes=type_num,average='macro',multiclass=True) | ||
|
||
def evalute(self,pre,label): | ||
micro_F1 = self.micro_f1_cal(pre,label) | ||
macro_F1 = self.macro_f1_cal(pre,label) | ||
if self.max_micro_F1 < micro_F1: | ||
self.max_micro_F1 = micro_F1 | ||
self.max_macro_F1 = macro_F1 | ||
self.log('micro-f1',micro_F1,prog_bar=True) | ||
self.log('macro-f1',macro_F1,prog_bar=True) | ||
self.micro_f1_cal.reset() | ||
self.macro_f1_cal.reset() | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam(self.parameters(), | ||
lr=self.hparams.lr, | ||
weight_decay=self.hparams.wd) | ||
return optimizer | ||
|
||
def get_em(self): | ||
if self.hparams.use_feature: | ||
feature = self.fc_node(self.feature) | ||
else: | ||
feature = self.feature | ||
edge_feature = self.fc_edge(self.edge_feature) | ||
if self.hparams.aggregator=='agat': | ||
em = self.agat(feature,self.edge_index,self.edge_type,edge_feature) | ||
elif self.hparams.aggregator=='sgat': | ||
em = self.sgat(feature,self.edge_index,self.edge_type,edge_feature)\ | ||
.expand(self.hparams.type_num,feature.shape[0],self.hparams.d_model) | ||
return em #t,N,d_model | ||
|
||
def forward(self,node_id): | ||
em = self.get_em() | ||
node_em = em[:,node_id].transpose(0,1) #bs,t,d | ||
logits = (node_em * self.w).sum(-1) # bs,t | ||
return logits | ||
|
||
def training_step(self,batch, *args, **kwargs) -> STEP_OUTPUT: | ||
data = batch[0] | ||
node_id,label = data[:,0],data[:,1] | ||
pre = self(node_id) | ||
loss = self.loss(pre,label) | ||
self.log('loss',loss,prog_bar=True) | ||
return loss | ||
|
||
def validation_step(self,batch, *args, **kwargs) -> Optional[STEP_OUTPUT]: | ||
data = batch[0] | ||
node_id, label = data[:, 0], data[:, 1] | ||
pre = self(node_id) | ||
self.evalute(pre,label) | ||
|
||
def test_step(self,batch, *args, **kwargs) -> Optional[STEP_OUTPUT]: | ||
return self.validation_step(batch) | ||
|
||
def on_fit_end(self) -> None: | ||
with open(self.trainer.log_dir + '/best_result.txt', mode='w') as f: | ||
result = {'micro-f1': float(self.max_micro_F1), 'macro-f1': float(self.max_macro_F1)} | ||
print('test_result:', result) | ||
f.write(str(result)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
task: 'simi_node_CL' | ||
data: | ||
# datapath,batch_size,is_dir=False,num_workers=0 | ||
batch_size: 16 | ||
datapath: 'data/Aifb/all_data.pkl' | ||
num_workers: 0 | ||
model: | ||
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd | ||
aggregator: 'agat' | ||
use_feature: False | ||
feature_dim: 1156 | ||
d_model: 32 | ||
type_num: 4 | ||
L: 3 | ||
use_gradient_checkpointing: False | ||
lr: 0.01 | ||
wd: 0.005 | ||
dropout: 0.1 | ||
callback: | ||
monitor: 'micro-f1' | ||
mode: 'max' | ||
train: | ||
max_epochs: 50 | ||
gpus: 1 | ||
# reload_dataloaders_every_n_epochs: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
task: 'simi_node_CL' | ||
data: | ||
# datapath,batch_size,is_dir=False,num_workers=0 | ||
batch_size: 16 | ||
datapath: 'data/PubMed/all_data.pkl' | ||
num_workers: 0 | ||
model: | ||
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd | ||
aggregator: 'agat' | ||
use_feature: False | ||
feature_dim: 200 | ||
d_model: 32 | ||
type_num: 8 | ||
L: 3 | ||
use_gradient_checkpointing: True | ||
lr: 0.01 | ||
wd: 0.005 | ||
dropout: 0.1 | ||
callback: | ||
monitor: 'micro-f1' | ||
mode: 'max' | ||
train: | ||
max_epochs: 50 | ||
gpus: 1 | ||
# reload_dataloaders_every_n_epochs: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from torch_sparse import coalesce | ||
import pickle | ||
|
||
|
||
def do(base_path,node_num,has_feature,is_dir): | ||
''' | ||
:param base_path: | ||
:param node_num: | ||
:param has_feature: | ||
:param is_dir: 是有向图 | ||
:return: | ||
''' | ||
edge_index,edge_type = get_train_sparse_adj(base_path+'/link.dat',is_dir) | ||
train_data = get_label_data(base_path+'/label.dat') | ||
test_data = get_label_data(base_path+'/label.dat.test') | ||
if has_feature: | ||
feature_data = get_feature_data(base_path+'/node.dat') | ||
else: | ||
feature_data = None | ||
|
||
all_data = {'edge_index':edge_index, | ||
'edge_type':edge_type, | ||
'feature_data':feature_data, | ||
'train_data':train_data, | ||
'test_data':test_data} | ||
torch.save(all_data,base_path+'/all_data.pkl') | ||
|
||
def get_feature_data(path): | ||
node_df = pd.read_csv(path, sep='\t', header=None, quoting=3) | ||
dd = node_df[3].str.split(',', expand=True).astype(np.float32) | ||
data = torch.from_numpy(dd.to_numpy(dtype=np.float32)) | ||
return data | ||
|
||
|
||
def get_label_data(path): | ||
df = pd.read_csv(path, sep='\t', index_col=None, header=None) | ||
df = df[[0,3]] | ||
data = torch.from_numpy(df.to_numpy(dtype=np.int64)) | ||
return data | ||
|
||
|
||
def get_train_sparse_adj(path,is_dir): | ||
df = pd.read_csv(path, sep='\t', index_col=None, header=None) | ||
# 替换成功 | ||
data = torch.from_numpy(df.to_numpy(dtype=np.int64)) | ||
data = data[:,[2,0,1]] # [edge_type,row,col] | ||
type_num = data[:,0].max() | ||
N = data[:,1:].max()+1 | ||
self_loop_index = torch.stack([torch.arange(N),torch.arange(N)]) | ||
self_loop_type = torch.zeros(N,dtype=torch.long) | ||
print('引入边,自环,num=',N) | ||
edge_index = [self_loop_index] | ||
edge_type = [self_loop_type] | ||
for type_id in range(1,type_num+1): | ||
# 对每类边施行反向,去重,操作 | ||
index = data[:,0]==type_id | ||
i = data[index,1:].T | ||
v = torch.ones(i.shape[1],dtype=torch.long) | ||
if not is_dir: | ||
# 无向图化有向图 | ||
i = torch.cat([i, i[[1, 0]]], dim=1) | ||
v = torch.cat([v, v], dim=0) | ||
# 去重 | ||
i,v = coalesce(i,v,N,N) | ||
v[:] = type_id | ||
print('引入边,类别 %d,num= %d'%(type_id,v.shape[0])) | ||
edge_index.append(i) | ||
edge_type.append(v) | ||
edge_index = torch.cat(edge_index,dim=1) | ||
edge_type = torch.cat(edge_type,dim=0) | ||
print('训练集总边数:',edge_index.shape[1]) | ||
# edge_index,edge_type = data[:, 1:].transpose(0, 1), data[:, 0] | ||
# train_adj = torch.sparse_coo_tensor(data[:, 1:].transpose(0, 1), data[:, 0]) | ||
return edge_index,edge_type | ||
|
||
if __name__ == '__main__': | ||
|
||
# base_path = '../data/PubMed' | ||
# node_num = 63109 | ||
# has_feature = True | ||
# is_dir = True | ||
|
||
base_path = '../data/Aifb' | ||
node_num = 8285 | ||
has_feature = False | ||
is_dir = True | ||
do(base_path,node_num,has_feature,is_dir) |