Skip to content

Commit

Permalink
增加梯度检查点策略
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Jan 25, 2022
1 parent 125d0f7 commit 61f4dbd
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 45 deletions.
85 changes: 46 additions & 39 deletions models/AGAT.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,21 @@
import torch
from torch import nn
from torch_scatter import scatter_softmax, scatter_add
from torch.utils.checkpoint import checkpoint
from torch_sparse import spmm

from dataloader.link_pre_dataloader import LinkPredictionDataloader


class AGAT(nn.Module):
def __init__(self,type_num,d_model,L,dropout=.1):
def __init__(self,type_num,d_model,L,use_gradient_checkpointing=False,dropout=.1):
super(AGAT, self).__init__()
self.type_num = type_num
self.d_model = d_model
self.L = L
self.theta = nn.ParameterList()
self.we = nn.ParameterList()
self.wr = nn.ParameterList()
self.dropout = nn.ModuleList()
for i in range(L):
theta = nn.Parameter(torch.FloatTensor(type_num,3*d_model))
we = nn.Parameter(torch.FloatTensor(d_model,d_model))
wr = nn.Parameter(torch.FloatTensor(d_model,d_model))
nn.init.xavier_uniform_(theta)
nn.init.xavier_uniform_(we)
nn.init.xavier_uniform_(wr)
self.theta.append(theta)
self.we.append(we)
self.wr.append(wr)
self.dropout.append(nn.Dropout(dropout))
self.use_gradient_checkpointing = use_gradient_checkpointing
self.layer_list = nn.ModuleList([AGATLayer(type_num,d_model) for i in range(L)])
self.dropout = nn.ModuleList([nn.Dropout(dropout) for i in range(L)])
self.relu = nn.ReLU(inplace=True)
def forward(self,x,edge_index,edge_type,edge_feature,mask=None):
'''
Expand All @@ -38,45 +27,62 @@ def forward(self,x,edge_index,edge_type,edge_feature,mask=None):
'''
N,d,E,eT = x.shape[0],x.shape[1],edge_type.shape[0],edge_feature.shape[0]
x = x.expand(self.type_num,N,d)
# edge_feature = edge_feature.expand(self.type_num,eT,d)

for i in range(self.L):
x,edge_feature = self.layer_forward(x,edge_index,edge_type,edge_feature,mask,
self.theta[i][:,:d],self.theta[i][:,d:2*d],self.theta[i][:,2*d:],self.wr[i],self.we[i])
if self.use_gradient_checkpointing:
x, edge_feature = checkpoint(self.layer_list[i],x,edge_index,edge_type,edge_feature,mask)
else:
x, edge_feature = self.layer_list[i](x,edge_index,edge_type,edge_feature,mask)
if i == self.L-1:
break
x = self.dropout[i](self.relu(x))
edge_feature = self.relu(edge_feature)

return x

def layer_forward(self,x,edge_index,edge_type,edge_feature,mask,theta_g,theta_hi,theta_hj,wr,we):

class AGATLayer(nn.Module):
def __init__(self,type_num,d_model):
super(AGATLayer, self).__init__()
self.type_num = type_num
self.d_model = d_model
self.theta_g = nn.Parameter(torch.FloatTensor(type_num, d_model))
self.theta_hi = nn.Parameter(torch.FloatTensor(type_num, d_model))
self.theta_hj = nn.Parameter(torch.FloatTensor(type_num, d_model))
self.we = nn.Parameter(torch.FloatTensor(d_model, d_model))
self.wr = nn.Parameter(torch.FloatTensor(d_model, d_model))
nn.init.xavier_uniform_(self.theta_g)
nn.init.xavier_uniform_(self.theta_hi)
nn.init.xavier_uniform_(self.theta_hj)
nn.init.xavier_uniform_(self.we)
nn.init.xavier_uniform_(self.wr)
def forward(self,x,edge_index,edge_type,edge_feature,mask):
'''
:param x: type_num,N,d_model
:param x:
:param edge_index:
:param edge_type:
:param edge_feature: edge_type,d_model
:param edge_feature:
:param mask:
:param theta:
:param wr:
:param we:
:return:
'''
row,col = edge_index[0],edge_index[1]
theta_g, theta_hi, theta_hj, wr, we = self.theta_g,self.theta_hi,self.theta_hj,self.wr,self.we
row, col = edge_index[0], edge_index[1]
# 计算r_g分量
r_g = (edge_feature.unsqueeze(0) * theta_g.unsqueeze(1)).sum(-1).index_select(1,edge_type) #t,et->t,E
r_hi = (x * theta_hi.unsqueeze(1)).sum(-1).index_select(1,row) # t,N->t,E
r_hj = (x * theta_hj.unsqueeze(1)).sum(-1).index_select(1,col) # t,N->t,E
r = r_g+r_hi+r_hj
r_g = (edge_feature.unsqueeze(0) * theta_g.unsqueeze(1)).sum(-1).index_select(1, edge_type) # t,et->t,E
r_hi = (x * theta_hi.unsqueeze(1)).sum(-1).index_select(1, row) # t,N->t,E
r_hj = (x * theta_hj.unsqueeze(1)).sum(-1).index_select(1, col) # t,N->t,E
r = r_g + r_hi + r_hj

