Skip to content

Commit

Permalink
Create BaseModel.py
Browse files Browse the repository at this point in the history
  • Loading branch information
shuyuan-x authored Aug 28, 2023
1 parent 2267d73 commit c95ce14
Showing 1 changed file with 311 additions and 0 deletions.
311 changes: 311 additions & 0 deletions src/models/BaseModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
# coding=utf-8

import torch
import logging
from sklearn.metrics import *
import numpy as np
import torch.nn.functional as F
import os
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
from utils.rank_metrics import *
from utils.global_p import *
from utils import utils




class BaseModel(torch.nn.Module):
"""
Base Model, need to overload following functions for new models:
parse_model_args,
__init__,
_init_weights,
predict,
forward,
"""

include_id = True
include_user_features = True
include_item_features = True
include_context_features = True
data_loader = 'DataLoader' # Default data_loader
data_processor = 'DataProcessor' # Default data_processor
runner = 'BaseRunner' # Default runner

@staticmethod
def parse_model_args(parser, model_name='BaseModel'):
"""
command parameters
:param parser:
:param model_name: model name
:return:
"""
parser.add_argument('--loss_sum', type=int, default=1,
help='Reduction of batch loss 1=sum, 0=mean')
parser.add_argument('--model_path', type=str,
default=os.path.join(MODEL_DIR, '%s/%s.pt' % (model_name, model_name)),
help='Model save path.')
return parser

@staticmethod
def evaluate_method(p, data, metrics, error_skip=False):
"""
calculate evaluation values.
:param p: prediction values,np.array,generated by runner.predict
:param data: data dict,generated by DataProcessor
:param metrics: metrics list,usually same as runner.metrics,e.g. ['rmse', 'auc']
:return:
"""
l = data[Y]
evaluations = []
rank = False
for metric in metrics:
if '@' in metric:
rank = True

split_l, split_p, split_l_sum = None, None, None
if rank:
uids, times = data[UID].reshape([-1]), data[TIME].reshape([-1])
if TIME in data:
if len(np.unique(uids)) < len(np.unique(times)):
sorted_idx = np.lexsort((-l, -p, uids))
sorted_uid = uids[sorted_idx]
sorted_key, sorted_spl = np.unique(sorted_uid, return_index=True)
else:
sorted_idx = np.lexsort((-l, -p, times, uids))
sorted_uid, sorted_time = uids[sorted_idx], times[sorted_idx]
sorted_key, sorted_spl = np.unique([sorted_uid, sorted_time], axis=1, return_index=True)
else:
sorted_idx = np.lexsort((-l, -p, uids))
sorted_uid = uids[sorted_idx]
sorted_key, sorted_spl = np.unique(sorted_uid, return_index=True)
sorted_l, sorted_p = l[sorted_idx], p[sorted_idx]
split_l, split_p = np.split(sorted_l, sorted_spl[1:]), np.split(sorted_p, sorted_spl[1:])
split_l_sum = [np.sum((d > 0).astype(float)) for d in split_l]

for metric in metrics:
try:
if metric == 'rmse':
evaluations.append(np.sqrt(mean_squared_error(l, p)))
elif metric == 'mae':
evaluations.append(mean_absolute_error(l, p))
elif metric == 'auc':
evaluations.append(roc_auc_score(l, p))
elif metric == 'f1':
evaluations.append(f1_score(l, np.around(p)))
elif metric == 'accuracy':
evaluations.append(accuracy_score(l, np.around(p)))
elif metric == 'precision':
evaluations.append(precision_score(l, np.around(p)))
elif metric == 'recall':
evaluations.append(recall_score(l, np.around(p)))
else:
k = int(metric.split('@')[-1])
if metric.startswith('ndcg@'):
max_k = max([len(d) for d in split_l])
k_data = np.array([(list(d) + [0] * max_k)[:max_k] for d in split_l])
best_rank = -np.sort(-k_data, axis=1)
best_dcg = np.sum(best_rank[:, :k] / np.log2(np.arange(2, k + 2)), axis=1)
best_dcg[best_dcg == 0] = 1
dcg = np.sum(k_data[:, :k] / np.log2(np.arange(2, k + 2)), axis=1)
ndcgs = dcg / best_dcg
evaluations.append(np.average(ndcgs))

# k_data = np.array([(list(d) + [0] * k)[:k] for d in split_l])
# best_rank = -np.sort(-k_data, axis=1)
# best_dcg = np.sum(best_rank / np.log2(np.arange(2, k + 2)), axis=1)
# best_dcg[best_dcg == 0] = 1
# dcg = np.sum(k_data / np.log2(np.arange(2, k + 2)), axis=1)
# ndcgs = dcg / best_dcg
# evaluations.append(np.average(ndcgs))
elif metric.startswith('hit@'):
k_data = np.array([(list(d) + [0] * k)[:k] for d in split_l])
hits = (np.sum((k_data > 0).astype(float), axis=1) > 0).astype(float)
evaluations.append(np.average(hits))
elif metric.startswith('precision@'):
k_data = [d[:k] for d in split_l]
k_data_dict = defaultdict(list)
for d in k_data:
k_data_dict[len(d)].append(d)
precisions = [np.average((np.array(d) > 0).astype(float), axis=1) for d in k_data_dict.values()]
evaluations.append(np.average(np.concatenate(precisions)))
elif metric.startswith('recall@'):
k_data = np.array([(list(d) + [0] * k)[:k] for d in split_l])
recalls = np.sum((k_data > 0).astype(float), axis=1) / split_l_sum
evaluations.append(np.average(recalls))
except Exception as e:
if error_skip:
evaluations.append(-1)
else:
raise e
return evaluations

