forked from sheoyon-jhin/CONTIME
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolver.py
70 lines (47 loc) · 1.98 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torchdiffeq
import warnings
import numpy as np
class _ContinuousDelayField(torch.nn.Module):
def __init__(self, X, func,device):
super(_ContinuousDelayField, self).__init__()
self.X = X
self.func = func
def forward(self, t, inputz):
h,z= inputz # h : 256,49 z : 256,55
vector_field,out = self.func(t,z,self.X)
return (vector_field,out)
f = forward
def g(self, t, z):
return torch.zeros_like(z).unsqueeze(-1)
def contint_delay(X, func, z0,h0, t,device, adjoint=True, backend="torchdiffeq", **kwargs):
# Reduce the default values for the tolerances because CDEs are difficult to solve with the default high tolerances.
if 'atol' not in kwargs:
kwargs['atol'] = 1e-6
if 'rtol' not in kwargs:
kwargs['rtol'] = 1e-4
if adjoint:
if "adjoint_atol" not in kwargs:
kwargs["adjoint_atol"] = kwargs["atol"]
if "adjoint_rtol" not in kwargs:
kwargs["adjoint_rtol"] = kwargs["rtol"]
if 'method' not in kwargs:
kwargs['method'] = 'rk4'
if kwargs['method'] == 'rk4':
if 'options' not in kwargs:
kwargs['options'] = {}
options = kwargs['options']
if 'step_size' not in options and 'grid_constructor' not in options:
time_diffs = 1.0 # 0.5
options['step_size'] = time_diffs
vector_field = _ContinuousDelayField(X=X, func=func,device=device)
if backend == "torchdiffeq":
odeint = torchdiffeq.odeint_adjoint if adjoint else torchdiffeq.odeint
out_h,out_z = odeint(func=vector_field, y0=(h0,z0), t=t, **kwargs)
else:
raise ValueError(f"Unrecognised backend={backend}")
batch_dims = range(1, len(out_h.shape) - 1)
out_h = out_h.permute(*batch_dims, 0, -1)
batch_dims = range(1, len(out_z.shape) - 1)
out_z = out_z.permute(*batch_dims, 0, -1)
return out_h,out_z