Skip to content

Commit

Permalink
[FEAT] new baselines
Browse files Browse the repository at this point in the history
  • Loading branch information
yehjin-shin committed Mar 17, 2024
1 parent 397c7d7 commit 704a0d1
Show file tree
Hide file tree
Showing 9 changed files with 485 additions and 59 deletions.
60 changes: 35 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
# Code for AAAI_BSARec
## 1. Install conda environments
# BSARec
This is the official source code for our AAAI 2024 Paper ["An Attentive Inductive Bias for Sequential Recommendation beyond the Self-Attention"](https://arxiv.org/abs/2312.10325)

## Overview
Beyond Self-Attention for Sequential Recommendation (BSARec) leverages Fourier transform to strike a balance between our inductive bias and self-attention.
![BSARec](fig/model_architecture.pdf)

## Dataset
In our experiments, we utilize six datasets, all stored in the `src/data` folder.
- For the Beauty, Sports, Toys, and Yelp datasets, we employed the datasets downloaded from [this repository]((https://github.com/Woeee/FMLP-Rec)).
- For ML-1M and LastFM, we processed the data according to the procedure outlined in [this code](https://github.com/RUCAIBox/CIKM2020-S3Rec/blob/master/data/data_process.py).
- The `src/data/*_same_target.npy` files are utilized for training DuoRec and FEARec, both of which incorporate contrastive learning.

## Quick Start
### Environment Setting
```
conda env create -f bsarec_env.yaml
conda activate bsarec
```

## 2. Train BSARec
Note that pretrained model (.pt) and train log file (.log) will saved in `BSARec/output`
### (1) How to train
### How to train BSARec
- Note that pretrained model (.pt) and train log file (.log) will saved in `BSARec/output`
- `train_name`: name for log file and checkpoint file
```
python main.py --data_name [DATASET] \
Expand All @@ -17,7 +29,7 @@ python main.py --data_name [DATASET] \
--num_attention_heads [N_HEADS] \
--train_name [LOG_NAME]
```
### (2) Example: Beauty
- Example for Beauty
```
python main.py --data_name Beauty \
--lr 0.0005 \
Expand All @@ -26,19 +38,9 @@ python main.py --data_name Beauty \
--num_attention_heads 1 \
--train_name BSARec_Beauty
```
### (3) Example: LastFM
```
python main.py --data_name LastFM \
--lr 0.001 \
--alpha 0.9 \
--c 3 \
--num_attention_heads 1 \
--train_name BSARec_LastFM
```

## 3. Test pretrained BSARec
Note that pretrained model (.pt file) must be in `BSARec/output`
### (1) How to test pretrained model
### How to test pretrained BSARec
- Note that pretrained model (.pt file) must be in `BSARec/output`
- `load_model`: pretrained model name without .pt
```
python main.py --data_name [DATASET] \
Expand All @@ -48,7 +50,7 @@ python main.py --data_name [DATASET] \
--load_model [PRETRAINED_MODEL_NAME] \
--do_eval
```
### (2) Beauty
- Example for Beauty
```
python main.py --data_name Beauty \
--alpha 0.7 \
Expand All @@ -57,12 +59,20 @@ python main.py --data_name Beauty \
--load_model BSARec_Beauty_best \
--do_eval
```
### (3) LastFM

### How to train the baselines
- You can easily train the baseline models used in BSARec by changing the `model_type` argument.
- `model_type`: Caser, GRU4Rec, SASRec, BERT4Rec, FMLPRec, DuoRec, FEARec
- For the hyperparameters for the baselines, check the `parse_args()` function in `src/utils.py`.
```
python main.py --data_name LastFM \
--alpha 0.9 \
--c 3 \
python main.py --model_type SASRec \
--data_name Beauty \
--num_attention_heads 1 \
--load_model BSARec_LastFM_best \
--do_eval
--train_name SASRec_Beauty
```

## Contact
If you have any inquiries regarding our paper or codes, feel free to reach out via email at yehjin.shin@yonsei.ac.kr.

## Acknowledgement
This repository is based on [FMLP-Rec](https://github.com/Woeee/FMLP-Rec).
75 changes: 67 additions & 8 deletions src/dataset.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
import tqdm
import numpy as np
import torch
import os
from scipy.sparse import csr_matrix
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import random

class RecDataset(Dataset):
def __init__(self, args, user_seq, test_neg_items=None, data_type='train'):
self.args = args
self.user_seq = []
self.max_len = args.max_seq_length
self.user_ids = []
self.contrastive_learning = args.model_type.lower() in ['fearec', 'duorec']
self.data_type = data_type

if data_type=='train':
if self.data_type=='train':
for user, seq in enumerate(user_seq):
input_ids = seq[-(self.max_len + 2):-2]
for i in range(len(input_ids)):
self.user_seq.append(input_ids[:i + 1])
self.user_ids.append(user)
elif data_type=='valid':
elif self.data_type=='valid':
for sequence in user_seq:
self.user_seq.append(sequence[:-1])
else:
self.user_seq = user_seq

self.test_neg_items = test_neg_items
self.data_type = data_type

if self.contrastive_learning and self.data_type=='train':
if os.path.exists(args.same_target_path):
self.same_target_index = np.load(args.same_target_path, allow_pickle=True)
else:
print("Start making same_target_index for contrastive learning")
self.same_target_index = self.get_same_target_index()
self.same_target_index = np.array(self.same_target_index)
np.save(args.same_target_path, self.same_target_index)

def get_same_target_index(self):
num_items = max([max(v) for v in self.user_seq]) + 2
Expand All @@ -51,19 +63,66 @@ def __getitem__(self, index):
input_ids = items[:-1]
answer = items[-1]

seq_set = set(items)
neg_answer = neg_sample(seq_set, self.args.item_size)

pad_len = self.max_len - len(input_ids)
input_ids = [0] * pad_len + input_ids
input_ids = input_ids[-self.max_len:]
assert len(input_ids) == self.max_len

cur_tensors = (
torch.tensor(index, dtype=torch.long),
torch.tensor(input_ids, dtype=torch.long),
torch.tensor(answer, dtype=torch.long),
)
if self.data_type in ['valid', 'test']:
cur_tensors = (
torch.tensor(index, dtype=torch.long), # user_id for testing
torch.tensor(input_ids, dtype=torch.long),
torch.tensor(answer, dtype=torch.long),
torch.zeros(0, dtype=torch.long), # not used
torch.zeros(0, dtype=torch.long), # not used
)

elif self.contrastive_learning:
sem_augs = self.same_target_index[answer]
sem_aug = random.choice(sem_augs)
keep_random = False
for i in range(len(sem_augs)):
if sem_augs[0] != sem_augs[i]:
keep_random = True

while keep_random and sem_aug == items:
sem_aug = random.choice(sem_augs)

sem_aug = sem_aug[:-1]
pad_len = self.max_len - len(sem_aug)
sem_aug = [0] * pad_len + sem_aug
sem_aug = sem_aug[-self.max_len:]
assert len(sem_aug) == self.max_len

cur_tensors = (
torch.tensor(self.user_ids[index], dtype=torch.long), # user_id for testing
torch.tensor(input_ids, dtype=torch.long),
torch.tensor(answer, dtype=torch.long),
torch.tensor(neg_answer, dtype=torch.long),
torch.tensor(sem_aug, dtype=torch.long)
)

else:
cur_tensors = (
torch.tensor(self.user_ids[index], dtype=torch.long), # user_id for testing
torch.tensor(input_ids, dtype=torch.long),
torch.tensor(answer, dtype=torch.long),
torch.tensor(neg_answer, dtype=torch.long),
torch.zeros(0, dtype=torch.long), # not used
)

return cur_tensors


def neg_sample(item_set, item_size):
item = random.randint(1, item_size - 1)
while item in item_set:
item = random.randint(1, item_size - 1)
return item

def generate_rating_matrix_valid(user_seq, num_users, num_items):
# three lists are used to construct sparse matrix
row = []
Expand Down
1 change: 1 addition & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main():
args.num_users = num_users + 1

args.checkpoint_path = os.path.join(args.output_dir, args.train_name + '.pt')
args.same_target_path = os.path.join(args.data_dir, args.data_name+'_same_target.npy')
train_dataloader, eval_dataloader, test_dataloader = get_dataloder(args,seq_dic)

logger.info(str(args))
Expand Down
14 changes: 14 additions & 0 deletions src/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
from model.bsarec import BSARecModel
from model.caser import CaserModel
from model.gru4rec import GRU4RecModel
from model.sasrec import SASRecModel
from model.bert4rec import BERT4RecModel
from model.fmlprec import FMLPRecModel
from model.duorec import DuoRecModel
from model.fearec import FEARecModel

MODEL_DICT = {
"bsarec": BSARecModel,
"caser": CaserModel,
"gru4rec": GRU4RecModel,
"sasrec": SASRecModel,
"bert4rec": BERT4RecModel,
"fmlprec": FMLPRecModel,
"duorec": DuoRecModel,
"fearec": FEARecModel,
}
19 changes: 15 additions & 4 deletions src/model/_abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def add_position_embedding(self, sequence):
return sequence_emb

def init_weights(self, module):
""" Initialize the weights.
"""
""" Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
Expand All @@ -39,6 +38,18 @@ def init_weights(self, module):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()

def get_bi_attention_mask(self, item_seq):
"""Generate bidirectional attention mask for multi-head attention."""

attention_mask = (item_seq > 0).long()
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64

# bidirectional mask
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

return extended_attention_mask

def get_attention_mask(self, item_seq):
"""Generate left-to-right uni-directional attention mask for multi-head attention."""

Expand All @@ -60,8 +71,8 @@ def get_attention_mask(self, item_seq):
def forward(self, input_ids, all_sequence_output=False):
pass

def predict(self, input_ids, all_sequence_output=False):
return self.forward(input_ids, all_sequence_output)
def predict(self, input_ids, user_ids, all_sequence_output=False):
return self.forward(input_ids, user_ids, all_sequence_output)

def calculate_loss(self, input_ids, answers):
pass
Expand Down
Loading

0 comments on commit 704a0d1

Please sign in to comment.