forked from sheoyon-jhin/CONTIME
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmisc.py
166 lines (126 loc) · 5.84 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import math
import numpy as np
import torch
def cheap_stack(tensors, dim):
if len(tensors) == 1:
return tensors[0].unsqueeze(dim)
else:
return torch.stack(tensors, dim=dim)
def tridiagonal_solve(b, A_upper, A_diagonal, A_lower):
"""Solves a tridiagonal system Ax = b.
The arguments A_upper, A_digonal, A_lower correspond to the three diagonals of A. Letting U = A_upper, D=A_digonal
and L = A_lower, and assuming for simplicity that there are no batch dimensions, then the matrix A is assumed to be
of size (k, k), with entries:
D[0] U[0]
L[0] D[1] U[1]
L[1] D[2] U[2] 0
L[2] D[3] U[3]
. . .
. . .
. . .
L[k - 3] D[k - 2] U[k - 2]
0 L[k - 2] D[k - 1] U[k - 1]
L[k - 1] D[k]
Arguments:
b: A tensor of shape (..., k), where '...' is zero or more batch dimensions
A_upper: A tensor of shape (..., k - 1).
A_diagonal: A tensor of shape (..., k).
A_lower: A tensor of shape (..., k - 1).
Returns:
A tensor of shape (..., k), corresponding to the x solving Ax = b
Warning:
This implementation isn't super fast. You probably want to cache the result, if possible.
"""
# This implementation is very much written for clarity rather than speed.
A_upper, _ = torch.broadcast_tensors(A_upper, b[..., :-1])
A_lower, _ = torch.broadcast_tensors(A_lower, b[..., :-1])
A_diagonal, b = torch.broadcast_tensors(A_diagonal, b)
channels = b.size(-1)
new_b = np.empty(channels, dtype=object)
new_A_diagonal = np.empty(channels, dtype=object)
outs = np.empty(channels, dtype=object)
new_b[0] = b[..., 0]
new_A_diagonal[0] = A_diagonal[..., 0]
for i in range(1, channels):
w = A_lower[..., i - 1] / new_A_diagonal[i - 1]
new_A_diagonal[i] = A_diagonal[..., i] - w * A_upper[..., i - 1]
new_b[i] = b[..., i] - w * new_b[i - 1]
outs[channels - 1] = new_b[channels - 1] / new_A_diagonal[channels - 1]
for i in range(channels - 2, -1, -1):
outs[i] = (new_b[i] - A_upper[..., i] * outs[i + 1]) / new_A_diagonal[i]
return torch.stack(outs.tolist(), dim=-1)
def validate_input_path(x, t):
if not x.is_floating_point():
raise ValueError("X must both be floating point.")
if x.ndimension() < 2:
raise ValueError("X must have at least two dimensions, corresponding to time and channels. It instead has "
"shape {}.".format(tuple(x.shape)))
if t is None:
t = torch.linspace(0, x.size(-2) - 1, x.size(-2), dtype=x.dtype, device=x.device)
if not t.is_floating_point():
raise ValueError("t must both be floating point.")
if len(t.shape) != 1:
raise ValueError("t must be one dimensional. It instead has shape {}.".format(tuple(t.shape)))
prev_t_i = -math.inf
for t_i in t:
if t_i <= prev_t_i:
raise ValueError("t must be monotonically increasing.")
prev_t_i = t_i
if x.size(-2) != t.size(0):
raise ValueError("The time dimension of X must equal the length of t. X has shape {} and t has shape {}, "
"corresponding to time dimensions of {} and {} respectively."
.format(tuple(x.shape), tuple(t.shape), x.size(-2), t.size(0)))
if t.size(0) < 2:
raise ValueError("Must have a time dimension of size at least 2. It instead has shape {}, corresponding to a "
"time dimension of size {}.".format(tuple(t.shape), t.size(0)))
return t
def forward_fill(x, fill_index=-2):
"""Forward fills data in a torch tensor of shape (..., length, input_channels) along the length dim.
Arguments:
x: tensor of values with first channel index being time, of shape (..., length, input_channels), where ... is
some number of batch dimensions.
fill_index: int that denotes the index to fill down. Default is -2 as we tend to use the convention (...,
length, input_channels) filling down the length dimension.
Returns:
A tensor with forward filled data.
"""
# Checks
assert isinstance(x, torch.Tensor)
assert x.dim() >= 2
mask = torch.isnan(x)
if mask.any():
cumsum_mask = (~mask).cumsum(dim=fill_index)
cumsum_mask[mask] = 0
_, index = cumsum_mask.cummax(dim=fill_index)
x = x.gather(dim=fill_index, index=index)
return x
class TupleControl(torch.nn.Module):
def __init__(self, *controls):
super(TupleControl, self).__init__()
if len(controls) == 0:
raise ValueError("Expected one or more controls to batch together.")
self._interval = controls[0].interval
grid_points = controls[0].grid_points
same_grid_points = True
for control in controls[1:]:
if (control.interval != self._interval).any():
raise ValueError("Can only batch togehter controls over the same interval.")
if same_grid_points and (control.grid_points != grid_points).any():
same_grid_points = False
if same_grid_points:
self._grid_points = grid_points
else:
self._grid_points = None
self.controls = torch.nn.ModuleList(controls)
@property
def interval(self):
return self._interval
@property
def grid_points(self):
if self._grid_points is None:
raise RuntimeError("Batch of controls have different grid points.")
return self._grid_points
def evaluate(self, t):
return tuple(control.evaluate(t) for control in self.controls)
def derivative(self, t):
return tuple(control.derivative(t) for control in self.controls)