-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More flexible ODE integration #2628
Comments
I agree, this would all be useful to have. RE: Auxiliary solve output, see @jacobjinkelly's #2574. Doing this right may require a more comprehensive solution adding side-effects to core JAX. RE: Differentiating through the solver. This would also be great to have, along with other adjoint calculation methods (e.g., as described in Appendix B of https://arxiv.org/pdf/2001.04385.pdf and implemented in DifferentialEquations.jl). Differentiating through the solver could be a little tricky to do in JAX, because we need fixed memory usage in order to compile something with XLA (i.e., |
Wasn't aware of #2574. Def happy to see progress on that front! Although it seems inevitable that more and more auxiliary info will be desirable in the future. If I'm understanding #2574 correctly this would entail breaking API changes any time a new diagnostic output is added.
Out of curiosity, why does XLA require bounded loops in order to do reverse-mode? Other IRs don't seem to have this limitation, eg. Relay which supports reverse-mode gradients through arbitrary loops. I believe their approach is based on https://arxiv.org/pdf/1803.10228.pdf. Julia also doesn't seem to have any issue with this. |
I think the right way to do this is to switch to returning an object with an extensible list of fields in addition to the ODE solution, like SciPy's odeint or (newer) solve_ivp. This would allow for adding more auxiliary fields without breaking code.
XLA has a general requirement that all memory allocation needs to be a statically known based on shapes. There's no dynamic allocation of arrays based on computation results. I don't know the exact reason for this requirement, but I imagine that it makes the compiler's job much easier. |
Mmm, I can see how that would make the compiler much simpler, but it seems quite limiting now. Would it be possible to implement this without |
It's not just about simplicity; being able to statically analyze the shapes and memory requirements enables a ton of optimizations (fusion, layout, remat, etc).
Indeed, just write it with a regular Python while loop and don't jit it! I'm not sure this is something we'd want to maintain in the jax core repo, but it should be easy enough to stand up as a baseline. |
Ok, will do! Could I just reuse the |
OTOH, what prevents XLA from intelligently inferring loop bounds and then applying optimizations where possible? This seems to be fairly standard practice in compilers. |
Yeah, |
Ok, gotcha! |
Re: loop bounds, sure that's possible in some cases (though it's more of a JAX issue than an XLA issue because XLA doesn't do autodiff). Actually that was my intention with this code: if we can specialize on the trip count then we can do reverse-mode autodiff. (It failed some internal test that I can't remember, but I think it should work...) But the point of while_loop is that you might not be able to predict the trip count. Something like an ode integrator, or anything you run to numerical convergence, is a good example: a priori you don't have a bound on the number of iterations. (There are some tricks you can play, like use recursive checkpointing and assume that loop trip counts are always bounded by 2**32 or something, but that's just in the weeds.) |
Ok after chatting with @MarisaKirisame it sounds like the key difference here is that Relay supports ADTs and allocation in the IR. So in their case the AD of a while loop is just another while loop with tracing that runs in the IR. This is made fast thanks to all kinds of compiler tricks within the IR. My understanding is that when loop bounds can be inferred they switch on using optimized, shape-aware compilation. If I'm understanding correctly this means that even an omnipotent JAX would not be able to do AD through while loops when compiling down to XLA. If that's true it seems as though the way forward to fix AD for loops is either
|
RE:Alternate solver choices. Do you mean implicit solvers like the Livermore Solver (LSODE; https://computing.llnl.gov/casc/nsde/pubs/u113855.pdf), which are defacto in scipy and can handle stiff equations? I started a project to write the LSODE for tensorflow and, let me say, it's tricky due to the many decisions made (page 58 of the link). There are a ton of heuristics that make things converge nicely and efficiently. There is the option however of stripping most of that logic out down to the fundamentals of an implicit ODE solver (which is also described fully pages 1-38 in that document). |
@Joshuaalbert Yeah, I say the more options the better! I guess I view these sorts of things as good stress tests for tools like JAX and XLA. That flowchart looks pretty nasty, so it'd definitely be a challenge but I believe it should be do-able. Another option would be to support third-party solvers through FFI the same way scipy does. IIRC JAX has some kind of internal-ish way to do that sort of thing somewhere already. |
An implicit solver for stiff ODEs would definitely be a welcome addition. We have a BDF method in TF-probability that could be a good start. It would also be nice to expose an interface for reusing the ODE gradient definition(s), allowing users to bring their own solvers without needing to write new gradient rules. This would be similar to what we did with |
I'd like to add my voice to those who would like to see differentiating through the solver and other adjoint calculation methods. That would be great. Also, it appears that the current version of experimental.odeint cannot handle 64 bit data. The error message below stems from my call to odeint. Is it possible to alter odeint to handle 64 bit data?
|
@John-Boik Do you have a code sample you could share that reproduces the issue? It looks like the forward solve is not actually using float64 as you requested. |
It was my fault. I was unintentionally passing odeint a matrix specified as float32. |
@John-Boik and others, I would strongly advise against differentiating through an adaptive implict ODE solver like LSODE for performance reasons. Especially for a stiff problem. I would have time in June to help implement LSODE. |
|
@shoyer |
I'd also like to add a feature request for
|
Another feature that I'm realizing would be important
|
I have implemented a BDF solver for stiff ODEs in JAX using TF-probability’s code. The implementation is pretty barebones right now and I still need to test JIT, VMAP. Also, I still need to implement adjoint gradient method like TF-probability’s which I’m planning to do soon. I tested it against SciPy’s VODE methods for stiff chemical kinetics problems and results look pretty good to me. The current implementation is here and I’m planning to open a PR here once I get the adjoint gradient method, JIT and VMAP etc working. Feedback and suggestions are welcome! Thanks for such a great framework. |
To make this work in a way that is compatible with |
Yeah, XLA's allocation limitations present a bit of a challenge here. |
I second this! I think this could be accomplished by using |
btw, I have a rough implementation of some other RK solvers here, in case anyone has a use for them. :) |
Just checking on this thread - we are very interested in getting fast gradients through a stiff ODE solution. Some cool ideas in this thread - are there any updates since last year? |
For those watching this thread -- check out Diffrax, which is a library doing pretty much everything discussed here. Other RK solvers, implicit solvers, discretise-then-optimise, etc. |
The current implementation of Runge-Kutta with adjoint reverse-mode gradients is great, but there are a few things I still find myself missing, and I'd really love to help contribute, or just see in JAX one way or another.
y0
and the "hopefully close" y(t_0) from backtracking through the dynamics in the adjoint ODE.odeint
function with a vjp rule. Being able to select different solvers for the forward and adjoint passes would also be useful. Ideal solution would make arbitrary solver combos a cinch, eg. run RK on the forward pass, but Euler integration for the adjoint.The text was updated successfully, but these errors were encountered: