This repository has been archived by the owner on Mar 8, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
65 lines (50 loc) · 1.93 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torch import nn
class SpatioTemporalEmbedding(nn.Module):
def __init__(self):
super(SpatioTemporalEmbedding, self).__init__()
self.day_of_week_embedding = nn.Embedding(7, 32)
self.time_of_day_embedding = nn.Embedding(48, 32)
self.location_x_embedding = nn.Embedding(201, 128)
self.location_y_embedding = nn.Embedding(201, 128)
def forward(self, day, time, location_x, location_y):
day_embed = self.day_of_week_embedding(day)
time_embed = self.time_of_day_embedding(time)
location_x_embed = self.location_x_embedding(location_x)
location_y_embed = self.location_y_embedding(location_y)
embed = torch.cat((day_embed, time_embed, location_x_embed, location_y_embed), dim=-1)
return embed
class Encoding(nn.Module):
def __init__(self):
super(Encoding, self).__init__()
self.gru = nn.GRU(input_size=32+32+128+128, hidden_size=256, num_layers=4, dropout=0.1)
def forward(self, input):
if input.dim() == 3:
input = input.permute(1, 0, 2)
out, h = self.gru(input)
return h
class Decoding(nn.Module):
def __init__(self):
super(Decoding, self).__init__()
self.gru = nn.GRU(input_size=32+32+128+128, hidden_size=256, num_layers=4, dropout=0.1)
self.mlp1 = nn.Sequential(
nn.Linear(256, 1024),
nn.Dropout(0.1),
nn.ReLU(),
nn.Linear(1024, 200)
)
self.mlp2 = nn.Sequential(
nn.Linear(256, 1024),
nn.Dropout(0.1),
nn.ReLU(),
nn.Linear(1024, 200)
)
def forward(self, input, h0):
if input.dim() == 3:
input = input.permute(1, 0, 2)
out, h = self.gru(input, h0)
if out.dim() == 3:
out = out.permute(1, 0, 2)
pred_x = self.mlp1(out)
pred_y = self.mlp2(out)
return pred_x, pred_y, h