def __init__(self, label_min, label_max, feature_num, loss_sum, l2_bias, random_seed, model_path):
super(BaseModel, self).__init__()
self.label_min = label_min
self.label_max = label_max
self.feature_num = feature_num
self.loss_sum = loss_sum
self.l2_bias = l2_bias
self.random_seed = random_seed
torch.manual_seed(self.random_seed)
torch.cuda.manual_seed(self.random_seed)
self.model_path = model_path

self._init_weights()
logging.debug(list(self.parameters()))

self.total_parameters = self.count_variables()
logging.info('# of params: %d' % self.total_parameters)

# optimizer is generated by runner
self.optimizer = None

def _init_weights(self):
"""
initialize required weights (including weight layer)
:return:
"""
self.x_bn = torch.nn.BatchNorm1d(self.feature_num)
self.prediction = torch.nn.Linear(self.feature_num, 1)
self.l2_embeddings = []

def count_variables(self):
"""
count number of model variables
:return:
"""
total_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
return total_parameters

def init_paras(self, m):
"""
initialize function, will be called in main.py
:param m: parameters
:return:
"""
if 'Linear' in str(type(m)):
torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
if m.bias is not None:
torch.nn.init.normal_(m.bias, mean=0.0, std=0.01)
elif 'Embedding' in str(type(m)):
torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)

def l2(self, out_dict):
"""
calculate l2 of model, default is square sum of all variables (exclude embedding)
L2 of embedding only contain embeddings of current batch
:return:
"""
l2 = utils.numpy_to_torch(np.array(0.0, dtype=np.float32), gpu=True)
for name, p in self.named_parameters():
if not p.requires_grad:
continue
if self.l2_bias == 0 and 'bias' in name:
continue
if name.split('.')[0] in self.l2_embeddings:
continue
l2 += (p ** 2).sum()
b_l2 = utils.numpy_to_torch(np.array(0.0, dtype=np.float32), gpu=True)
for p in out_dict[EMBEDDING_L2]:
b_l2 += (p ** 2).sum()
if self.loss_sum == 0:
l2_batch = out_dict[TOTAL_BATCH_SIZE] if L2_BATCH not in out_dict else out_dict[L2_BATCH]
b_l2 = b_l2 / l2_batch
return l2 + b_l2

def predict(self, feed_dict):
"""
predict only, without loss
:param feed_dict: input dict
:return: output,a dict,prediction represents prediction values,check represents intermediate results.
"""
check_list = []
x = self.x_bn(feed_dict[X].float())
x = torch.nn.Dropout(p=feed_dict[DROPOUT])(x)
prediction = F.relu(self.prediction(x)).view([-1])
out_dict = {PREDICTION: prediction,
CHECK: check_list}
return out_dict

def forward(self, feed_dict):
"""
calculate preduiction and loss
:param feed_dict: input dict
:return: output,a dict,prediction represents prediction values,check represents intermediate results,loss represents loss values.
"""
out_dict = self.predict(feed_dict)
if feed_dict[RANK] == 1:
# calculate loss based on topn recommendation task, first half are positive samples and second half are negative samples.
loss = self.rank_loss(out_dict[PREDICTION], feed_dict[Y], feed_dict[REAL_BATCH_SIZE])
else:
# calculate loss based on rating task, default mse.
if self.loss_sum == 1:
loss = torch.nn.MSELoss(reduction='sum')(out_dict[PREDICTION], feed_dict[Y])
else:
loss = torch.nn.MSELoss(reduction='mean')(out_dict[PREDICTION], feed_dict[Y])
out_dict[LOSS] = loss
out_dict[LOSS_L2] = self.l2(out_dict)
return out_dict

def rank_loss(self, prediction, label, real_batch_size):
'''
calculate rank loss,similar to BPR-max,ref:
@inproceedings{hidasi2018recurrent,
title={Recurrent neural networks with top-k gains for session-based recommendations},
author={Hidasi, Bal{\'a}zs and Karatzoglou, Alexandros},
booktitle={Proceedings of the 27th ACM International Conference on Information and Knowledge Management},
pages={843--852},
year={2018},
organization={ACM}
}
:param prediction: prediction values [None]
:param label: label [None]
:param real_batch_size: batch size of observation values, excludes negative sampling.
:return:
'''
pos_neg_tag = (label - 0.5) * 2
observed, sample = prediction[:real_batch_size], prediction[real_batch_size:]
# sample = sample.view([-1, real_batch_size]).mean(dim=0)
sample = sample.view([-1, real_batch_size])
sample_softmax = (sample * pos_neg_tag.view([1, real_batch_size])).softmax(dim=0)
sample = (sample * sample_softmax).sum(dim=0)
if self.loss_sum == 1:
loss = -(pos_neg_tag * (observed - sample)).sigmoid().log().sum()
else:
loss = -(pos_neg_tag * (observed - sample)).sigmoid().log().mean()
return loss

def lrp(self):
pass

def save_model(self, model_path=None):
"""
save model, default path usually.
:param model_path: assigned model path
:return:
"""
if model_path is None:
model_path = self.model_path
dir_path = os.path.dirname(model_path)
if not os.path.exists(dir_path):
os.mkdir(dir_path)
torch.save(self.state_dict(), model_path)
logging.info('Save model to ' + model_path)

def load_model(self, model_path=None, cpu=False):
"""
load model, default path usually
:param model_path: assigned model path
:return:
"""
if model_path is None:
model_path = self.model_path
if cpu:
self.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
else:
self.load_state_dict(torch.load(model_path))
self.eval()
logging.info('Load model from ' + model_path)

0 comments on commit c95ce14

Please sign in to comment.