Skip to content

Commit

Permalink
添加linkrank任务
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Feb 8, 2022
1 parent f319134 commit 9f54fe5
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 11 deletions.
3 changes: 1 addition & 2 deletions dataloader/link_pre_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from utils.sparse_utils import *
from torch_sparse import coalesce
class LinkPredictionDataloader(pl.LightningDataModule):
def __init__(self,datapath,batch_size,is_dir=False,num_workers=0):
def __init__(self,datapath,batch_size,num_workers=0):
super(LinkPredictionDataloader, self).__init__()
self.datapath = datapath
self.batch_size = batch_size
self.is_dir = is_dir
self.num_workers = num_workers
self.read_data()
def read_data(self):
Expand Down
14 changes: 13 additions & 1 deletion dataloader/link_rank_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,16 @@
from torch_sparse import coalesce
from dataloader.link_pre_dataloader import LinkPredictionDataloader
class LinkRankDataloader(LinkPredictionDataloader):
pass

def read_data(self):
data = torch.load(self.datapath)
# val_data,test_data,feature_data = data['val_data'],data['test_data'],data['feature_data']
self.feature_data = None
self.val_dataset = TensorDataset(data['val_triple'],data['val_label'])
self.test_dataset = TensorDataset(data['test_triple'],data['test_label'])
self.edge_index,self.edge_type = data['edge_index'],data['edge_type']
self.N,self.E = data['p'].num_ent,self.edge_index.shape[1]
self.edge_id = torch.arange(self.E)
# mask除去自环
mask = self.edge_type<(data['p'].num_rel*2)
self.train_dataset = TensorDataset(self.edge_index.T[mask],self.edge_type[mask],self.edge_id[mask])
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
import torch

from dataloader.link_pre_dataloader import LinkPredictionDataloader
from dataloader.link_rank_dataloader import LinkRankDataloader
from dataloader.node_cla_dataloader import NodeClassificationDataloader
from models.LinkPreTask import LinkPredictionTask
from models.LinkRankTask import LinkRankTask
from models.NodeCLTask import NodeClassificationTask
import pytorch_lightning as pl
import yaml
import argparse

TASK = {
'link_pre':(LinkPredictionDataloader,LinkPredictionTask),
'link_rank':(LinkRankDataloader,LinkRankTask),
'simi_node_CL':(NodeClassificationDataloader,NodeClassificationTask)
}

Expand Down Expand Up @@ -56,7 +59,7 @@ def test(parser):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--setting_path',type=str,default='settings/pub_settings.yaml')
parser.add_argument('--setting_path',type=str,default='settings/wn_settings.yaml')
parser.add_argument("--test", action='store_true', help='test or train')
temp_args, _ = parser.parse_known_args()
if temp_args.test:
Expand Down
97 changes: 90 additions & 7 deletions models/LinkRankTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,103 @@ def __init__(self, edge_index, edge_type, feature, N, aggregator, use_feature, f
use_gradient_checkpointing, neg_num, dropout, lr, wd):
super().__init__(edge_index, edge_type, feature, N, aggregator, use_feature, feature_dim, d_model, type_num, L,
use_gradient_checkpointing, neg_num, dropout, lr, wd)
self.val_result = {'mrr':-np.inf}

def training_step(self, batch, *args, **kwargs) -> STEP_OUTPUT:
return super().training_step(batch, *args, **kwargs)
pos_edge, pos_edge_type, edge_id = batch
em = self.get_em(mask=edge_id) # type_num,N,d_model
source = pos_edge[:, 0]
target = pos_edge[:, 1]
l1 = self.loss1(inputs=em[pos_edge_type, source], weights=self.w, labels=target,
neg_num=self.hparams.neg_num)
self.log('loss1', l1, prog_bar=True)
loss = l1
if self.hparams.aggregator == 'agat':
logits = (em[:, source] * self.w[target].unsqueeze(0)).sum(-1).T # bs,t
l2 = self.loss2(logits, pos_edge_type)
self.log('loss2', l2, prog_bar=True)
self.log('loss_all', l1 + l2, prog_bar=True)
loss = loss + l2
return loss

