Skip to content

Commit

Permalink
添加半监督节点分类任务
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Feb 6, 2022
1 parent 466e41a commit f319134
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 5 deletions.
9 changes: 9 additions & 0 deletions dataloader/link_rank_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import TensorDataset, DataLoader
from utils.sparse_utils import *
from torch_sparse import coalesce
from dataloader.link_pre_dataloader import LinkPredictionDataloader
class LinkRankDataloader(LinkPredictionDataloader):
pass
33 changes: 33 additions & 0 deletions dataloader/node_cla_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import TensorDataset, DataLoader
from utils.sparse_utils import *
from torch_sparse import coalesce
class NodeClassificationDataloader(pl.LightningDataModule):
def __init__(self,datapath,batch_size,num_workers=0):
super(NodeClassificationDataloader, self).__init__()
self.datapath = datapath
self.batch_size = batch_size
self.num_workers = num_workers
self.read_data()
def read_data(self):
data = torch.load(self.datapath)
train_data, test_data, feature_data = data['train_data'], data['test_data'], data['feature_data']
self.feature_data = feature_data
self.test_dataset = TensorDataset(test_data)
self.train_dataset = TensorDataset(train_data)
self.edge_index, self.edge_type = data['edge_index'], data['edge_type']
self.N, self.E = self.edge_index.max() + 1, self.edge_index.shape[1]

def train_dataloader(self) -> TRAIN_DATALOADERS:
return DataLoader(self.train_dataset,self.batch_size,shuffle=True,num_workers=self.num_workers,drop_last=True)

def test_dataloader(self) -> EVAL_DATALOADERS:
return DataLoader(self.test_dataset,batch_size=len(self.test_dataset))

def val_dataloader(self) -> EVAL_DATALOADERS:
return DataLoader(self.test_dataset,batch_size=len(self.test_dataset))

if __name__ == '__main__':
dataloader = NodeClassificationDataloader('../data/Aifb/all_data.pkl', 64)
15 changes: 12 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@
import torch

from dataloader.link_pre_dataloader import LinkPredictionDataloader
from dataloader.node_cla_dataloader import NodeClassificationDataloader
from models.LinkPreTask import LinkPredictionTask
from models.NodeCLTask import NodeClassificationTask
import pytorch_lightning as pl
import yaml
import argparse

TASK = {
'link_pre':(LinkPredictionDataloader,LinkPredictionTask),
'simi_node_CL':(NodeClassificationDataloader,NodeClassificationTask)
}

def get_trainer_model_dataloader_from_yaml(yaml_path):
with open(yaml_path) as f:
settings = dict(yaml.load(f,yaml.FullLoader))

dl = LinkPredictionDataloader(**settings['data'])
model = LinkPredictionTask(dl.edge_index,dl.edge_type,dl.feature_data,dl.N, **settings['model'])
DATALOADER,MODEL=TASK[settings['task']]

dl = DATALOADER(**settings['data'])
model = MODEL(dl.edge_index,dl.edge_type,dl.feature_data,dl.N, **settings['model'])
checkpoint_callback = pl.callbacks.ModelCheckpoint(**settings['callback'])
trainer = pl.Trainer(callbacks=[checkpoint_callback], **settings['train'])
return trainer,model,dl
Expand Down Expand Up @@ -47,7 +56,7 @@ def test(parser):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--setting_path',type=str,default='settings/yot_settings.yaml')
parser.add_argument('--setting_path',type=str,default='settings/pub_settings.yaml')
parser.add_argument("--test", action='store_true', help='test or train')
temp_args, _ = parser.parse_known_args()
if temp_args.test:
Expand Down
34 changes: 34 additions & 0 deletions models/LinkRankTask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional

import torch
import torchmetrics
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from torch import nn
import pytorch_lightning as pl
import numpy as np

from models.AGAT import AGAT
from models.LinkPreTask import LinkPredictionTask
class LinkRankTask(LinkPredictionTask):

def __init__(self, edge_index, edge_type, feature, N, aggregator, use_feature, feature_dim, d_model, type_num, L,
use_gradient_checkpointing, neg_num, dropout, lr, wd):
super().__init__(edge_index, edge_type, feature, N, aggregator, use_feature, feature_dim, d_model, type_num, L,
use_gradient_checkpointing, neg_num, dropout, lr, wd)

def training_step(self, batch, *args, **kwargs) -> STEP_OUTPUT:
return super().training_step(batch, *args, **kwargs)

def validation_step(self, batch, *args, **kwargs) -> Optional[STEP_OUTPUT]:
# 剔除不在训练集中的
return super().validation_step(batch, *args, **kwargs)

