-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fbc71e8
commit 34036fd
Showing
20 changed files
with
116,349 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
name: bsarec | ||
channels: | ||
- pytorch | ||
- nvidia | ||
- defaults | ||
dependencies: | ||
- cudatoolkit=11.1.74 | ||
- numpy=1.24.3 | ||
- python=3.9.7 | ||
- pytorch=1.8.1 | ||
- pip=23.1.2=py39h06a4308_0 | ||
- pip: | ||
- scipy==1.11.1 | ||
- tqdm==4.65.0 |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import tqdm | ||
import numpy as np | ||
import torch | ||
from scipy.sparse import csr_matrix | ||
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler | ||
|
||
class RecDataset(Dataset): | ||
def __init__(self, args, user_seq, test_neg_items=None, data_type='train'): | ||
self.args = args | ||
self.user_seq = [] | ||
self.max_len = args.max_seq_length | ||
self.user_ids = [] | ||
|
||
if data_type=='train': | ||
for user, seq in enumerate(user_seq): | ||
input_ids = seq[-(self.max_len + 2):-2] | ||
for i in range(len(input_ids)): | ||
self.user_seq.append(input_ids[:i + 1]) | ||
self.user_ids.append(user) | ||
elif data_type=='valid': | ||
for sequence in user_seq: | ||
self.user_seq.append(sequence[:-1]) | ||
else: | ||
self.user_seq = user_seq | ||
|
||
self.test_neg_items = test_neg_items | ||
self.data_type = data_type | ||
|
||
def get_same_target_index(self): | ||
num_items = max([max(v) for v in self.user_seq]) + 2 | ||
same_target_index = [[] for _ in range(num_items)] | ||
|
||
user_seq = self.user_seq[:] | ||
tmp_user_seq = [] | ||
for i in tqdm.tqdm(range(1, num_items)): | ||
for j in range(len(user_seq)): | ||
if user_seq[j][-1] == i: | ||
same_target_index[i].append(user_seq[j]) | ||
else: | ||
tmp_user_seq.append(user_seq[j]) | ||
user_seq = tmp_user_seq | ||
tmp_user_seq = [] | ||
|
||
return same_target_index | ||
|
||
def __len__(self): | ||
return len(self.user_seq) | ||
|
||
def __getitem__(self, index): | ||
items = self.user_seq[index] | ||
input_ids = items[:-1] | ||
answer = items[-1] | ||
|
||
pad_len = self.max_len - len(input_ids) | ||
input_ids = [0] * pad_len + input_ids | ||
input_ids = input_ids[-self.max_len:] | ||
assert len(input_ids) == self.max_len | ||
|
||
cur_tensors = ( | ||
torch.tensor(index, dtype=torch.long), | ||
torch.tensor(input_ids, dtype=torch.long), | ||
torch.tensor(answer, dtype=torch.long), | ||
) | ||
|
||
return cur_tensors | ||
|
||
def generate_rating_matrix_valid(user_seq, num_users, num_items): | ||
# three lists are used to construct sparse matrix | ||
row = [] | ||
col = [] | ||
data = [] | ||
for user_id, item_list in enumerate(user_seq): | ||
for item in item_list[:-2]: # | ||
row.append(user_id) | ||
col.append(item) | ||
data.append(1) | ||
|
||
row = np.array(row) | ||
col = np.array(col) | ||
data = np.array(data) | ||
rating_matrix = csr_matrix((data, (row, col)), shape=(num_users, num_items)) | ||
|
||
return rating_matrix | ||
|
||
def generate_rating_matrix_test(user_seq, num_users, num_items): | ||
# three lists are used to construct sparse matrix | ||
row = [] | ||
col = [] | ||
data = [] | ||
for user_id, item_list in enumerate(user_seq): | ||
for item in item_list[:-1]: # | ||
row.append(user_id) | ||
col.append(item) | ||
data.append(1) | ||
|
||
row = np.array(row) | ||
col = np.array(col) | ||
data = np.array(data) | ||
rating_matrix = csr_matrix((data, (row, col)), shape=(num_users, num_items)) | ||
|
||
return rating_matrix | ||
|
||
def get_rating_matrix(data_name, seq_dic, max_item): | ||
|
||
num_items = max_item + 1 | ||
valid_rating_matrix = generate_rating_matrix_valid(seq_dic['user_seq'], seq_dic['num_users'], num_items) | ||
test_rating_matrix = generate_rating_matrix_test(seq_dic['user_seq'], seq_dic['num_users'], num_items) | ||
|
||
return valid_rating_matrix, test_rating_matrix | ||
|
||
def get_user_seqs_and_max_item(data_file): | ||
lines = open(data_file).readlines() | ||
lines = lines[1:] | ||
user_seq = [] | ||
item_set = set() | ||
for line in lines: | ||
user, items = line.strip().split(' ', 1) | ||
items = items.split() | ||
items = [int(item) for item in items] | ||
user_seq.append(items) | ||
item_set = item_set | set(items) | ||
max_item = max(item_set) | ||
return user_seq, max_item | ||
|
||
def get_user_seqs(data_file): | ||
lines = open(data_file).readlines() | ||
user_seq = [] | ||
item_set = set() | ||
for line in lines: | ||
user, items = line.strip().split(' ', 1) | ||
items = items.split(' ') | ||
items = [int(item) for item in items] | ||
user_seq.append(items) | ||
item_set = item_set | set(items) | ||
max_item = max(item_set) | ||
num_users = len(lines) | ||
|
||
return user_seq, max_item, num_users | ||
|
||
def get_seq_dic(args): | ||
|
||
args.data_file = args.data_dir + args.data_name + '.txt' | ||
user_seq, max_item, num_users = get_user_seqs(args.data_file) | ||
seq_dic = {'user_seq':user_seq, 'num_users':num_users } | ||
|
||
return seq_dic, max_item, num_users | ||
|
||
def get_dataloder(args,seq_dic): | ||
|
||
train_dataset = RecDataset(args, seq_dic['user_seq'], data_type='train') | ||
train_sampler = RandomSampler(train_dataset) | ||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size, num_workers=args.num_workers) | ||
|
||
eval_dataset = RecDataset(args, seq_dic['user_seq'], data_type='valid') | ||
eval_sampler = SequentialSampler(eval_dataset) | ||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, num_workers=args.num_workers) | ||
|
||
test_dataset = RecDataset(args, seq_dic['user_seq'], data_type='test') | ||
test_sampler = SequentialSampler(test_dataset) | ||
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.batch_size, num_workers=args.num_workers) | ||
|
||
return train_dataloader, eval_dataloader, test_dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import torch | ||
import numpy as np | ||
|
||
from model import MODEL_DICT | ||
from trainers import Trainer | ||
from utils import EarlyStopping, check_path, set_seed, parse_args, set_logger | ||
from dataset import get_seq_dic, get_dataloder, get_rating_matrix | ||
|
||
def main(): | ||
|
||
args = parse_args() | ||
log_path = os.path.join(args.output_dir, args.train_name + '.log') | ||
logger = set_logger(log_path) | ||
|
||
set_seed(args.seed) | ||
check_path(args.output_dir) | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id | ||
args.cuda_condition = torch.cuda.is_available() and not args.no_cuda | ||
|
||
seq_dic, max_item, num_users = get_seq_dic(args) | ||
args.item_size = max_item + 1 | ||
args.num_users = num_users + 1 | ||
|
||
args.checkpoint_path = os.path.join(args.output_dir, args.train_name + '.pt') | ||
train_dataloader, eval_dataloader, test_dataloader = get_dataloder(args,seq_dic) | ||
|
||
logger.info(str(args)) | ||
model = MODEL_DICT[args.model_type.lower()](args=args) | ||
logger.info(model) | ||
trainer = Trainer(model, train_dataloader, eval_dataloader, test_dataloader, args, logger) | ||
|
||
args.valid_rating_matrix, args.test_rating_matrix = get_rating_matrix(args.data_name, seq_dic, max_item) | ||
|
||
if args.do_eval: | ||
if args.load_model is None: | ||
logger.info(f"No model input!") | ||
exit(0) | ||
else: | ||
args.checkpoint_path = os.path.join(args.output_dir, args.load_model + '.pt') | ||
trainer.load(args.checkpoint_path) | ||
logger.info(f"Load model from {args.checkpoint_path} for test!") | ||
scores, result_info = trainer.test(0) | ||
|
||
else: | ||
early_stopping = EarlyStopping(args.checkpoint_path, logger=logger, patience=args.patience, verbose=True) | ||
for epoch in range(args.epochs): | ||
|
||
trainer.train(epoch) | ||
scores, _ = trainer.valid(epoch) | ||
# evaluate on MRR | ||
early_stopping(np.array(scores[-1:]), trainer.model) | ||
if early_stopping.early_stop: | ||
logger.info("Early stopping") | ||
break | ||
|
||
logger.info("---------------Test Score---------------") | ||
trainer.model.load_state_dict(torch.load(args.checkpoint_path)) | ||
scores, result_info = trainer.test(0) | ||
|
||
logger.info(args.train_name) | ||
logger.info(result_info) | ||
|
||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import math | ||
|
||
def recall_at_k(actual, predicted, topk): | ||
sum_recall = 0.0 | ||
num_users = len(predicted) | ||
true_users = 0 | ||
for i in range(num_users): | ||
act_set = set([actual[i]]) | ||
pred_set = set(predicted[i][:topk]) | ||
if len(act_set) != 0: | ||
sum_recall += len(act_set & pred_set) / float(len(act_set)) | ||
true_users += 1 | ||
return sum_recall / true_users | ||
|
||
def ndcg_k(actual, predicted, topk): | ||
res = 0 | ||
for user_id in range(len(actual)): | ||
k = min(topk, len([actual[user_id]])) | ||
idcg = idcg_k(k) | ||
dcg_k = sum([int(predicted[user_id][j] in | ||
set([actual[user_id]])) / math.log(j+2, 2) for j in range(topk)]) | ||
res += dcg_k / idcg | ||
return res / float(len(actual)) | ||
|
||
# Calculates the ideal discounted cumulative gain at k | ||
def idcg_k(k): | ||
res = sum([1.0/math.log(i+2, 2) for i in range(k)]) | ||
if not res: | ||
return 1.0 | ||
else: | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from model.bsarec import BSARecModel | ||
|
||
MODEL_DICT = { | ||
"bsarec": BSARecModel, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
import torch.nn as nn | ||
from model._modules import LayerNorm | ||
from torch.nn.init import xavier_uniform_ | ||
|
||
class SequentialRecModel(nn.Module): | ||
def __init__(self, args): | ||
super(SequentialRecModel, self).__init__() | ||
self.args = args | ||
self.item_embeddings = nn.Embedding(args.item_size, args.hidden_size, padding_idx=0) | ||
self.position_embeddings = nn.Embedding(args.max_seq_length, args.hidden_size) | ||
self.batch_size = args.batch_size | ||
|
||
def add_position_embedding(self, sequence): | ||
seq_length = sequence.size(1) | ||
position_ids = torch.arange(seq_length, dtype=torch.long, device=sequence.device) | ||
position_ids = position_ids.unsqueeze(0).expand_as(sequence) | ||
item_embeddings = self.item_embeddings(sequence) | ||
position_embeddings = self.position_embeddings(position_ids) | ||
sequence_emb = item_embeddings + position_embeddings | ||
sequence_emb = self.LayerNorm(sequence_emb) | ||
sequence_emb = self.dropout(sequence_emb) | ||
|
||
return sequence_emb | ||
|
||
def init_weights(self, module): | ||
""" Initialize the weights. | ||
""" | ||
if isinstance(module, (nn.Linear, nn.Embedding)): | ||
# Slightly different from the TF version which uses truncated_normal for initialization | ||
# cf https://github.com/pytorch/pytorch/pull/5617 | ||
module.weight.data.normal_(mean=0.0, std=self.args.initializer_range) | ||
elif isinstance(module, LayerNorm): | ||
module.bias.data.zero_() | ||
module.weight.data.fill_(1.0) | ||
elif isinstance(module, nn.GRU): | ||
xavier_uniform_(module.weight_hh_l0) | ||
xavier_uniform_(module.weight_ih_l0) | ||
if isinstance(module, nn.Linear) and module.bias is not None: | ||
module.bias.data.zero_() | ||
|
||
def get_attention_mask(self, item_seq): | ||
"""Generate left-to-right uni-directional attention mask for multi-head attention.""" | ||
|
||
attention_mask = (item_seq > 0).long() | ||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64 | ||
|
||
max_len = attention_mask.size(-1) | ||
attn_shape = (1, max_len, max_len) | ||
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8 | ||
subsequent_mask = (subsequent_mask == 0).unsqueeze(1) | ||
subsequent_mask = subsequent_mask.long().to(item_seq.device) | ||
|
||
extended_attention_mask = extended_attention_mask * subsequent_mask | ||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | ||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | ||
|
||
return extended_attention_mask | ||
|
||
def forward(self, input_ids, all_sequence_output=False): | ||
pass | ||
|
||
def predict(self, input_ids, all_sequence_output=False): | ||
return self.forward(input_ids, all_sequence_output) | ||
|
||
def calculate_loss(self, input_ids, answers): | ||
pass | ||
|
Oops, something went wrong.