Skip to content

Commit

Permalink
撤销‘根据度进行负采样’,添加F1
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Jan 26, 2022
1 parent a699240 commit 4382efe
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
26 changes: 16 additions & 10 deletions models/LinkPreTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class LinkPredictionTask(pl.LightningModule):
def __init__(self,edge_index,edge_type,feature,N,degree,use_feature,feature_dim,d_model,type_num, L,use_gradient_checkpointing,neg_num,dropout,lr,wd):
def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model,type_num, L,use_gradient_checkpointing,neg_num,dropout,lr,wd):
super(LinkPredictionTask, self).__init__()
# 工程类组件
self.save_hyperparameters(ignore=['edge_index','edge_type','feature','N','degree'])
Expand All @@ -25,11 +25,13 @@ def __init__(self,edge_index,edge_type,feature,N,degree,use_feature,feature_dim,
self.feature = nn.Parameter(torch.randn(N,d_model))
# nn.init.xavier_uniform_(self.feature)
self.loss2 = nn.CrossEntropyLoss()
self.loss1 = NCELoss(N,degree)
self.loss1 = NCELoss(N)
self.val_best_auc = 0
self.val_best_aupr = 0
self.val_best_f1 = 0
self.test_best_auc = 0
self.test_best_aupr = 0
self.test_best_f1 = 0
#

self.fc_edge = nn.Linear(type_num+1,d_model)
Expand Down Expand Up @@ -70,11 +72,14 @@ def validation_step(self, batch,*args, **kwargs) -> Optional[STEP_OUTPUT]:
score = torch.sigmoid(score)
auc = torchmetrics.functional.auroc(score, label, pos_label=1)
aupr = torchmetrics.functional.average_precision(score, label, pos_label=1)
f1 = torchmetrics.functional.f1(score,label)
if auc > self.val_best_auc:
self.val_best_auc = auc
self.val_best_aupr = aupr
self.val_best_f1 = f1
self.log('val_auc', auc, prog_bar=True)
self.log('val_aupr', aupr, prog_bar=True)
self.log('val_f1', f1, prog_bar=True)

def test_step(self, batch,*args, **kwargs) -> Optional[STEP_OUTPUT]:
em = self.get_em()
Expand All @@ -84,18 +89,21 @@ def test_step(self, batch,*args, **kwargs) -> Optional[STEP_OUTPUT]:
score = torch.sigmoid(score)
auc = torchmetrics.functional.auroc(score, label, pos_label=1)
aupr = torchmetrics.functional.average_precision(score, label, pos_label=1)
f1 = torchmetrics.functional.f1(score, label)
if auc > self.test_best_auc:
self.test_best_auc = auc
self.test_best_aupr = aupr
self.test_best_f1 = f1

def on_test_end(self) -> None:
with open(self.trainer.log_dir + '/best_result.txt', mode='w') as f:
result = {'auc': float(self.test_best_auc), 'aupr': float(self.test_best_aupr)}
result = {'auc': float(self.test_best_auc), 'aupr': float(self.test_best_aupr),'f1': float(self.test_best_f1)}
print('test_result:', result)
f.write(str(result))
# 结束时存储最优结果
# 结束时存储最优验证结果
def on_fit_end(self) -> None:
with open(self.trainer.log_dir + '/val_best_result.txt', mode='w') as f:
result = {'auc': float(self.val_best_auc), 'aupr': float(self.val_best_aupr)}
result = {'auc': float(self.val_best_auc), 'aupr': float(self.val_best_aupr),'f1':float(self.val_best_f1)}
print('val_best_result:', result)
f.write(str(result))

Expand All @@ -108,15 +116,13 @@ def configure_optimizers(self):
return optimizer

class NCELoss(nn.Module):
def __init__(self,N,degree):
def __init__(self,N):
super(NCELoss, self).__init__()
self.N = N
self.register_buffer('degree',degree)
self.bce=nn.BCEWithLogitsLoss()
def forward(self,inputs,weights,labels,neg_num):
# neg_batch = torch.randint(0, self.N, (neg_num*inputs.shape[0],),
# dtype=torch.long,device=inputs.device)
neg_batch = torch.multinomial(self.degree,neg_num*inputs.shape[0],True)
neg_batch = torch.randint(0, self.N, (neg_num*inputs.shape[0],),
dtype=torch.long,device=inputs.device)
target = weights[torch.cat([labels,neg_batch],dim=0)]
label = torch.zeros(target.shape[0],device=inputs.device)
label[:labels.shape[0]]=1
Expand Down
6 changes: 3 additions & 3 deletions settings/ama_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
use_feature: True
feature_dim: 1156
d_model: 128
d_model: 64
type_num: 2
L: 3
L: 4
use_gradient_checkpointing: False
neg_num: 4
lr: 0.01
wd: 0.0001
wd: 0.000
dropout: 0.1
callback:
monitor: 'val_auc'
Expand Down

0 comments on commit 4382efe

Please sign in to comment.