def test_step(self, batch, *args, **kwargs) -> Optional[STEP_OUTPUT]:

return super().test_step(batch, *args, **kwargs)

def on_test_end(self) -> None:
super().on_test_end()

def on_fit_end(self) -> None:
super().on_fit_end()
102 changes: 102 additions & 0 deletions models/NodeCLTask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Optional, Union, List

import torch
import torchmetrics
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from torch import nn
import pytorch_lightning as pl
import numpy as np
from torch.optim import Optimizer

from models.AGAT import AGAT

class NodeClassificationTask(pl.LightningModule):
def __init__(self, edge_index, edge_type, feature, N, aggregator, use_feature, feature_dim, d_model, type_num, L,
use_gradient_checkpointing, dropout, lr, wd):
super(NodeClassificationTask, self).__init__()
self.save_hyperparameters(ignore=['edge_index','edge_type','feature','N','degree'])
self.register_buffer('edge_index', edge_index)
self.register_buffer('edge_type', edge_type)
edge_type_num = edge_type.max()+1
self.register_buffer('edge_feature', torch.eye(edge_type_num))
self.fc_edge = nn.Linear(edge_type_num, d_model)

if use_feature:
self.register_buffer('feature',feature)
self.fc_node = nn.Linear(feature_dim, d_model)
else:
self.feature = nn.Parameter(torch.randn(N,d_model))

self.w = nn.Parameter(torch.FloatTensor(type_num, d_model))
nn.init.xavier_uniform_(self.w)
if aggregator == 'agat':
self.agat = AGAT(type_num, d_model, L, use_gradient_checkpointing, dropout)
elif aggregator == 'sgat':
self.sgat = AGAT(1, d_model, L, use_gradient_checkpointing, dropout)

self.loss = nn.CrossEntropyLoss()
self.max_macro_F1 = -np.inf
self.max_micro_F1 = -np.inf
self.micro_f1_cal = torchmetrics.F1(num_classes=type_num,average='micro',multiclass=True)
self.macro_f1_cal = torchmetrics.F1(num_classes=type_num,average='macro',multiclass=True)

def evalute(self,pre,label):
micro_F1 = self.micro_f1_cal(pre,label)
macro_F1 = self.macro_f1_cal(pre,label)
if self.max_micro_F1 < micro_F1:
self.max_micro_F1 = micro_F1
self.max_macro_F1 = macro_F1
self.log('micro-f1',micro_F1,prog_bar=True)
self.log('macro-f1',macro_F1,prog_bar=True)
self.micro_f1_cal.reset()
self.macro_f1_cal.reset()

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.wd)
return optimizer

def get_em(self):
if self.hparams.use_feature:
feature = self.fc_node(self.feature)
else:
feature = self.feature
edge_feature = self.fc_edge(self.edge_feature)
if self.hparams.aggregator=='agat':
em = self.agat(feature,self.edge_index,self.edge_type,edge_feature)
elif self.hparams.aggregator=='sgat':
em = self.sgat(feature,self.edge_index,self.edge_type,edge_feature)\
.expand(self.hparams.type_num,feature.shape[0],self.hparams.d_model)
return em #t,N,d_model

def forward(self,node_id):
em = self.get_em()
node_em = em[:,node_id].transpose(0,1) #bs,t,d
logits = (node_em * self.w).sum(-1) # bs,t
return logits

def training_step(self,batch, *args, **kwargs) -> STEP_OUTPUT:
data = batch[0]
node_id,label = data[:,0],data[:,1]
pre = self(node_id)
loss = self.loss(pre,label)
self.log('loss',loss,prog_bar=True)
return loss

def validation_step(self,batch, *args, **kwargs) -> Optional[STEP_OUTPUT]:
data = batch[0]
node_id, label = data[:, 0], data[:, 1]
pre = self(node_id)
self.evalute(pre,label)

def test_step(self,batch, *args, **kwargs) -> Optional[STEP_OUTPUT]:
return self.validation_step(batch)

def on_fit_end(self) -> None:
with open(self.trainer.log_dir + '/best_result.txt', mode='w') as f:
result = {'micro-f1': float(self.max_micro_F1), 'macro-f1': float(self.max_macro_F1)}
print('test_result:', result)
f.write(str(result))

