-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f3c4285
Showing
15 changed files
with
45,986 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# FMLP-Rec | ||
The source code for our WWW 2022 Paper [**"Filter-enhanced MLP is All You Need for Sequential Recommendation"**]() | ||
|
||
|
||
## Requirements | ||
* Install Python, Pytorch(>=1.8). We use Python 3.7, Pytorch 1.8. | ||
* If you plan to use GPU computation, install CUDA. | ||
|
||
## Overview | ||
**FMLP-Rec** stacks multiple **Filter-enhanced Blocks** to produce the representation of sequential user preference for recommendation. The key difference between our approach and SASRec is to replace the multi-head self-attention structure in Transformer with a novel filter structure. You can transform FMLP-Rec to SASRec, by adding `--no_filters` parameter when running code. | ||
|
||
![avatar](fig/model.png) | ||
|
||
## Datasets | ||
We use eight datasets in our paper, all of which have been uploaded to [Google Drive](https://drive.google.com/drive/folders/1omfrWZiYwmj3eFpIpb-8O29wbt4SVGzP?usp=sharing) | ||
and [Baidu Netdisk](https://pan.baidu.com/s/1we2eJ_Vz9SM33PoRqPNijQ?pwd=kzq2). | ||
|
||
The downloaded dataset should be placed in the `data` folder, furthermore, session-based dataset should be placed in a folder named after the dataset. | ||
|
||
If you want to use your own dataset, please follow the steps below: | ||
1. Prepare a file with user_ids and each follows 99 negative samples, and name it with `YOUR_DATASTES_sample.txt`. For session-based dataset, only validation set and test set need to be sampled. | ||
2. Place your dataset and sample file in the `data` folder. For session-based dataset, a folder named after the dataset is needed. | ||
3. Add the name of your dataset to the data list in utils.py, according to the data type. | ||
|
||
|
||
## Quick-Start | ||
If you have downloaded the source codes, you can train the model just with data_name input. | ||
``` | ||
python main.py --data_name=[data_name] | ||
``` | ||
|
||
If you want to change the parameters, just set the additional command parameters as you need. For example: | ||
``` | ||
python main.py --data_name=Beauty --num_hidden_layers=4 --batch_size=512 | ||
``` | ||
|
||
You can also test the model has been saved by command line. | ||
``` | ||
python main.py --data_name=Beauty --do_eval --load_model=FMLPRec-Beauty-4eval | ||
``` | ||
|
||
Additional hyper-parameters can be specified, and detailed information can be accessed by: | ||
|
||
``` | ||
python main.py --help | ||
``` | ||
|
||
## Contact | ||
If you have any questions for our paper or codes, please send an email to ishyu@outlook.com. | ||
|
||
## Acknowledgement | ||
Our code is developed based on [S3-Rec](https://github.com/RUCAIBox/CIKM2020-S3Rec)* |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import random | ||
import argparse | ||
def get_item_size(data_file): | ||
lines = open(data_file).readlines() | ||
lines = lines[1:] | ||
user_seq = [] | ||
item_set = set() | ||
for line in lines: | ||
user, items = line.strip().split(' ', 1) | ||
items = items.split() | ||
items = [int(item) for item in items] | ||
user_seq.append(items) | ||
item_set = item_set | set(items) | ||
max_item = max(item_set) | ||
item_size = max_item + 1 | ||
return item_size | ||
|
||
def get_user_seqs_and_gene_sample(data_file,item_size): | ||
lines = open(data_file).readlines() | ||
lines = lines[1:] | ||
user_seq = [] | ||
item_set = set() | ||
for line in lines: | ||
user, items = line.strip().split(' ', 1) | ||
items = items.split() | ||
items = [int(item) for item in items] | ||
user_seq.append(items) | ||
item_set = item_set | set(items) | ||
|
||
sample_seq = [] | ||
for i in range(len(lines)): | ||
sample_list = neg_sample(set(user_seq[i]), item_size) | ||
sample_seq.append(sample_list) | ||
|
||
return sample_seq | ||
|
||
def neg_sample(item_set, item_size): # 前闭后闭 | ||
sample_list = [] | ||
for _ in range(99): | ||
item = random.randint(1, item_size - 1) | ||
while (item in item_set) or (item in sample_list): | ||
item = random.randint(1, item_size - 1) | ||
sample_list.append(item) | ||
return sample_list | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--data_dir', default='./', type=str) | ||
parser.add_argument('--data_name', default='nowplaying', type=str) | ||
args = parser.parse_args() | ||
args.data_file = args.data_dir + args.data_name +'/'+ args.data_name + '.train.inter' | ||
args.data_file_eval = args.data_dir + args.data_name +'/'+ args.data_name + '.valid.inter' | ||
args.data_file_test = args.data_dir + args.data_name +'/'+ args.data_name + '.test.inter' | ||
|
||
args.sample_file_eval = args.data_dir + args.data_name +'/'+ args.data_name + '_valid_sample.txt' | ||
args.sample_file_test = args.data_dir + args.data_name +'/'+ args.data_name + '_test_sample.txt' | ||
|
||
item_size = get_item_size(args.data_file) | ||
neg_sample_eval = get_user_seqs_and_gene_sample(args.data_file_eval,item_size) | ||
output = open(args.sample_file_eval,'w') | ||
for i in range(len(neg_sample_eval)): | ||
output.write(str(i)) | ||
for k in neg_sample_eval[i]: | ||
output.write(' '+str(k)) | ||
output.write('\n') | ||
output.close() | ||
neg_sample_test = get_user_seqs_and_gene_sample(args.data_file_test,item_size) | ||
output = open(args.sample_file_test,'w') | ||
for i in range(len(neg_sample_test)): | ||
output.write(str(i)) | ||
for k in neg_sample_test[i]: | ||
output.write(' '+str(k)) | ||
output.write('\n') | ||
output.close() | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# @Time : 2022/2/14 | ||
# @Author : Hui Yu | ||
# @Email : ishyu@outlook.com | ||
|
||
import torch | ||
import random | ||
from torch.utils.data import Dataset | ||
|
||
class FMLPRecDataset(Dataset): | ||
def __init__(self, args, user_seq, test_neg_items=None, data_type='train'): | ||
self.args = args | ||
self.user_seq = [] | ||
self.max_len = args.max_seq_length | ||
|
||
if data_type=='train': | ||
for seq in user_seq: | ||
input_ids = seq[-(self.max_len + 2):-2] # keeping same as train set | ||
for i in range(len(input_ids)): | ||
self.user_seq.append(input_ids[:i + 1]) | ||
elif data_type=='valid': | ||
for sequence in user_seq: | ||
self.user_seq.append(sequence[:-1]) | ||
else: | ||
self.user_seq = user_seq | ||
|
||
self.test_neg_items = test_neg_items | ||
self.data_type = data_type | ||
self.max_len = args.max_seq_length | ||
|
||
|
||
def __len__(self): | ||
return len(self.user_seq) | ||
|
||
def __getitem__(self, index): | ||
items = self.user_seq[index] | ||
input_ids = items[:-1] | ||
answer = items[-1] | ||
|
||
seq_set = set(items) | ||
neg_answer = neg_sample(seq_set, self.args.item_size) | ||
|
||
pad_len = self.max_len - len(input_ids) | ||
input_ids = [0] * pad_len + input_ids | ||
input_ids = input_ids[-self.max_len:] | ||
assert len(input_ids) == self.max_len | ||
# Associated Attribute Prediction | ||
# Masked Attribute Prediction | ||
|
||
if self.test_neg_items is not None: | ||
test_samples = self.test_neg_items[index] | ||
|
||
cur_tensors = ( | ||
torch.tensor(index, dtype=torch.long), # user_id for testing | ||
torch.tensor(input_ids, dtype=torch.long), | ||
#torch.tensor(attribute, dtype=torch.long), | ||
torch.tensor(answer, dtype=torch.long), | ||
torch.tensor(neg_answer, dtype=torch.long), | ||
torch.tensor(test_samples, dtype=torch.long), | ||
) | ||
else: | ||
cur_tensors = ( | ||
torch.tensor(index, dtype=torch.long), # user_id for testing | ||
torch.tensor(input_ids, dtype=torch.long), | ||
#torch.tensor(attribute, dtype=torch.long), | ||
torch.tensor(answer, dtype=torch.long), | ||
torch.tensor(neg_answer, dtype=torch.long), | ||
) | ||
|
||
return cur_tensors | ||
|
||
def neg_sample(item_set, item_size): # 前闭后闭 | ||
item = random.randint(1, item_size - 1) | ||
while item in item_set: | ||
item = random.randint(1, item_size - 1) | ||
return item |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# @Time : 2022/2/13 | ||
# @Author : Hui Yu | ||
# @Email : ishyu@outlook.com | ||
|
||
import os | ||
import torch | ||
import argparse | ||
import numpy as np | ||
|
||
from models import FMLPRecModel | ||
from trainers import FMLPRecTrainer | ||
from utils import EarlyStopping, check_path, set_seed, get_local_time, get_seq_dic, get_dataloder, get_rating_matrix | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--data_dir", default="./data/", type=str) | ||
parser.add_argument("--output_dir", default="output/", type=str) | ||
parser.add_argument("--data_name", default="Beauty", type=str) | ||
parser.add_argument("--do_eval", action="store_true") | ||
parser.add_argument("--load_model", default=None, type=str) | ||
|
||
# model args | ||
parser.add_argument("--model_name", default="FMLPRec", type=str) | ||
parser.add_argument("--hidden_size", default=64, type=int, help="hidden size of model") | ||
parser.add_argument("--num_hidden_layers", default=2, type=int, help="number of filter-enhanced blocks") | ||
parser.add_argument("--num_attention_heads", default=2, type=int) | ||
parser.add_argument("--hidden_act", default="gelu", type=str) # gelu relu | ||
parser.add_argument("--attention_probs_dropout_prob", default=0.5, type=float) | ||
parser.add_argument("--hidden_dropout_prob", default=0.5, type=float) | ||
parser.add_argument("--initializer_range", default=0.02, type=float) | ||
parser.add_argument("--max_seq_length", default=50, type=int) | ||
parser.add_argument("--no_filters", action="store_true", help="if no filters, filter layers transform to self-attention") | ||
|
||
# train args | ||
parser.add_argument("--lr", default=0.001, type=float, help="learning rate of adam") | ||
parser.add_argument("--batch_size", default=256, type=int, help="number of batch_size") | ||
parser.add_argument("--epochs", default=200, type=int, help="number of epochs") | ||
parser.add_argument("--no_cuda", action="store_true") | ||
parser.add_argument("--log_freq", default=1, type=int, help="per epoch print res") | ||
parser.add_argument("--full_sort", action="store_true") | ||
parser.add_argument("--patience", default=10, type=int, help="how long to wait after last time validation loss improved") | ||
|
||
parser.add_argument("--seed", default=42, type=int) | ||
parser.add_argument("--weight_decay", default=0.0, type=float, help="weight_decay of adam") | ||
parser.add_argument("--adam_beta1", default=0.9, type=float, help="adam first beta value") | ||
parser.add_argument("--adam_beta2", default=0.999, type=float, help="adam second beta value") | ||
parser.add_argument("--gpu_id", default="0", type=str, help="gpu_id") | ||
parser.add_argument("--variance", default=5, type=float) | ||
|
||
args = parser.parse_args() | ||
|
||
set_seed(args.seed) | ||
check_path(args.output_dir) | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id | ||
args.cuda_condition = torch.cuda.is_available() and not args.no_cuda | ||
|
||
seq_dic, max_item = get_seq_dic(args) | ||
|
||
args.item_size = max_item + 1 | ||
|
||
# save model args | ||
cur_time = get_local_time() | ||
if args.no_filters: | ||
args.model_name = "SASRec" | ||
args_str = f'{args.model_name}-{args.data_name}-{cur_time}' | ||
args.log_file = os.path.join(args.output_dir, args_str + '.txt') | ||
print(str(args)) | ||
with open(args.log_file, 'a') as f: | ||
f.write(str(args) + '\n') | ||
|
||
# save model | ||
args.checkpoint_path = os.path.join(args.output_dir, args_str + '.pt') | ||
|
||
train_dataloader, eval_dataloader, test_dataloader = get_dataloder(args,seq_dic) | ||
|
||
model = FMLPRecModel(args=args) | ||
trainer = FMLPRecTrainer(model, train_dataloader, eval_dataloader, | ||
test_dataloader, args) | ||
|
||
if args.full_sort: | ||
args.valid_rating_matrix, args.test_rating_matrix = get_rating_matrix(args.data_name, seq_dic, max_item) | ||
|
||
if args.do_eval: | ||
if args.load_model is None: | ||
print(f"No model input!") | ||
exit(0) | ||
else: | ||
args.checkpoint_path = os.path.join(args.output_dir, args.load_model + '.pt') | ||
trainer.load(args.checkpoint_path) | ||
print(f"Load model from {args.checkpoint_path} for test!") | ||
scores, result_info = trainer.test(0, full_sort=args.full_sort) | ||
|
||
else: | ||
early_stopping = EarlyStopping(args.checkpoint_path, patience=args.patience, verbose=True) | ||
for epoch in range(args.epochs): | ||
trainer.train(epoch) | ||
scores, _ = trainer.valid(epoch, full_sort=args.full_sort) | ||
# evaluate on MRR | ||
early_stopping(np.array(scores[-1:]), trainer.model) | ||
if early_stopping.early_stop: | ||
print("Early stopping") | ||
break | ||
|
||
print("---------------Sample 99 results---------------") | ||
# load the best model | ||
trainer.model.load_state_dict(torch.load(args.checkpoint_path)) | ||
scores, result_info = trainer.test(0, full_sort=args.full_sort) | ||
|
||
print(args_str) | ||
print(result_info) | ||
with open(args.log_file, 'a') as f: | ||
f.write(args_str + '\n') | ||
f.write(result_info + '\n') | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# @Time : 2022/2/13 | ||
# @Author : Hui Yu | ||
# @Email : ishyu@outlook.com | ||
|
||
import torch | ||
import torch.nn as nn | ||
from modules import Encoder, LayerNorm | ||
|
||
class FMLPRecModel(nn.Module): | ||
def __init__(self, args): | ||
super(FMLPRecModel, self).__init__() | ||
self.args = args | ||
self.item_embeddings = nn.Embedding(args.item_size, args.hidden_size, padding_idx=0) | ||
self.position_embeddings = nn.Embedding(args.max_seq_length, args.hidden_size) | ||
self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) | ||
self.dropout = nn.Dropout(args.hidden_dropout_prob) | ||
self.item_encoder = Encoder(args) | ||
|
||
self.apply(self.init_weights) | ||
|
||
def add_position_embedding(self, sequence): | ||
seq_length = sequence.size(1) | ||
position_ids = torch.arange(seq_length, dtype=torch.long, device=sequence.device) | ||
position_ids = position_ids.unsqueeze(0).expand_as(sequence) | ||
item_embeddings = self.item_embeddings(sequence) | ||
position_embeddings = self.position_embeddings(position_ids) | ||
sequence_emb = item_embeddings + position_embeddings | ||
sequence_emb = self.LayerNorm(sequence_emb) | ||
sequence_emb = self.dropout(sequence_emb) | ||
|
||
return sequence_emb | ||
|
||
# same as SASRec | ||
def forward(self, input_ids): | ||
attention_mask = (input_ids > 0).long() | ||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64 | ||
max_len = attention_mask.size(-1) | ||
attn_shape = (1, max_len, max_len) | ||
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8 | ||
subsequent_mask = (subsequent_mask == 0).unsqueeze(1) | ||
subsequent_mask = subsequent_mask.long() | ||
|
||
if self.args.cuda_condition: | ||
subsequent_mask = subsequent_mask.cuda() | ||
extended_attention_mask = extended_attention_mask * subsequent_mask | ||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | ||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | ||
|
||
sequence_emb = self.add_position_embedding(input_ids) | ||
|
||
item_encoded_layers = self.item_encoder(sequence_emb, | ||
extended_attention_mask, | ||
output_all_encoded_layers=True, | ||
) | ||
sequence_output = item_encoded_layers[-1] | ||
|
||
return sequence_output | ||
|
||
def init_weights(self, module): | ||
""" Initialize the weights. | ||
""" | ||
if isinstance(module, (nn.Linear, nn.Embedding)): | ||
# Slightly different from the TF version which uses truncated_normal for initialization | ||
# cf https://github.com/pytorch/pytorch/pull/5617 | ||
module.weight.data.normal_(mean=0.0, std=self.args.initializer_range) | ||
elif isinstance(module, LayerNorm): | ||
module.bias.data.zero_() | ||
module.weight.data.fill_(1.0) | ||
if isinstance(module, nn.Linear) and module.bias is not None: | ||
module.bias.data.zero_() |
Oops, something went wrong.