Skip to content

Commit

Permalink
add mask
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Jan 26, 2022
1 parent 79bb02d commit 247cdea
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
5 changes: 3 additions & 2 deletions models/AGAT.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch
from torch import nn
from torch_scatter import scatter_softmax, scatter_add
Expand Down Expand Up @@ -36,12 +37,11 @@ def forward(self,x,edge_index,edge_type,edge_feature,mask=None):
x, edge_feature = self.layer_list[i](x,edge_index,edge_type,edge_feature,mask)
if i == self.L-1:
break
x = self.relu(x_+self.dropout[i](x))
x = x_+self.relu(self.dropout[i](x))
edge_feature = self.relu(edge_feature)

return x


class AGATLayer(nn.Module):
def __init__(self,type_num,d_model):
super(AGATLayer, self).__init__()
Expand Down Expand Up @@ -80,6 +80,7 @@ def forward(self,x,edge_index,edge_type,edge_feature,mask):
# h = x.index_select(1,col) # t,E,d
# r = (torch.cat([path,h],dim=-1) * theta.unsqueeze(1)).sum(-1) #t,E
if mask is not None:
r = r.index_fill(1,mask,-np.inf)
pass
r = scatter_softmax(r, row, dim=-1) # t,E
edge_feature = edge_feature @ wr # et,d_model
Expand Down
2 changes: 1 addition & 1 deletion models/LinkPreTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_em(self,mask=None):

def training_step(self, batch,*args, **kwargs) -> STEP_OUTPUT:
pos_edge,pos_edge_type,edge_id = batch
em = self.get_em() #type_num,N,d_model
em = self.get_em(mask=edge_id) #type_num,N,d_model
source = pos_edge[:,0]
target = pos_edge[:,1]
l1 = self.loss1(inputs=em[pos_edge_type-1,source],weights=self.w,labels=target,neg_num=self.hparams.neg_num)
Expand Down
6 changes: 3 additions & 3 deletions settings/ama_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
use_feature: True
feature_dim: 1156
d_model: 32
d_model: 128
type_num: 2
L: 3
use_gradient_checkpointing: False
neg_num: 5
neg_num: 4
lr: 0.01
wd: 0.0001
dropout: 0.1
callback:
monitor: 'val_auc'
mode: 'max'
train:
max_epochs: 50
max_epochs: 200
gpus: 1
# reload_dataloaders_every_n_epochs: 1
2 changes: 1 addition & 1 deletion settings/tiw_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ model:
L: 3
use_gradient_checkpointing: False
neg_num: 5
lr: 0.01
lr: 0.005
wd: 0.0001
dropout: 0.1
callback:
Expand Down

0 comments on commit 247cdea

Please sign in to comment.