25 changes: 25 additions & 0 deletions settings/aifb_settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
task: 'simi_node_CL'
data:
# datapath,batch_size,is_dir=False,num_workers=0
batch_size: 16
datapath: 'data/Aifb/all_data.pkl'
num_workers: 0
model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
aggregator: 'agat'
use_feature: False
feature_dim: 1156
d_model: 32
type_num: 4
L: 3
use_gradient_checkpointing: False
lr: 0.01
wd: 0.005
dropout: 0.1
callback:
monitor: 'micro-f1'
mode: 'max'
train:
max_epochs: 50
gpus: 1
# reload_dataloaders_every_n_epochs: 1
4 changes: 2 additions & 2 deletions settings/ama_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ model:
aggregator: 'agat'
use_feature: True
feature_dim: 1156
d_model: 64
d_model: 32
type_num: 2
L: 6
L: 2
use_gradient_checkpointing: False
neg_num: 1
lr: 0.005
Expand Down
25 changes: 25 additions & 0 deletions settings/pub_settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
task: 'simi_node_CL'
data:
# datapath,batch_size,is_dir=False,num_workers=0
batch_size: 16
datapath: 'data/PubMed/all_data.pkl'
num_workers: 0
model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
aggregator: 'agat'
use_feature: False
feature_dim: 200
d_model: 32
type_num: 8
L: 3
use_gradient_checkpointing: True
lr: 0.01
wd: 0.005
dropout: 0.1
callback:
monitor: 'micro-f1'
mode: 'max'
train:
max_epochs: 50
gpus: 1
# reload_dataloaders_every_n_epochs: 1
90 changes: 90 additions & 0 deletions utils/dataprepare4nodeCL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy as np
import pandas as pd
import torch
from torch_sparse import coalesce
import pickle


def do(base_path,node_num,has_feature,is_dir):
'''
:param base_path:
:param node_num:
:param has_feature:
:param is_dir: 是有向图
:return:
'''
edge_index,edge_type = get_train_sparse_adj(base_path+'/link.dat',is_dir)
train_data = get_label_data(base_path+'/label.dat')
test_data = get_label_data(base_path+'/label.dat.test')
if has_feature:
feature_data = get_feature_data(base_path+'/node.dat')
else:
feature_data = None

all_data = {'edge_index':edge_index,
'edge_type':edge_type,
'feature_data':feature_data,
'train_data':train_data,
'test_data':test_data}
torch.save(all_data,base_path+'/all_data.pkl')

def get_feature_data(path):
node_df = pd.read_csv(path, sep='\t', header=None, quoting=3)
dd = node_df[3].str.split(',', expand=True).astype(np.float32)
data = torch.from_numpy(dd.to_numpy(dtype=np.float32))
return data


def get_label_data(path):
df = pd.read_csv(path, sep='\t', index_col=None, header=None)
df = df[[0,3]]
data = torch.from_numpy(df.to_numpy(dtype=np.int64))
return data


def get_train_sparse_adj(path,is_dir):
df = pd.read_csv(path, sep='\t', index_col=None, header=None)
# 替换成功
data = torch.from_numpy(df.to_numpy(dtype=np.int64))
data = data[:,[2,0,1]] # [edge_type,row,col]
type_num = data[:,0].max()
N = data[:,1:].max()+1
self_loop_index = torch.stack([torch.arange(N),torch.arange(N)])
self_loop_type = torch.zeros(N,dtype=torch.long)
print('引入边,自环,num=',N)
edge_index = [self_loop_index]
edge_type = [self_loop_type]
for type_id in range(1,type_num+1):
# 对每类边施行反向,去重,操作
index = data[:,0]==type_id
i = data[index,1:].T
v = torch.ones(i.shape[1],dtype=torch.long)
if not is_dir:
# 无向图化有向图
i = torch.cat([i, i[[1, 0]]], dim=1)
v = torch.cat([v, v], dim=0)
# 去重
i,v = coalesce(i,v,N,N)
v[:] = type_id
print('引入边,类别 %d,num= %d'%(type_id,v.shape[0]))
edge_index.append(i)
edge_type.append(v)
edge_index = torch.cat(edge_index,dim=1)
edge_type = torch.cat(edge_type,dim=0)
print('训练集总边数:',edge_index.shape[1])
# edge_index,edge_type = data[:, 1:].transpose(0, 1), data[:, 0]
# train_adj = torch.sparse_coo_tensor(data[:, 1:].transpose(0, 1), data[:, 0])
return edge_index,edge_type

if __name__ == '__main__':

# base_path = '../data/PubMed'
# node_num = 63109
# has_feature = True
# is_dir = True

base_path = '../data/Aifb'
node_num = 8285
has_feature = False
is_dir = True
do(base_path,node_num,has_feature,is_dir)

0 comments on commit f319134

Please sign in to comment.