def evalute(self,obj,pred,label):
'''
the code comes from compGCN
:param pre: bs,
:param label:
:return:
'''
results={}
b_range = torch.arange(pred.size()[0], device=self.device)
target_pred = pred[b_range, obj]
pred = torch.where(label.byte(), -torch.ones_like(pred) * 10000000, pred)
pred[b_range, obj] = target_pred
ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True), dim=1, descending=False)[b_range, obj]
ranks = ranks.float()
results['count'] = torch.numel(ranks)
results['mr'] = torch.sum(ranks).item()
results['mrr'] = torch.sum(1.0 / ranks).item()
for k in range(10):
results['hits@{}'.format(k + 1)] = torch.numel(ranks[ranks <= (k + 1)])
return results

def get_combined_results(self,left_results, right_results):
'''
the code comes from compGCN
:param left_results:
:param right_results:
:return:
'''
results = {}
count = float(left_results['count'])

results['left_mr'] = round(left_results['mr'] / count, 5)
results['left_mrr'] = round(left_results['mrr'] / count, 5)
results['right_mr'] = round(right_results['mr'] / count, 5)
results['right_mrr'] = round(right_results['mrr'] / count, 5)
results['mr'] = round((left_results['mr'] + right_results['mr']) / (2 * count), 5)
results['mrr'] = round((left_results['mrr'] + right_results['mrr']) / (2 * count), 5)

for k in range(10):
results['left_hits@{}'.format(k + 1)] = round(left_results['hits@{}'.format(k + 1)] / count, 5)
results['right_hits@{}'.format(k + 1)] = round(right_results['hits@{}'.format(k + 1)] / count, 5)
results['hits@{}'.format(k + 1)] = round(
(left_results['hits@{}'.format(k + 1)] + right_results['hits@{}'.format(k + 1)]) / (2 * count), 5)
return results

def get_evalute_result(self,batch):
'''
:param batch: 前一半是 预测tail节点,得到left_result;后一半是预测head节点,得到right_result
:return: results
'''
triple,label = batch
bs = triple.shape[0]//2
head,rel,tail = triple[:,0],triple[:,1],triple[:,2]
em = self.get_em()
pred = em[rel,head] @ self.w.T #bs*2,N

left_reslut = self.evalute(tail[:bs],pred[:bs],label[:bs])
right_reslut = self.evalute(tail[bs:],pred[bs:],label[bs:])
result = self.get_combined_results(left_reslut,right_reslut)
return result

def validation_step(self, batch, *args, **kwargs) -> Optional[STEP_OUTPUT]:
# 剔除不在训练集中的
return super().validation_step(batch, *args, **kwargs)
result = self.get_evalute_result(batch)
if self.val_result['mrr']<result['mrr']:
self.val_result = result
self.log_dict(result)

def test_step(self, batch, *args, **kwargs) -> Optional[STEP_OUTPUT]:

return super().test_step(batch, *args, **kwargs)
def test_step(self, batch, *args, **kwargs) -> Optional[STEP_OUTPUT]:
result = self.get_evalute_result(batch)
self.test_result = result

def on_test_end(self) -> None:
super().on_test_end()
with open(self.trainer.log_dir + '/best_result.txt', mode='w') as f:
print('test_result:', self.test_result)
f.write(str(self.test_result))

def on_fit_end(self) -> None:
super().on_fit_end()
with open(self.trainer.log_dir + '/best_val_result.txt', mode='w') as f:
print('val_result:', self.val_result)
f.write(str(self.val_result))
26 changes: 26 additions & 0 deletions settings/wn_settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
task: 'link_rank'
data:
# datapath,batch_size,is_dir=False,num_workers=0
batch_size: 4096
datapath: 'data/WN18RR/all_data.pkl'
num_workers: 0
model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
aggregator: 'agat'
use_feature: False
feature_dim: 1156
d_model: 32
# reverse
type_num: 22
neg_num: 1
L: 3
use_gradient_checkpointing: True
lr: 0.001
wd: 0.00
dropout: 0.1
callback:
monitor: 'mrr'
mode: 'max'
train:
max_epochs: 50
gpus: 1
Loading

0 comments on commit 9f54fe5

Please sign in to comment.