forked from yehjin-shin/BSARec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
81 lines (63 loc) · 3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import torch
import numpy as np
from model import MODEL_DICT
from trainers import Trainer
from utils import EarlyStopping, check_path, set_seed, parse_args, set_logger
from dataset import get_seq_dic, get_dataloder, get_rating_matrix
def main():
args = parse_args()
log_path = os.path.join(args.output_dir, args.train_name + '.log')
logger = set_logger(log_path)
set_seed(args.seed)
check_path(args.output_dir)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
args.cuda_condition = torch.cuda.is_available() and not args.no_cuda
args.checkpoint_path = os.path.join(args.output_dir, args.train_name + '.pt')
args.same_target_path = os.path.join(args.data_dir, args.data_name+'_same_target.npy')
args.data_file = args.data_dir + 'train_' + args.data_name + '.txt'
train_seq_dic, train_max_item, num_users = get_seq_dic(args)
args.item_size = train_max_item + 1
args.num_users = num_users + 1
train_dataloader = get_dataloder(args, train_seq_dic, 'train')
args.data_file = args.data_dir + 'val_' + args.data_name + '.txt'
val_seq_dic, max_item, num_users = get_seq_dic(args)
args.item_size = max_item + 1
args.num_users = num_users + 1
eval_dataloader = get_dataloder(args, val_seq_dic, 'val')
args.data_file = args.data_dir + 'test_' + args.data_name + '.txt'
test_seq_dic, max_item, num_users = get_seq_dic(args)
args.item_size = max_item + 1
args.num_users = num_users + 1
test_dataloader = get_dataloder(args, test_seq_dic, 'test')
logger.info(str(args))
logger.info(args.data_file)
model = MODEL_DICT[args.model_type.lower()](args=args)
logger.info(model)
trainer = Trainer(model, train_dataloader, eval_dataloader, test_dataloader, args, logger)
# args.valid_rating_matrix, args.test_rating_matrix = get_rating_matrix(args.data_name, val_seq_dic, train_max_item), get_rating_matrix(args.data_name, test_seq_dic, train_max_item)
if args.do_eval:
if args.load_model is None:
logger.info(f"No model input!")
exit(0)
else:
args.checkpoint_path = os.path.join(args.output_dir, args.load_model + '.pt')
trainer.load(args.checkpoint_path)
logger.info(f"Load model from {args.checkpoint_path} for test!")
scores, result_info = trainer.test(0)
else:
early_stopping = EarlyStopping(args.checkpoint_path, logger=logger, patience=args.patience, verbose=True)
for epoch in range(args.epochs):
trainer.train(epoch)
scores, _ = trainer.valid(epoch)
# evaluate on MRR
early_stopping(np.array(scores[-1:]), trainer.model)
if early_stopping.early_stop:
logger.info("Early stopping")
break
logger.info("---------------Test Score---------------")
trainer.model.load_state_dict(torch.load(args.checkpoint_path))
scores, result_info = trainer.test(0)
logger.info(args.train_name)
logger.info(result_info)
main()