Skip to content

Commit

Permalink
针对稠密图进一步优化显存
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Jan 25, 2022
1 parent 1aac78a commit 8e0acc8
Show file tree
Hide file tree
Showing 17 changed files with 28 additions and 66 deletions.
2 changes: 1 addition & 1 deletion dataloader/link_pre_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def read_data(self):
self.train_dataset = TensorDataset(self.edge_index.T[mask],self.edge_type[mask],self.edge_id[mask])

def train_dataloader(self) -> TRAIN_DATALOADERS:
return DataLoader(self.train_dataset,self.batch_size,shuffle=True,num_workers=self.num_workers)
return DataLoader(self.train_dataset,self.batch_size,shuffle=True,num_workers=self.num_workers,drop_last=True)

def test_dataloader(self) -> EVAL_DATALOADERS:
return DataLoader(self.test_dataset,batch_size=len(self.test_dataset))
Expand Down
Binary file not shown.
9 changes: 0 additions & 9 deletions lightning_logs/version_0/hparams.yaml

This file was deleted.

Binary file not shown.
9 changes: 0 additions & 9 deletions lightning_logs/version_1/hparams.yaml

This file was deleted.

Binary file not shown.
9 changes: 0 additions & 9 deletions lightning_logs/version_2/hparams.yaml

This file was deleted.

Binary file not shown.
9 changes: 0 additions & 9 deletions lightning_logs/version_3/hparams.yaml

This file was deleted.

Binary file not shown.
9 changes: 0 additions & 9 deletions lightning_logs/version_4/hparams.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test(parser):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--setting_path',type=str,default='settings/yot_settings.yaml')
parser.add_argument('--setting_path',type=str,default='settings/ama_settings.yaml')
parser.add_argument("--test", action='store_true', help='test or train')
temp_args, _ = parser.parse_known_args()
if temp_args.test:
Expand Down
15 changes: 12 additions & 3 deletions models/AGAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def forward(self,x,edge_index,edge_type,edge_feature,mask):
:param mask:
:return:
'''
E = edge_type.shape[0]
et=edge_feature.shape[0]
T,N,d= x.shape
theta_g, theta_hi, theta_hj, wr, we = self.theta_g,self.theta_hi,self.theta_hj,self.wr,self.we
row, col = edge_index[0], edge_index[1]
# 计算r_g分量
Expand All @@ -79,9 +82,15 @@ def forward(self,x,edge_index,edge_type,edge_feature,mask):
pass
r = scatter_softmax(r, row, dim=-1) # t,E
edge_feature = edge_feature @ wr # et,d_model
v_g = torch.sigmoid(edge_feature).index_select(0, edge_type).unsqueeze(0) # 1,E,d_model
v_h = (x @ we).index_select(1, col)
out = r.unsqueeze(-1) * v_g * v_h
if E>10*et*N:
v_g = torch.sigmoid(edge_feature).view(1,et,1,d)
v_h = (x @ we).view(T,1,N,d)
v = (v_g*v_h)[:,edge_type,col] #T,E,d
else:
v_g = torch.sigmoid(edge_feature).index_select(0, edge_type).unsqueeze(0) # 1,E,d_model
v_h = (x @ we).index_select(1, col)
v = v_g*v_h
out = r.unsqueeze(-1) * v
out = scatter_add(out, row, dim=-2) # t,N,d_model
return out, edge_feature

Expand Down
3 changes: 1 addition & 2 deletions settings/ama_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
use_feature: True
feature_dim: 1156
d_model: 32
d_model: 128
type_num: 2
L: 3
use_gradient_checkpointing: False
Expand All @@ -23,4 +23,3 @@ train:
max_epochs: 100
gpus: 1
# reload_dataloaders_every_n_epochs: 1
# resume_from_checkpoint: 'lightning_logs/version_0/checkpoints/epoch=96-step=6789.ckpt'
5 changes: 2 additions & 3 deletions settings/tiw_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
use_feature: False
feature_dim: 1156
d_model: 32
d_model: 128
type_num: 4
L: 3
use_gradient_checkpointing: False
Expand All @@ -20,7 +20,6 @@ callback:
monitor: 'val_auc'
mode: 'max'
train:
max_epochs: 50
max_epochs: 100
gpus: 1
# reload_dataloaders_every_n_epochs: 1
# resume_from_checkpoint: 'lightning_logs/version_0/checkpoints/epoch=96-step=6789.ckpt'
6 changes: 3 additions & 3 deletions settings/yot_settings.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
data:
# datapath,batch_size,is_dir=False,num_workers=0
batch_size: 4096
batch_size: 16384
datapath: 'data/youtube/all_data.pkl'
is_dir: False
num_workers: 0
model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
use_feature: False
feature_dim: 1156
d_model: 16
d_model: 32
type_num: 5
L: 3
use_gradient_checkpointing: True
Expand All @@ -22,5 +22,5 @@ callback:
train:
max_epochs: 50
gpus: 1
precision: 16
# reload_dataloaders_every_n_epochs: 1
# resume_from_checkpoint: 'lightning_logs/version_0/checkpoints/epoch=96-step=6789.ckpt'
16 changes: 8 additions & 8 deletions utils/dataprepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def get_train_sparse_adj(path,old2new,is_dir):
# has_feature = True
# is_dir = False

base_path = '../data/youtube'
node_num = 2000
has_feature = False
is_dir = False

# base_path = '../data/twitter'
# node_num = 10000
# base_path = '../data/youtube'
# node_num = 2000
# has_feature = False
# is_dir = True
# is_dir = False

base_path = '../data/twitter'
node_num = 10000
has_feature = False
is_dir = True
do(base_path,node_num,has_feature,is_dir)

0 comments on commit 8e0acc8

Please sign in to comment.