Skip to content

Commit

Permalink
增加梯度检查点策略
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Jan 25, 2022
1 parent 61f4dbd commit 2aed881
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 24 deletions.
37 changes: 20 additions & 17 deletions dataloader/link_pre_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,30 @@ def __init__(self,datapath,batch_size,is_dir=False,num_workers=0):
def read_data(self):
data = torch.load(self.datapath)
val_data,test_data,feature_data = data['val_data'],data['test_data'],data['feature_data']
self.feature_data = feature_data
self.val_dataset = TensorDataset(val_data)
self.test_dataset = TensorDataset(test_data)
# 自动去重
i,v = data['edge_index'],data['edge_type']
if not self.is_dir:
# 无向图化有向图
i = torch.cat([i,i[[1,0]]],dim=1)
v = torch.cat([v,v],dim=0)
n = 1+i.max()
i,v = coalesce(i,v,n,n,op='max')
train_adj = torch.sparse_coo_tensor(i,v).coalesce()
# 添加自环0
train_adj = (train_adj + sparse_diags([0]*train_adj.shape[0])).coalesce()
# # 自动去重
# i,v = data['edge_index'],data['edge_type']
# if not self.is_dir:
# # 无向图化有向图
# i = torch.cat([i,i[[1,0]]],dim=1)
# v = torch.cat([v,v],dim=0)
# n = 1+i.max()
# i,v = coalesce(i,v,n,n,op='max')
# train_adj = torch.sparse_coo_tensor(i,v).coalesce()
# # 添加自环0
# train_adj = (train_adj + sparse_diags([0]*train_adj.shape[0])).coalesce()

self.train_adj = train_adj
self.edge_index,self.edge_type,self.edge_id = train_adj.indices(),train_adj.values(),torch.arange(train_adj._nnz())
self.feature_data = feature_data
# self.train_adj = train_adj
# self.edge_index,self.edge_type,self.edge_id = train_adj.indices(),train_adj.values(),torch.arange(train_adj._nnz())
# self.feature_data = feature_data

self.edge_index,self.edge_type = data['edge_index'],data['edge_type']
self.N,self.E = self.edge_index.max()+1,self.edge_index.shape[1]
self.edge_id = torch.arange(self.E)
# mask除去自环
mask = self.edge_type>0

self.N = n
self.train_dataset = TensorDataset(self.edge_index.T[mask],self.edge_type[mask],self.edge_id[mask])

def train_dataloader(self) -> TRAIN_DATALOADERS:
Expand All @@ -49,4 +52,4 @@ def val_dataloader(self) -> EVAL_DATALOADERS:


if __name__ == '__main__':
dataloader = LinkPredictionDataloader('../data/amazon/all_data.pkl',64,64)
dataloader = LinkPredictionDataloader('../data/amazon/all_data.pkl',64)
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test(parser):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--setting_path',type=str,default='settings/yot_settings.yaml')
parser.add_argument('--setting_path',type=str,default='settings/ama_settings.yaml')
parser.add_argument("--test", action='store_true', help='test or train')
temp_args, _ = parser.parse_known_args()
if temp_args.test:
Expand Down
26 changes: 26 additions & 0 deletions settings/tiw_settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
data:
# datapath,batch_size,is_dir=False,num_workers=0
batch_size: 4096
datapath: 'data/twitter/all_data.pkl'
is_dir: True
num_workers: 0
model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
use_feature: False
feature_dim: 1156
d_model: 32
type_num: 4
L: 3
use_gradient_checkpointing: False
neg_num: 5
lr: 0.01
wd: 0.0001
dropout: 0.1
callback:
monitor: 'val_auc'
mode: 'max'
train:
max_epochs: 50
gpus: 1
# reload_dataloaders_every_n_epochs: 1
# resume_from_checkpoint: 'lightning_logs/version_0/checkpoints/epoch=96-step=6789.ckpt'
52 changes: 46 additions & 6 deletions utils/dataprepare.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import numpy as np
import pandas as pd
import torch
from torch_sparse import coalesce
import pickle


def do(base_path,node_num,has_feature):
def do(base_path,node_num,has_feature,is_dir):
'''
:param base_path:
:param node_num:
:param has_feature:
:param is_dir: 是有向图
:return:
'''

df=pd.read_csv(base_path+'/train.txt',sep=' ', index_col=None,header=None)
node_set = set(df[1].append(df[2]))
Expand All @@ -14,7 +22,7 @@ def do(base_path,node_num,has_feature):
old2new[old]=new
new2old[new]=old

edge_index,edge_type = get_train_sparse_adj(base_path+'/train.txt',old2new)
edge_index,edge_type = get_train_sparse_adj(base_path+'/train.txt',old2new,is_dir)
val_data = get_test_data(base_path+'/valid.txt',old2new)
test_data = get_test_data(base_path+'/test.txt',old2new)
if has_feature:
Expand Down Expand Up @@ -47,23 +55,55 @@ def get_test_data(path,old2new):
return data


def get_train_sparse_adj(path,old2new):
def get_train_sparse_adj(path,old2new,is_dir):
df = pd.read_csv(path, sep=' ', index_col=None, header=None)
df[1].replace(old2new, inplace=True)
df[2].replace(old2new, inplace=True)
# 替换成功
# 存为稀疏型邻接矩阵
data = torch.from_numpy(df.to_numpy(dtype=np.int64))
edge_index,edge_type = data[:, 1:].transpose(0, 1), data[:, 0]
# [edge_type,row,col]
type_num = data[:,0].max()
N = len(old2new)
self_loop_index = torch.stack([torch.arange(N),torch.arange(N)])
self_loop_type = torch.zeros(N,dtype=torch.long)
print('引入边,自环,num=',N)
edge_index = [self_loop_index]
edge_type = [self_loop_type]
for type_id in range(1,type_num+1):
# 对每类边施行反向,去重,操作
index = data[:,0]==type_id
i = data[index,1:].T
v = torch.ones(i.shape[1],dtype=torch.long)
if not is_dir:
# 无向图化有向图
i = torch.cat([i, i[[1, 0]]], dim=1)
v = torch.cat([v, v], dim=0)
# 去重
i,v = coalesce(i,v,N,N)
v[:] = type_id
print('引入边,类别 %d,num= %d'%(type_id,v.shape[0]))
edge_index.append(i)
edge_type.append(v)
edge_index = torch.cat(edge_index,dim=1)
edge_type = torch.cat(edge_type,dim=0)
print('训练集总边数:',edge_index.shape[1])
# edge_index,edge_type = data[:, 1:].transpose(0, 1), data[:, 0]
# train_adj = torch.sparse_coo_tensor(data[:, 1:].transpose(0, 1), data[:, 0])
return edge_index,edge_type

if __name__ == '__main__':
# base_path = '../data/amazon'
# node_num = 10166
# has_feature = True
# is_dir = False

base_path = '../data/youtube'
node_num = 2000
has_feature = False
do(base_path,node_num,has_feature)
is_dir = False

# base_path = '../data/twitter'
# node_num = 10000
# has_feature = False
# is_dir = True
do(base_path,node_num,has_feature,is_dir)

0 comments on commit 2aed881

Please sign in to comment.