This repository has been archived by the owner on Mar 8, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
val_task1.py
86 lines (66 loc) · 3.89 KB
/
val_task1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import argparse
import json
from tqdm import tqdm
import torch
from dataset import *
from model import *
def task1(st_embedding_pth, encoding_pth, decoding_pth, device):
result_path = os.path.join('result', 'task1')
os.makedirs(result_path, exist_ok=True)
task1_val_dataset = Task1ValDataset('./data/task1_dataset_kotae.csv')
st_embedding = SpatioTemporalEmbedding().to(device)
encoding = Encoding().to(device)
decoding = Decoding().to(device)
st_embedding.load_state_dict(torch.load(st_embedding_pth, map_location=device))
encoding.load_state_dict(torch.load(encoding_pth, map_location=device))
decoding.load_state_dict(torch.load(decoding_pth, map_location=device))
result = dict()
result['generated'] = []
result['reference'] = []
st_embedding.eval()
encoding.eval()
decoding.eval()
with torch.no_grad():
for data in tqdm(task1_val_dataset):
encoding_day_of_week = data['encoding_day_of_week'].to(device)
encoding_time_of_day = data['encoding_time_of_day'].to(device)
encoding_location_x = data['encoding_location_x'].to(device)
encoding_location_y = data['encoding_location_y'].to(device)
decoding_day = data['decoding_day'].to(device)
decoding_time_of_day = data['decoding_time_of_day'].to(device)
label_location_x = data['label_location_x'].to(device)
label_location_y = data['label_location_y'].to(device)
encoding_input_embed = st_embedding(encoding_day_of_week, encoding_time_of_day, encoding_location_x, encoding_location_y)
h = encoding(encoding_input_embed)
pred_len = decoding_day.size(0)
pred_x_array = torch.zeros((pred_len, ), dtype=torch.int64, device=device)
pred_y_array = torch.zeros((pred_len, ), dtype=torch.int64, device=device)
for pred_step in range(pred_len):
decoding_input_day_of_week = decoding_day[pred_step].unsqueeze(0) % 7
decoding_input_time_of_day = decoding_time_of_day[pred_step].unsqueeze(0)
if pred_step == 0:
decoding_input_location_x = torch.zeros((1, ), dtype=torch.int64, device=device)
decoding_input_location_y = torch.zeros((1, ), dtype=torch.int64, device=device)
else:
decoding_input_location_x = pred_x_array[pred_step-1].unsqueeze(0)
decoding_input_location_y = pred_y_array[pred_step-1].unsqueeze(0)
decoding_input_embed = st_embedding(decoding_input_day_of_week, decoding_input_time_of_day, decoding_input_location_x, decoding_input_location_y)
pred_x, pred_y, h = decoding(decoding_input_embed, h)
pred_x_array[pred_step] = torch.argmax(pred_x, dim=-1) + 1
pred_y_array[pred_step] = torch.argmax(pred_y, dim=-1) + 1
generated = torch.stack((decoding_day, decoding_time_of_day, pred_x_array, pred_y_array), dim=-1).cpu().tolist()
reference = torch.stack((decoding_day, decoding_time_of_day, label_location_x, label_location_y), dim=-1).cpu().tolist()
result['generated'].append(generated)
result['reference'].append(reference)
with open(os.path.join(result_path, f'result.json'), 'w') as file:
json.dump(result, file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--st_embedding_pth', type=str, default='./checkpoint/task1/st_embedding.pth')
parser.add_argument('--encoding_pth', type=str, default='./checkpoint/task1/encoding.pth')
parser.add_argument('--decoding_pth', type=str, default='./checkpoint/task1/decoding.pth')
parser.add_argument('--cuda', type=int, default=0)
args = parser.parse_args()
device = torch.device(f'cuda:{args.cuda}')
task1(args.st_embedding_pth, args.encoding_pth, args.decoding_pth, device)