forked from patrick-kidger/diffrax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolution.py
171 lines (134 loc) · 6.66 KB
/
solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from dataclasses import field
from typing import Any, Dict, Optional
import equinox.internal as eqxi
import jax
from .custom_types import Array, Bool, PyTree, Scalar
from .global_interpolation import DenseInterpolation
from .misc import static_select
from .path import AbstractPath
class RESULTS(metaclass=eqxi.ContainerMeta):
successful = ""
discrete_terminating_event_occurred = (
"Terminating solve because a discrete event occurred."
)
max_steps_reached = (
"The maximum number of solver steps was reached. Try increasing `max_steps`."
)
dt_min_reached = "The minimum step size was reached."
implicit_divergence = "Implicit method diverged."
implicit_nonconvergence = (
"Implicit method did not converge within the required number of iterations."
)
def is_okay(result: RESULTS) -> Bool:
with jax.ensure_compile_time_eval():
return is_successful(result) | is_event(result)
def is_successful(result: RESULTS) -> Bool:
with jax.ensure_compile_time_eval():
return result == RESULTS.successful
# TODO: In the future we may support other event types, in which case this function
# should be updated.
def is_event(result: RESULTS) -> Bool:
with jax.ensure_compile_time_eval():
return result == RESULTS.discrete_terminating_event_occurred
def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS:
"""
Returns:
old | success event_o error_o
new |
--------+-------------------------
success | success event_o error_o
event_n | event_n event_o error_o
error_n | error_n error_n error_o
"""
with jax.ensure_compile_time_eval():
out_result = static_select(is_okay(old_result), new_result, old_result)
return static_select(
is_okay(new_result) & is_event(old_result), old_result, out_result
)
class Solution(AbstractPath):
"""The solution to a differential equation.
**Attributes:**
- `t0`: The start of the interval that the differential equation was solved over.
- `t1`: The end of the interval that the differential equation was solved over.
- `ts`: Some ordered collection of times. Might be `None` if no values were saved.
(i.e. just `diffeqsolve(..., saveat=SaveAt(dense=True))` is used.)
- `ys`: The value of the solution at each of the times in `ts`. Might `None` if no
values were saved.
- `stats`: Statistics for the solve (number of steps etc.).
- `result`: Integer specifying the success or cause of failure of the solve. A
value of `0` corresponds to a successful solve. Any other value is a failure.
A human-readable message can be obtained by looking up messages via
`diffrax.RESULTS[<integer>]`.
- `solver_state`: If saved, the final internal state of the numerical solver.
- `controller_state`: If saved, the final internal state for the step size
controller.
- `made_jump`: If saved, the final internal state for the jump tracker.
!!! note
If `diffeqsolve(..., saveat=SaveAt(steps=True))` is set, then the `ts` and `ys`
in the solution object will be padded with `NaN`s, out to the value of
`max_steps` passed to [`diffrax.diffeqsolve`][].
This is because JAX demands that shapes be known statically ahead-of-time. As
we do not know how many steps we will take until the solve is performed, we
must allocate enough space for the maximum possible number of steps.
"""
t0: Scalar = field(init=True, repr=True)
t1: Scalar = field(init=True, repr=True) # override AbstractPath
ts: Optional[Array["times"]] # noqa: F821
ys: Optional[PyTree["times", ...]] # noqa: F821
interpolation: Optional[DenseInterpolation]
stats: Dict[str, Any]
result: RESULTS
solver_state: Optional[PyTree]
controller_state: Optional[PyTree]
made_jump: Optional[Bool]
def evaluate(
self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True
) -> PyTree:
"""If dense output was saved, then evaluate the solution at any point in the
region of integration `self.t0` to `self.t1`.
**Arguments:**
- `t0`: The point to evaluate the solution at.
- `t1`: If passed, then the increment from `t0` to `t1` is returned.
(`=evaluate(t1) - evaluate(t0)`)
- `left`: When evaluating at a jump in the solution, whether to return the
left-limit or the right-limit at that point.
"""
if self.interpolation is None:
raise ValueError(
"Dense solution has not been saved; pass SaveAt(dense=True)."
)
return self.interpolation.evaluate(t0, t1, left)
def derivative(self, t: Scalar, left: bool = True) -> PyTree:
r"""If dense output was saved, then calculate an **approximation** to the
derivative of the solution at any point in the region of integration `self.t0`
to `self.t1`.
That is, letting $y$ denote the solution over the interval `[t0, t1]`, then
this calculates an approximation to $\frac{\mathrm{d}y}{\mathrm{d}t}$.
(This is *not* backpropagating through the differential equation -- that
typically corresponds to e.g. $\frac{\mathrm{d}y(t_1)}{\mathrm{d}y(t_0)}$.)
!!! example
For an ODE satisfying
$\frac{\mathrm{d}y}{\mathrm{d}t} = f(t, y(t))$
then this value is approximately equal to $f(t, y(t))$.
!!! warning
This value is generally not very accurate. Differential equation solvers
are usually designed to produce splines whose value is close to the true
solution; not to produce splines whose derivative is close to the
derivative of the true solution.
If you need accurate derivatives for the solution of an ODE, it is usually
best to calculate `vector_field(t, sol.evaluate(t), args)`. That is, to
pay the extra computational cost of another vector field evaluation, in
order to get a more accurate value.
Put precisely: this `derivative` method returns the *derivative of the
numerical solution*, and *not* an approximation to the derivative of the
true solution.
**Arguments:**
- `t`: The point to calculate the derivative of the solution at.
- `left`: When evaluating at a jump in the solution, whether to return the
left-limit or the right-limit at that point.
"""
if self.interpolation is None:
raise ValueError(
"Dense solution has not been saved; pass SaveAt(dense=True)."
)
return self.interpolation.derivative(t, left)