Skip to content

Commit

Permalink
添加残差连接
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Jan 25, 2022
1 parent 8e0acc8 commit 79bb02d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion models/AGAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ def forward(self,x,edge_index,edge_type,edge_feature,mask=None):
x = x.expand(self.type_num,N,d)

for i in range(self.L):
x_ = x
if self.use_gradient_checkpointing:
x, edge_feature = checkpoint(self.layer_list[i],x,edge_index,edge_type,edge_feature,mask)
else:
x, edge_feature = self.layer_list[i](x,edge_index,edge_type,edge_feature,mask)
if i == self.L-1:
break
x = self.dropout[i](self.relu(x))
x = self.relu(x_+self.dropout[i](x))
edge_feature = self.relu(edge_feature)

return x
Expand Down
4 changes: 2 additions & 2 deletions settings/ama_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
use_feature: True
feature_dim: 1156
d_model: 128
d_model: 32
type_num: 2
L: 3
use_gradient_checkpointing: False
Expand All @@ -20,6 +20,6 @@ callback:
monitor: 'val_auc'
mode: 'max'
train:
max_epochs: 100
max_epochs: 50
gpus: 1
# reload_dataloaders_every_n_epochs: 1

0 comments on commit 79bb02d

Please sign in to comment.