Skip to content

Commit

Permalink
Update dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
The WISE Lab @ Rutgers CS authored Dec 19, 2020
1 parent edc8b30 commit 6ec238c
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions src/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,24 @@ def random_split_data(all_data_file, dataset_name, vt_ratio=0.1, copy_files=None
随机切分已经生成的数据集文件 *.all.csv -> *.train.csv,*.validation.csv,*.test.csv
:param all_data_file: 数据预处理完的文件 *.all.csv
:param dataset_name: 给数据集起个名字
:param vt_ratio: 验证集合测试集比例
:param vt_ratio: 验证集和测试集比例
:param copy_files: 需要复制的文件
:param copy_suffixes: 要复制的文件的生成的后缀名
:return: pandas dataframe 训练集,验证集,测试集
Randomly split the already generated datasets *.all.csv -> *.train.csv,*.validation.csv,*.test.csv
:param all_data_file: data files after pre-processing *.all.csv
:param dataset_name: create a name for the dataset
:param vt_ratio: ratio between validation and testing set
:param copy_files: files to copy
:param copy_suffixes: suffix of the filename of the files to be copied
:return: pandas dataframe training set, validation set, testing set
"""
dir_name = os.path.join(DATASET_DIR, dataset_name)
print('random_split_data', dir_name)
if not os.path.exists(dir_name): # 如果数据集文件夹dataset_name不存在,则创建该文件夹,dataset_name是文件夹名字
# 如果数据集文件夹dataset_name不存在,则创建该文件夹,dataset_name是文件夹名字
# If dataset folder dataset_name does not exist, then create the folder, dataset_name is the name of the folder
if not os.path.exists(dir_name):
os.mkdir(dir_name)
all_data = pd.read_csv(all_data_file, sep=SEP)
vt_size = int(len(all_data) * vt_ratio)
Expand All @@ -61,6 +71,7 @@ def random_split_data(all_data_file, dataset_name, vt_ratio=0.1, copy_files=None
test_set.to_csv(os.path.join(dir_name, dataset_name + TEST_SUFFIX), index=False, sep=SEP)

# 复制用户、物品的特征文件
# Copy the user, item feature file
if copy_files is not None:
if type(copy_files) is str:
copy_files = [copy_files]
Expand Down Expand Up @@ -116,13 +127,17 @@ def leave_out_by_time_df(all_df, leave_n=1, warm_n=5, split_n=1, max_user=-1):
for uid in total_uids:
group = gb_uid.get_group(uid)
found, found_idx = 0, -1
for idx in reversed(group.index): # 倒着看该用户的历史,直到找到一个正样本
# 倒着看该用户的历史,直到找到一个正样本
# Look at the user's history inversely, until a positive example is found
for idx in reversed(group.index):
if group.loc[idx, LABEL] > 0:
found_idx = idx
found += 1
if found >= leave_n:
break
if found > 0: # 如果找到了正样本,则该样本及其之后的负样本全部放入测试集
# 如果找到了正样本,则该样本及其之后的负样本全部放入测试集
# If a positive example is found, then this example and the negative examples after this example are put into testing set
if found > 0:
split_df.append(group.loc[found_idx:])
split_df = pd.concat(split_df).sort_index()
all_df = all_df.drop(split_df.index)
Expand All @@ -133,18 +148,29 @@ def leave_out_by_time_df(all_df, leave_n=1, warm_n=5, split_n=1, max_user=-1):

def leave_out_by_time_csv(all_data_file, dataset_name, leave_n=1, warm_n=5, u_f=None, i_f=None):
"""
默认all_data里的交互是按时间顺序排列的,按交互顺序,把最后的交互划分到验证集合测试集里
默认all_data里的交互是按时间顺序排列的,按交互顺序,把最后的交互划分到验证集和测试集里
:param all_data_file: 数据预处理完的文件 *.all.csv,交互按时间顺序排列
:param dataset_name: 给数据集起个名字
:param leave_n: 验证和测试集保留几个用户交互
:param warm_n: 保证测试用户在训练集中至少有warm_n个交互,否则交互全部放在训练集中
:param u_f: 用户特征文件 *.user.csv
:param i_f: 物品特征文件 *.item.csv
:return: pandas dataframe 训练集,验证集,测试集
By default, the interaction history in all_data are sorting according to timestamp, according to the interaction time, put the last interactions into validation and testing set
:param all_data_file: data file after pre-processing *.all.csv, interactions are sorted according to timestamp
:param dataset_name: create a name for the dataset
:param leave_n: how many interactions to leave out in validation and testing set
:param warm_n: guranttee that the testing user has at least warn_n number of interactions in training set, otherwise put all interactions into training set
:param u_f: user feature vector *.user.csv
:param i_f: item feature vector *.item.csv
:return: pandas dataframe training set, validation set, testing set
"""
dir_name = os.path.join(DATASET_DIR, dataset_name)
print('leave_out_by_time_csv', dir_name, leave_n, warm_n)
if not os.path.exists(dir_name): # 如果数据集文件夹dataset_name不存在,则创建该文件夹,dataset_name是文件夹名字
# 如果数据集文件夹dataset_name不存在,则创建该文件夹,dataset_name是文件夹名字
# If the dataset folder data_name does not exist, then create the folder, dataset_name is the name of the folder
if not os.path.exists(dir_name):
os.mkdir(dir_name)
all_data = pd.read_csv(all_data_file, sep=SEP)

Expand All @@ -159,6 +185,7 @@ def leave_out_by_time_csv(all_data_file, dataset_name, leave_n=1, warm_n=5, u_f=
validation_set.to_csv(os.path.join(dir_name, dataset_name + VALIDATION_SUFFIX), index=False, sep=SEP)
test_set.to_csv(os.path.join(dir_name, dataset_name + TEST_SUFFIX), index=False, sep=SEP)
# 复制用户、物品的特征文件
# Copy the user, item feature file
if u_f is not None:
copyfile(u_f, os.path.join(dir_name, dataset_name + USER_SUFFIX))
if i_f is not None:
Expand Down

0 comments on commit 6ec238c

Please sign in to comment.