Skip to content

Commit

Permalink
提高evalute效率
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Feb 9, 2022
1 parent 9f54fe5 commit fcc8dcc
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 22 deletions.
27 changes: 18 additions & 9 deletions models/LinkRankTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def training_step(self, batch, *args, **kwargs) -> STEP_OUTPUT:
neg_num=self.hparams.neg_num)
self.log('loss1', l1, prog_bar=True)
loss = l1
if self.hparams.aggregator == 'agat':
logits = (em[:, source] * self.w[target].unsqueeze(0)).sum(-1).T # bs,t
l2 = self.loss2(logits, pos_edge_type)
self.log('loss2', l2, prog_bar=True)
self.log('loss_all', l1 + l2, prog_bar=True)
loss = loss + l2
# if self.hparams.aggregator == 'agat':
# logits = (em[:, source] * self.w[target].unsqueeze(0)).sum(-1).T # bs,t
# l2 = self.loss2(logits, pos_edge_type)
# self.log('loss2', l2, prog_bar=True)
# self.log('loss_all', l1 + l2, prog_bar=True)
# loss = loss + l2
return loss

def evalute(self,obj,pred,label):
Expand All @@ -44,9 +44,14 @@ def evalute(self,obj,pred,label):
results={}
b_range = torch.arange(pred.size()[0], device=self.device)
target_pred = pred[b_range, obj]
pred = torch.where(label.byte(), -torch.ones_like(pred) * 10000000, pred)
pred[label] = -np.inf
# torch.where(label.byte(), -torch.ones_like(pred) * 10000000, pred)
pred[b_range, obj] = target_pred
ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True), dim=1, descending=False)[b_range, obj]
pred = torch.argsort(pred, dim=1, descending=True)
ranks = 1+torch.argmax((pred==obj.unsqueeze(-1)).byte(),dim=1)
# pred = torch.argsort(pred, dim=1, descending=False)
# ranks = 1 + pred[b_range, obj]
del pred
ranks = ranks.float()
results['count'] = torch.numel(ranks)
results['mr'] = torch.sum(ranks).item()
Expand Down Expand Up @@ -89,7 +94,10 @@ def get_evalute_result(self,batch):
head,rel,tail = triple[:,0],triple[:,1],triple[:,2]
em = self.get_em()
pred = em[rel,head] @ self.w.T #bs*2,N

# pred = []
# for i in range(bs*2):
# pred.append((em[rel[i]]*self.w[head[i]]).sum(-1))
# pred = torch.stack(pred)
left_reslut = self.evalute(tail[:bs],pred[:bs],label[:bs])
right_reslut = self.evalute(tail[bs:],pred[bs:],label[bs:])
result = self.get_combined_results(left_reslut,right_reslut)
Expand All @@ -100,6 +108,7 @@ def validation_step(self, batch, *args, **kwargs) -> Optional[STEP_OUTPUT]:
if self.val_result['mrr']<result['mrr']:
self.val_result = result
self.log_dict(result)
self.log('-mrr',result['mrr'],prog_bar=True)


def test_step(self, batch, *args, **kwargs) -> Optional[STEP_OUTPUT]:
Expand Down
14 changes: 6 additions & 8 deletions settings/pub_settings.yaml
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
task: 'simi_node_CL'
data:
# datapath,batch_size,is_dir=False,num_workers=0
batch_size: 16
batch_size: 32
datapath: 'data/PubMed/all_data.pkl'
num_workers: 0
model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
aggregator: 'agat'
use_feature: False
use_feature: True
feature_dim: 200
d_model: 32
d_model: 64
type_num: 8
L: 3
use_gradient_checkpointing: True
lr: 0.01
wd: 0.005
lr: 0.005
wd: 0.000
dropout: 0.1
callback:
monitor: 'micro-f1'
mode: 'max'
train:
max_epochs: 50
max_epochs: 100
gpus: 1
# reload_dataloaders_every_n_epochs: 1
11 changes: 6 additions & 5 deletions settings/wn_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@ model:
aggregator: 'agat'
use_feature: False
feature_dim: 1156
d_model: 32
d_model: 64
# reverse
type_num: 22
neg_num: 1
neg_num: 10
L: 3
use_gradient_checkpointing: True
lr: 0.001
wd: 0.00
lr: 0.005
wd: 0.000
dropout: 0.1
callback:
monitor: 'mrr'
mode: 'max'
train:
max_epochs: 50
max_epochs: 200
gpus: 1
precision: 16

0 comments on commit fcc8dcc

Please sign in to comment.