# h = x.index_select(1,col) # t,E,d
# r = (torch.cat([path,h],dim=-1) * theta.unsqueeze(1)).sum(-1) #t,E
if mask is not None:
pass
r = scatter_softmax(r,row,dim=-1) #t,E
edge_feature = edge_feature @ wr # et,d_model
v_g = torch.sigmoid(edge_feature).index_select(0,edge_type).unsqueeze(0) #1,E,d_model
v_h = (x @ we).index_select(1,col)
out = r.unsqueeze(-1)*v_g*v_h
out = scatter_add(out, row, dim=-2)# t,N,d_model
r = scatter_softmax(r, row, dim=-1) # t,E
edge_feature = edge_feature @ wr # et,d_model
v_g = torch.sigmoid(edge_feature).index_select(0, edge_type).unsqueeze(0) # 1,E,d_model
v_h = (x @ we).index_select(1, col)
out = r.unsqueeze(-1) * v_g * v_h
out = scatter_add(out, row, dim=-2) # t,N,d_model
return out, edge_feature

if __name__ == '__main__':
Expand All @@ -86,8 +92,9 @@ def layer_forward(self,x,edge_index,edge_type,edge_feature,mask,theta_g,theta_hi
N = edge_index.max()+1
# path = torch.randn(E,16)
# feature = torch.randn(N,16)
model = AGAT(4,32,3,0.1).cuda()
model = AGAT(4,64,3,True,0.1).cuda()
# rs = model(feature.cuda(),path.cuda(),edge_index.cuda())
x = torch.randn(N,32)
edge_feature = torch.randn(3,32)

x = torch.randn(N,64)
edge_feature = torch.randn(3,64)
rs = model(x.cuda(),edge_index.cuda(),path.cuda(),edge_feature.cuda())
9 changes: 5 additions & 4 deletions models/LinkPreTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class LinkPredictionTask(pl.LightningModule):
def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd):
def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model,type_num, L,use_gradient_checkpointing,neg_num,dropout,lr,wd):
super(LinkPredictionTask, self).__init__()
# 工程类组件
self.save_hyperparameters(ignore=['edge_index','edge_type','feature','N'])
Expand All @@ -22,8 +22,8 @@ def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model
self.register_buffer('feature',feature)
self.fc_node = nn.Linear(feature_dim, d_model)
else:
self.feature = nn.Parameter(torch.FloatTensor(N,d_model))
nn.init.xavier_uniform_(self.feature)
self.feature = nn.Parameter(torch.randn(N,d_model))
# nn.init.xavier_uniform_(self.feature)
self.loss2 = nn.CrossEntropyLoss()
self.loss1 = NCELoss(N)
self.val_best_auc = 0
Expand All @@ -35,7 +35,7 @@ def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model
self.fc_edge = nn.Linear(type_num+1,d_model)
self.w = nn.Parameter(torch.FloatTensor(N,d_model))
nn.init.xavier_uniform_(self.w)
self.agat = AGAT(type_num,d_model,L,dropout)
self.agat = AGAT(type_num,d_model,L,use_gradient_checkpointing,dropout)

def get_em(self,mask=None):
if self.hparams.use_feature:
Expand Down Expand Up @@ -118,6 +118,7 @@ def forward(self,inputs,weights,labels,neg_num):
target = weights[torch.cat([labels,neg_batch],dim=0)]
label = torch.zeros(target.shape[0],device=inputs.device)
label[:labels.shape[0]]=1
# bs,d_model-> bs*(neg_num+1),d_model
source = inputs.repeat((neg_num+1,1))
return self.bce((source*target).sum(dim=-1),label)

3 changes: 2 additions & 1 deletion settings/ama_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ model:
d_model: 32
type_num: 2
L: 3
use_gradient_checkpointing: False
neg_num: 5
lr: 0.01
wd: 0.0001
Expand All @@ -19,7 +20,7 @@ callback:
monitor: 'val_auc'
mode: 'max'
train:
max_epochs: 2
max_epochs: 100
gpus: 1
# reload_dataloaders_every_n_epochs: 1
# resume_from_checkpoint: 'lightning_logs/version_0/checkpoints/epoch=96-step=6789.ckpt'
3 changes: 2 additions & 1 deletion settings/yot_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
use_feature: False
feature_dim: 1156
d_model: 32
d_model: 16
type_num: 5
L: 3
use_gradient_checkpointing: True
neg_num: 5
lr: 0.01
wd: 0.0001
Expand Down

0 comments on commit 61f4dbd

Please sign in to comment.