diff --git a/models/AGAT.py b/models/AGAT.py index 1f12d99..6722f16 100644 --- a/models/AGAT.py +++ b/models/AGAT.py @@ -1,32 +1,21 @@ import torch from torch import nn from torch_scatter import scatter_softmax, scatter_add +from torch.utils.checkpoint import checkpoint from torch_sparse import spmm from dataloader.link_pre_dataloader import LinkPredictionDataloader class AGAT(nn.Module): - def __init__(self,type_num,d_model,L,dropout=.1): + def __init__(self,type_num,d_model,L,use_gradient_checkpointing=False,dropout=.1): super(AGAT, self).__init__() self.type_num = type_num self.d_model = d_model self.L = L - self.theta = nn.ParameterList() - self.we = nn.ParameterList() - self.wr = nn.ParameterList() - self.dropout = nn.ModuleList() - for i in range(L): - theta = nn.Parameter(torch.FloatTensor(type_num,3*d_model)) - we = nn.Parameter(torch.FloatTensor(d_model,d_model)) - wr = nn.Parameter(torch.FloatTensor(d_model,d_model)) - nn.init.xavier_uniform_(theta) - nn.init.xavier_uniform_(we) - nn.init.xavier_uniform_(wr) - self.theta.append(theta) - self.we.append(we) - self.wr.append(wr) - self.dropout.append(nn.Dropout(dropout)) + self.use_gradient_checkpointing = use_gradient_checkpointing + self.layer_list = nn.ModuleList([AGATLayer(type_num,d_model) for i in range(L)]) + self.dropout = nn.ModuleList([nn.Dropout(dropout) for i in range(L)]) self.relu = nn.ReLU(inplace=True) def forward(self,x,edge_index,edge_type,edge_feature,mask=None): ''' @@ -38,45 +27,62 @@ def forward(self,x,edge_index,edge_type,edge_feature,mask=None): ''' N,d,E,eT = x.shape[0],x.shape[1],edge_type.shape[0],edge_feature.shape[0] x = x.expand(self.type_num,N,d) - # edge_feature = edge_feature.expand(self.type_num,eT,d) + for i in range(self.L): - x,edge_feature = self.layer_forward(x,edge_index,edge_type,edge_feature,mask, - self.theta[i][:,:d],self.theta[i][:,d:2*d],self.theta[i][:,2*d:],self.wr[i],self.we[i]) + if self.use_gradient_checkpointing: + x, edge_feature = checkpoint(self.layer_list[i],x,edge_index,edge_type,edge_feature,mask) + else: + x, edge_feature = self.layer_list[i](x,edge_index,edge_type,edge_feature,mask) if i == self.L-1: break x = self.dropout[i](self.relu(x)) edge_feature = self.relu(edge_feature) + return x - def layer_forward(self,x,edge_index,edge_type,edge_feature,mask,theta_g,theta_hi,theta_hj,wr,we): + +class AGATLayer(nn.Module): + def __init__(self,type_num,d_model): + super(AGATLayer, self).__init__() + self.type_num = type_num + self.d_model = d_model + self.theta_g = nn.Parameter(torch.FloatTensor(type_num, d_model)) + self.theta_hi = nn.Parameter(torch.FloatTensor(type_num, d_model)) + self.theta_hj = nn.Parameter(torch.FloatTensor(type_num, d_model)) + self.we = nn.Parameter(torch.FloatTensor(d_model, d_model)) + self.wr = nn.Parameter(torch.FloatTensor(d_model, d_model)) + nn.init.xavier_uniform_(self.theta_g) + nn.init.xavier_uniform_(self.theta_hi) + nn.init.xavier_uniform_(self.theta_hj) + nn.init.xavier_uniform_(self.we) + nn.init.xavier_uniform_(self.wr) + def forward(self,x,edge_index,edge_type,edge_feature,mask): ''' - :param x: type_num,N,d_model + :param x: :param edge_index: :param edge_type: - :param edge_feature: edge_type,d_model + :param edge_feature: :param mask: - :param theta: - :param wr: - :param we: :return: ''' - row,col = edge_index[0],edge_index[1] + theta_g, theta_hi, theta_hj, wr, we = self.theta_g,self.theta_hi,self.theta_hj,self.wr,self.we + row, col = edge_index[0], edge_index[1] # 计算r_g分量 - r_g = (edge_feature.unsqueeze(0) * theta_g.unsqueeze(1)).sum(-1).index_select(1,edge_type) #t,et->t,E - r_hi = (x * theta_hi.unsqueeze(1)).sum(-1).index_select(1,row) # t,N->t,E - r_hj = (x * theta_hj.unsqueeze(1)).sum(-1).index_select(1,col) # t,N->t,E - r = r_g+r_hi+r_hj + r_g = (edge_feature.unsqueeze(0) * theta_g.unsqueeze(1)).sum(-1).index_select(1, edge_type) # t,et->t,E + r_hi = (x * theta_hi.unsqueeze(1)).sum(-1).index_select(1, row) # t,N->t,E + r_hj = (x * theta_hj.unsqueeze(1)).sum(-1).index_select(1, col) # t,N->t,E + r = r_g + r_hi + r_hj # h = x.index_select(1,col) # t,E,d # r = (torch.cat([path,h],dim=-1) * theta.unsqueeze(1)).sum(-1) #t,E if mask is not None: pass - r = scatter_softmax(r,row,dim=-1) #t,E - edge_feature = edge_feature @ wr # et,d_model - v_g = torch.sigmoid(edge_feature).index_select(0,edge_type).unsqueeze(0) #1,E,d_model - v_h = (x @ we).index_select(1,col) - out = r.unsqueeze(-1)*v_g*v_h - out = scatter_add(out, row, dim=-2)# t,N,d_model + r = scatter_softmax(r, row, dim=-1) # t,E + edge_feature = edge_feature @ wr # et,d_model + v_g = torch.sigmoid(edge_feature).index_select(0, edge_type).unsqueeze(0) # 1,E,d_model + v_h = (x @ we).index_select(1, col) + out = r.unsqueeze(-1) * v_g * v_h + out = scatter_add(out, row, dim=-2) # t,N,d_model return out, edge_feature if __name__ == '__main__': @@ -86,8 +92,9 @@ def layer_forward(self,x,edge_index,edge_type,edge_feature,mask,theta_g,theta_hi N = edge_index.max()+1 # path = torch.randn(E,16) # feature = torch.randn(N,16) - model = AGAT(4,32,3,0.1).cuda() + model = AGAT(4,64,3,True,0.1).cuda() # rs = model(feature.cuda(),path.cuda(),edge_index.cuda()) - x = torch.randn(N,32) - edge_feature = torch.randn(3,32) + + x = torch.randn(N,64) + edge_feature = torch.randn(3,64) rs = model(x.cuda(),edge_index.cuda(),path.cuda(),edge_feature.cuda()) diff --git a/models/LinkPreTask.py b/models/LinkPreTask.py index 8fe8eb7..44dc2fb 100644 --- a/models/LinkPreTask.py +++ b/models/LinkPreTask.py @@ -11,7 +11,7 @@ class LinkPredictionTask(pl.LightningModule): - def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd): + def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model,type_num, L,use_gradient_checkpointing,neg_num,dropout,lr,wd): super(LinkPredictionTask, self).__init__() # 工程类组件 self.save_hyperparameters(ignore=['edge_index','edge_type','feature','N']) @@ -22,8 +22,8 @@ def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model self.register_buffer('feature',feature) self.fc_node = nn.Linear(feature_dim, d_model) else: - self.feature = nn.Parameter(torch.FloatTensor(N,d_model)) - nn.init.xavier_uniform_(self.feature) + self.feature = nn.Parameter(torch.randn(N,d_model)) + # nn.init.xavier_uniform_(self.feature) self.loss2 = nn.CrossEntropyLoss() self.loss1 = NCELoss(N) self.val_best_auc = 0 @@ -35,7 +35,7 @@ def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model self.fc_edge = nn.Linear(type_num+1,d_model) self.w = nn.Parameter(torch.FloatTensor(N,d_model)) nn.init.xavier_uniform_(self.w) - self.agat = AGAT(type_num,d_model,L,dropout) + self.agat = AGAT(type_num,d_model,L,use_gradient_checkpointing,dropout) def get_em(self,mask=None): if self.hparams.use_feature: @@ -118,6 +118,7 @@ def forward(self,inputs,weights,labels,neg_num): target = weights[torch.cat([labels,neg_batch],dim=0)] label = torch.zeros(target.shape[0],device=inputs.device) label[:labels.shape[0]]=1 + # bs,d_model-> bs*(neg_num+1),d_model source = inputs.repeat((neg_num+1,1)) return self.bce((source*target).sum(dim=-1),label) diff --git a/settings/ama_settings.yaml b/settings/ama_settings.yaml index 242dd9b..bd35e5f 100644 --- a/settings/ama_settings.yaml +++ b/settings/ama_settings.yaml @@ -11,6 +11,7 @@ model: d_model: 32 type_num: 2 L: 3 + use_gradient_checkpointing: False neg_num: 5 lr: 0.01 wd: 0.0001 @@ -19,7 +20,7 @@ callback: monitor: 'val_auc' mode: 'max' train: - max_epochs: 2 + max_epochs: 100 gpus: 1 # reload_dataloaders_every_n_epochs: 1 # resume_from_checkpoint: 'lightning_logs/version_0/checkpoints/epoch=96-step=6789.ckpt' diff --git a/settings/yot_settings.yaml b/settings/yot_settings.yaml index 62477bc..476e3e5 100644 --- a/settings/yot_settings.yaml +++ b/settings/yot_settings.yaml @@ -8,9 +8,10 @@ model: # use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd use_feature: False feature_dim: 1156 - d_model: 32 + d_model: 16 type_num: 5 L: 3 + use_gradient_checkpointing: True neg_num: 5 lr: 0.01 wd: 0.0001