Skip to content

Commit

Permalink
添加SGAT
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Jan 30, 2022
1 parent 4382efe commit 466e41a
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 22 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test(parser):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--setting_path',type=str,default='settings/ama_settings.yaml')
parser.add_argument('--setting_path',type=str,default='settings/yot_settings.yaml')
parser.add_argument("--test", action='store_true', help='test or train')
temp_args, _ = parser.parse_known_args()
if temp_args.test:
Expand Down
29 changes: 19 additions & 10 deletions models/LinkPreTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class LinkPredictionTask(pl.LightningModule):
def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model,type_num, L,use_gradient_checkpointing,neg_num,dropout,lr,wd):
def __init__(self,edge_index,edge_type,feature,N,aggregator,use_feature,feature_dim,d_model,type_num, L,use_gradient_checkpointing,neg_num,dropout,lr,wd):
super(LinkPredictionTask, self).__init__()
# 工程类组件
self.save_hyperparameters(ignore=['edge_index','edge_type','feature','N','degree'])
Expand All @@ -37,15 +37,21 @@ def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model
self.fc_edge = nn.Linear(type_num+1,d_model)
self.w = nn.Parameter(torch.FloatTensor(N,d_model))
nn.init.xavier_uniform_(self.w)
self.agat = AGAT(type_num,d_model,L,use_gradient_checkpointing,dropout)

if aggregator=='agat':
self.agat = AGAT(type_num,d_model,L,use_gradient_checkpointing,dropout)
elif aggregator=='sgat':
self.sgat = AGAT(1,d_model,L,use_gradient_checkpointing,dropout)
def get_em(self,mask=None):
if self.hparams.use_feature:
feature = self.fc_node(self.feature)
else:
feature = self.feature
edge_feature = self.fc_edge(self.edge_feature)
em = self.agat(feature,self.edge_index,self.edge_type,edge_feature,mask)
if self.hparams.aggregator=='agat':
em = self.agat(feature,self.edge_index,self.edge_type,edge_feature,mask)
elif self.hparams.aggregator=='sgat':
em = self.sgat(feature,self.edge_index,self.edge_type,edge_feature,mask)\
.expand(self.hparams.type_num,feature.shape[0],self.hparams.d_model)
return em

def training_step(self, batch,*args, **kwargs) -> STEP_OUTPUT:
Expand All @@ -54,15 +60,18 @@ def training_step(self, batch,*args, **kwargs) -> STEP_OUTPUT:
source = pos_edge[:,0]
target = pos_edge[:,1]
l1 = self.loss1(inputs=em[pos_edge_type-1,source],weights=self.w,labels=target,neg_num=self.hparams.neg_num)
self.log('loss1', l1, prog_bar=True)
loss=l1
# em[:,source] #t,bs,d
# self.w[target].unsqueeze(0) #1, bs,d
# (em[:, source] * self.w[target].unsqueeze(0)).sum(-1) #t,bs
logits = (em[:, source] * self.w[target].unsqueeze(0)).sum(-1).T # bs,t
l2 = self.loss2(logits,pos_edge_type-1)
self.log('loss1', l1, prog_bar=True)
self.log('loss2', l2, prog_bar=True)
self.log('loss_all', l1+l2, prog_bar=True)
return l1+l2
if self.hparams.aggregator=='agat':
logits = (em[:, source] * self.w[target].unsqueeze(0)).sum(-1).T # bs,t
l2 = self.loss2(logits,pos_edge_type-1)
self.log('loss2', l2, prog_bar=True)
self.log('loss_all', l1+l2, prog_bar=True)
loss = loss+l2
return loss

def validation_step(self, batch,*args, **kwargs) -> Optional[STEP_OUTPUT]:
em = self.get_em()
Expand Down
50 changes: 50 additions & 0 deletions plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os

import torch
import pytorch_lightning as pl
import yaml
from dataloader.link_pre_dataloader import LinkPredictionDataloader
from models.LinkPreTask import LinkPredictionTask
# 用来在晚上连续跑实验的工具
def get_trainer_model_dataloader_from_dir(settings):
dl = LinkPredictionDataloader(**settings['data'])
model = LinkPredictionTask(dl.edge_index, dl.edge_type, dl.feature_data, dl.N, **settings['model'])
checkpoint_callback = pl.callbacks.ModelCheckpoint(**settings['callback'])
trainer = pl.Trainer(callbacks=[checkpoint_callback], **settings['train'])
return trainer, model, dl

def plan(base_settings,model_replace_key,model_replace_values):
'''
:param base_settings: 基础配置
:param model_replace_key: 取代的超参
:param model_replace_values: 超参值的列表
:return:
'''
for v in model_replace_values:
base_settings['model'][model_replace_key] = v
print('--------------------------------------------------')
print(model_replace_key, '=', v, 'has bean done!')
trainer,model,dl=get_trainer_model_dataloader_from_dir(base_settings)
trainer.fit(model,dl)
# 测试
# 加载参数
ckpt_path = trainer.log_dir + '/checkpoints/' + os.listdir(trainer.log_dir + '/checkpoints')[0]
state_dict = torch.load(ckpt_path)['state_dict']
model.load_state_dict(state_dict)
trainer.test(model, dl.test_dataloader())
print(model_replace_key, '=', v, 'has finished! result in',trainer.log_dir)
print('--------------------------------------------------')
del trainer
del model
del dl
print('finish plan!')

if __name__ == '__main__':
yaml_path = 'settings/yot_settings.yaml'
key = 'L'
values = [1,2,3,4,5,6]
# key = 'lam'
# values = [1.,0.5,0.3,0.05,0.01,0.001]
with open(yaml_path) as f:
settings = dict(yaml.load(f,yaml.FullLoader))
plan(settings,key,values)
7 changes: 4 additions & 3 deletions settings/ama_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ data:
num_workers: 0
model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
aggregator: 'agat'
use_feature: True
feature_dim: 1156
d_model: 64
type_num: 2
L: 4
L: 6
use_gradient_checkpointing: False
neg_num: 4
lr: 0.01
neg_num: 1
lr: 0.005
wd: 0.000
dropout: 0.1
callback:
Expand Down
11 changes: 6 additions & 5 deletions settings/tiw_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ data:
num_workers: 0
model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
aggregator: 'agat'
use_feature: False
feature_dim: 1156
d_model: 128
d_model: 64
type_num: 4
L: 3
L: 4
use_gradient_checkpointing: False
neg_num: 5
neg_num: 1
lr: 0.005
wd: 0.0001
wd: 0.000
dropout: 0.1
callback:
monitor: 'val_auc'
mode: 'max'
train:
max_epochs: 100
max_epochs: 50
gpus: 1
# reload_dataloaders_every_n_epochs: 1
7 changes: 4 additions & 3 deletions settings/yot_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ data:
batch_size: 16384
datapath: 'data/youtube/all_data.pkl'
is_dir: False
num_workers: 0
num_workers: 8
model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
aggregator: 'agat'
use_feature: False
feature_dim: 1156
d_model: 32
type_num: 5
L: 3
L: 6
use_gradient_checkpointing: True
neg_num: 5
neg_num: 1
lr: 0.01
wd: 0.0001
dropout: 0.1
Expand Down

0 comments on commit 466e41a

Please sign in to comment.