Skip to content

Commit

Permalink
add MetaHIN model and dbook dataset (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying-1106 authored Oct 29, 2024
1 parent d6f0170 commit 9a1f725
Show file tree
Hide file tree
Showing 12 changed files with 962 additions and 17 deletions.
29 changes: 29 additions & 0 deletions openhgnn/config.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,33 @@
; #################### add model config here
[MetaHIN]
; 修改前
; input_dir=/home/zzh/Z测试代码/
; output_dir=/home/zzh/Z测试代码/

input_dir=/openhgnn/dataset/Common_Dataset/
output_dir=/openhgnn/dataset/Common_Dataset/

dataset=dbook
use_cuda= True
file_num= 10
num_location= 453
num_fea_item= 2
num_publisher =1698
num_fea_user= 1
item_fea_len= 1
embedding_dim= 32
user_embedding_dim= 32
item_embedding_dim= 32
first_fc_hidden_dim= 64
second_fc_hidden_dim= 64
mp_update= 1
local_update= 1
lr= 5e-4
mp_lr= 5e-3
local_lr= 5e-3
batch_size= 32
num_epoch= 50
seed=13

[FedHGNN]
fea_dim = 64
Expand Down
27 changes: 26 additions & 1 deletion openhgnn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,32 @@ def __init__(self, file_path, model, dataset, task, gpu):
self.patience = conf.getint("General", "patience")
self.mini_batch_flag = conf.getboolean("General", "mini_batch_flag")
############## add config.py #################


elif self.model_name == "MetaHIN":
self.use_cuda = conf.getboolean("MetaHIN", "use_cuda")
self.file_num = conf.getint("MetaHIN", "file_num")
self.num_location = conf.getint("MetaHIN", "num_location")
self.num_fea_item = conf.getint("MetaHIN", "num_fea_item")
self.num_publisher = conf.getint("MetaHIN", "num_publisher")
self.num_fea_user = conf.getint("MetaHIN", "num_fea_user")
self.item_fea_len = conf.getint("MetaHIN", "item_fea_len")
self.embedding_dim = conf.getint("MetaHIN", "embedding_dim")
self.user_embedding_dim = conf.getint("MetaHIN", "user_embedding_dim")
self.item_embedding_dim = conf.getint("MetaHIN", "item_embedding_dim")
self.first_fc_hidden_dim = conf.getint("MetaHIN", "first_fc_hidden_dim")
self.second_fc_hidden_dim = conf.getint("MetaHIN", "second_fc_hidden_dim")
self.mp_update = conf.getint("MetaHIN", "mp_update")
self.local_update = conf.getint("MetaHIN", "local_update")
self.lr = conf.getfloat("MetaHIN", "lr")
self.mp_lr = conf.getfloat("MetaHIN", "mp_lr")
self.local_lr = conf.getfloat("MetaHIN", "local_lr")
self.batch_size = conf.getint("MetaHIN", "batch_size")
self.num_epoch = conf.getint("MetaHIN", "num_epoch")
self.input_dir = conf.get("MetaHIN", "input_dir")
self.output_dir = conf.get("MetaHIN", "output_dir")
self.seed = conf.getint("MetaHIN", "seed")


elif self.model_name =='FedHGNN':
self.fea_dim = conf.getint("FedHGNN","fea_dim")
self.in_dim = conf.getint("FedHGNN","in_dim")
Expand Down
161 changes: 161 additions & 0 deletions openhgnn/dataset/MetaHIN_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@

import gc
import glob
import os
import pickle

# from DataProcessor import Movielens
from tqdm import tqdm
from multiprocessing import Process, Pool
from multiprocessing.pool import ThreadPool
import numpy as np
import torch


class Meta_DataHelper:
def __init__(self, input_dir, config):
self.input_dir = input_dir
self.config = config
self.mp_list = ["ub", "ubab", "ubub"]

from dgl.data.utils import download, extract_archive
# 只有dbook这一个数据集
dataset_name = 'dbook'
self.zip_file = f'./openhgnn/dataset/Common_Dataset/{dataset_name}.zip'
# common_dataset/dbook_dir
self.base_dir = './openhgnn/dataset/Common_Dataset/' + dataset_name+'_dir'
self.url = f'https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/{dataset_name}.zip'
if os.path.exists(self.zip_file):
pass
else:
os.makedirs( os.path.join('./openhgnn/dataset/Common_Dataset/') ,exist_ok= True)
download(self.url,
path=os.path.join('./openhgnn/dataset/Common_Dataset/')
)
if os.path.exists( self.base_dir ):
pass
else:
os.makedirs( os.path.join( self.base_dir ) ,exist_ok= True )
extract_archive(self.zip_file, self.base_dir)


def load_data(self, data_set, state, load_from_file=True):
# 解压后的dbook目录: input_dir下的dbook文件夹
# data_dir = self.input_dir + data_set
# 修改后代码
data_dir = self.base_dir +'/'+data_set
supp_xs_s = []
supp_ys_s = []
supp_mps_s = []
query_xs_s = []
query_ys_s = []
query_mps_s = []

if data_set == "yelp":
training_set_size = int(
len(glob.glob("{}/{}/*.npy".format(data_dir, state)))
/ self.config.file_num
) # support, query

# load all data
for idx in tqdm(range(training_set_size)):
supp_xs_s.append(
torch.from_numpy(
np.load("{}/{}/support_x_{}.npy".format(data_dir, state, idx))
)
)
supp_ys_s.append(
torch.from_numpy(
np.load("{}/{}/support_y_{}.npy".format(data_dir, state, idx))
)
)
query_xs_s.append(
torch.from_numpy(
np.load("{}/{}/query_x_{}.npy".format(data_dir, state, idx))
)
)
query_ys_s.append(
torch.from_numpy(
np.load("{}/{}/query_y_{}.npy".format(data_dir, state, idx))
)
)

supp_mp_data, query_mp_data = {}, {}
for mp in self.mp_list:
_cur_data = np.load(
"{}/{}/support_{}_{}.npy".format(data_dir, state, mp, idx),
encoding="latin1",
)
supp_mp_data[mp] = [torch.from_numpy(x) for x in _cur_data]
_cur_data = np.load(
"{}/{}/query_{}_{}.npy".format(data_dir, state, mp, idx),
encoding="latin1",
)
query_mp_data[mp] = [torch.from_numpy(x) for x in _cur_data]
supp_mps_s.append(supp_mp_data)
query_mps_s.append(query_mp_data)
else: # 'dbook'
training_set_size = int(
len(glob.glob("{}/{}/*.pkl".format(data_dir, state)))
/ self.config.file_num
) # support, query

# load all data
for idx in tqdm(range(training_set_size)):
support_x = pickle.load(
open("{}/{}/support_x_{}.pkl".format(data_dir, state, idx), "rb")
)
if support_x.shape[0] > 5:
continue
del support_x
supp_xs_s.append(
pickle.load(
open(
"{}/{}/support_x_{}.pkl".format(data_dir, state, idx), "rb"
)
)
)
supp_ys_s.append(
pickle.load(
open(
"{}/{}/support_y_{}.pkl".format(data_dir, state, idx), "rb"
)
)
)
query_xs_s.append(
pickle.load(
open("{}/{}/query_x_{}.pkl".format(data_dir, state, idx), "rb")
)
)
query_ys_s.append(
pickle.load(
open("{}/{}/query_y_{}.pkl".format(data_dir, state, idx), "rb")
)
)

supp_mp_data, query_mp_data = {}, {}
for mp in self.mp_list:
supp_mp_data[mp] = pickle.load(
open(
"{}/{}/support_{}_{}.pkl".format(data_dir, state, mp, idx),
"rb",
)
)
query_mp_data[mp] = pickle.load(
open(
"{}/{}/query_{}_{}.pkl".format(data_dir, state, mp, idx),
"rb",
)
)
supp_mps_s.append(supp_mp_data)
query_mps_s.append(query_mp_data)

print(
"#support set: {}, #query set: {}".format(len(supp_xs_s), len(query_xs_s))
)
total_data = list(
zip(supp_xs_s, supp_ys_s, supp_mps_s, query_xs_s, query_ys_s, query_mps_s)
) # all training tasks
del (supp_xs_s, supp_ys_s, supp_mps_s, query_xs_s, query_ys_s, query_mps_s)
gc.collect()
return total_data
28 changes: 14 additions & 14 deletions openhgnn/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .SACN_dataset import *
from .NBF_dataset import NBF_Dataset
from .Ingram_dataset import Ingram_KG_TrainData, Ingram_KG_TestData
from .MetaHIN_dataset import Meta_DataHelper

DATASET_REGISTRY = {}

Expand Down Expand Up @@ -90,27 +91,17 @@ def build_dataset(dataset, task, *args, **kwargs):
if isinstance(dataset, DGLDataset):
return dataset

####### add dataset here

if dataset == "meirec":
train_dataloader = get_data_loader("train", batch_size=args[0])
test_dataloader = get_data_loader("test", batch_size=args[0])
return train_dataloader, test_dataloader


if dataset in CLASS_DATASETS:
return build_dataset_v2(dataset, task)
if not try_import_task_dataset(task):
exit(1)

if dataset == 'NL-100':
elif dataset == 'NL-100':
train_dataloader = Ingram_KG_TrainData('',dataset)
valid_dataloader = Ingram_KG_TestData('', dataset,'valid')
test_dataloader = Ingram_KG_TestData('',dataset,'test')
return train_dataloader,valid_dataloader,test_dataloader
elif dataset == 'meirec':
train_dataloader = get_data_loader("train", batch_size=args[0])
test_dataloader = get_data_loader("test", batch_size=args[0])
return train_dataloader, test_dataloader
elif dataset == 'AdapropT':
dataload=AdapropTDataLoader(args)
return dataload
Expand All @@ -119,13 +110,22 @@ def build_dataset(dataset, task, *args, **kwargs):
return dataload
elif dataset == 'SACN' or dataset == 'LTE':
return
elif dataset == "dbook":
dataload = Meta_DataHelper(args.input_dir, args)
return dataload

#############

if dataset in CLASS_DATASETS:
return build_dataset_v2(dataset, task)
if not try_import_task_dataset(task):
exit(1)

_dataset = None
if dataset in ['aifb', 'mutag', 'bgs', 'am']:
_dataset = 'rdf_' + task

##################### add dataset here
########### add dataset here
elif dataset in ['acm4HGMAE','hgprompt_acm_dblp','acm4FedHGNN']:
return DATASET_REGISTRY['common_dataset'](dataset, logger=kwargs['logger'],args = kwargs['args'])

Expand All @@ -135,7 +135,7 @@ def build_dataset(dataset, task, *args, **kwargs):
elif dataset in ['dblp4RHINE']:
_dataset = 'rhine_'+task
return DATASET_REGISTRY[_dataset](dataset, logger=kwargs['logger'],args = kwargs['args'])
######################
##########

elif dataset in ['acm4NSHE', 'acm4GTN', 'academic4HetGNN', 'acm_han', 'acm_han_raw', 'acm4HeCo', 'dblp',
'dblp4MAGNN', 'imdb4MAGNN', 'imdb4GTN', 'acm4NARS', 'demo_graph', 'yelp4HeGAN', 'DoubanMovie',
Expand Down
Loading

0 comments on commit 9a1f725

Please sign in to comment.