Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible ODE integration #2628

Open
3 tasks
samuela opened this issue Apr 7, 2020 · 29 comments
Open
3 tasks

More flexible ODE integration #2628

samuela opened this issue Apr 7, 2020 · 29 comments

Comments

@samuela
Copy link
Contributor

samuela commented Apr 7, 2020

The current implementation of Runge-Kutta with adjoint reverse-mode gradients is great, but there are a few things I still find myself missing, and I'd really love to help contribute, or just see in JAX one way or another.

  • Auxiliary solver output. Number of function evaluations, location of time steps, diagnostic information etc. Such output is useful both from the forward and adjoint solves. One especially interesting metric for me at the moment is the difference between the initial y0 and the "hopefully close" y(t_0) from backtracking through the dynamics in the adjoint ODE.
  • Differentiating through the solver, aka lame gradients. I know that this is often inferior to solving the reverse-time adjoint ODE, but for the sake of comparison it's an essential baseline. I'm assuming this shouldn't be too hard?
  • Alternate solver choices. The scipy library does a good job exposing multiple solver options to the user. I'm not sure that their API is the cleanest approach but having the ability to choose between types of solvers would be great. Being locked-in to RK can be annoying. I envision a design that includes a set of raw solvers defined with a unified API which can then be bundled up into a usable odeint function with a vjp rule. Being able to select different solvers for the forward and adjoint passes would also be useful. Ideal solution would make arbitrary solver combos a cinch, eg. run RK on the forward pass, but Euler integration for the adjoint.
@shoyer
Copy link
Collaborator

shoyer commented Apr 7, 2020

I agree, this would all be useful to have.

RE: Auxiliary solve output, see @jacobjinkelly's #2574. Doing this right may require a more comprehensive solution adding side-effects to core JAX.

RE: Differentiating through the solver. This would also be great to have, along with other adjoint calculation methods (e.g., as described in Appendix B of https://arxiv.org/pdf/2001.04385.pdf and implemented in DifferentialEquations.jl).

Differentiating through the solver could be a little tricky to do in JAX, because we need fixed memory usage in order to compile something with XLA (i.e., jax.jit). This could be achieved if we have some sort of maxiter parameter and used scan instead of while_loop, but a naive implementation would be extremely memory intensive. To make this practical, I think we would need some form of gradient checkpointing (also useful for other adjoint methods).

@samuela
Copy link
Contributor Author

samuela commented Apr 7, 2020

RE: Auxiliary solve output, see @jacobjinkelly's #2574. Doing this right may require a more comprehensive solution adding side-effects to core JAX.

Wasn't aware of #2574. Def happy to see progress on that front! Although it seems inevitable that more and more auxiliary info will be desirable in the future. If I'm understanding #2574 correctly this would entail breaking API changes any time a new diagnostic output is added.

RE: Differentiating through the solver. This would also be great to have, along with other adjoint calculation methods (e.g., as described in Appendix B of https://arxiv.org/pdf/2001.04385.pdf and implemented in DifferentialEquations.jl).

Differentiating through the solver could be a little tricky to do in JAX, because we need fixed memory usage in order to compile something with XLA (i.e., jax.jit). This could be achieved if we have some sort of maxiter parameter and used scan instead of while_loop, but a naive implementation would be extremely memory intensive. To make this practical, I think we would need some form of gradient checkpointing (also useful for other adjoint methods).

Out of curiosity, why does XLA require bounded loops in order to do reverse-mode? Other IRs don't seem to have this limitation, eg. Relay which supports reverse-mode gradients through arbitrary loops. I believe their approach is based on https://arxiv.org/pdf/1803.10228.pdf. Julia also doesn't seem to have any issue with this.

@shoyer
Copy link
Collaborator

shoyer commented Apr 7, 2020

If I'm understanding #2574 correctly this would entail breaking API changes any time a new diagnostic output is added.

I think the right way to do this is to switch to returning an object with an extensible list of fields in addition to the ODE solution, like SciPy's odeint or (newer) solve_ivp. This would allow for adding more auxiliary fields without breaking code.

Out of curiosity, why does XLA require bounded loops in order to do reverse-mode?

XLA has a general requirement that all memory allocation needs to be a statically known based on shapes. There's no dynamic allocation of arrays based on computation results. I don't know the exact reason for this requirement, but I imagine that it makes the compiler's job much easier.

@samuela
Copy link
Contributor Author

samuela commented Apr 7, 2020

XLA has a general requirement that all memory allocation needs to be a statically known based on shapes. There's no dynamic allocation of arrays based on computation results. I don't know the exact reason for this requirement, but I imagine that it makes the compiler's job much easier.

Mmm, I can see how that would make the compiler much simpler, but it seems quite limiting now. Would it be possible to implement this without jit to circumvent the peculiarities of XLA? Doing so would be slower to be sure, but it would be better than nothing.

@mattjj
Copy link
Collaborator

mattjj commented Apr 7, 2020

Mmm, I can see how that would make the compiler much simpler

It's not just about simplicity; being able to statically analyze the shapes and memory requirements enables a ton of optimizations (fusion, layout, remat, etc).

Would it be possible to implement this without jit to circumvent the peculiarities of XLA? Doing so would be slower to be sure, but it would be better than nothing.

Indeed, just write it with a regular Python while loop and don't jit it! I'm not sure this is something we'd want to maintain in the jax core repo, but it should be easy enough to stand up as a baseline.

@samuela
Copy link
Contributor Author

samuela commented Apr 7, 2020

Indeed, just write it with a regular Python while loop and don't jit it! I'm not sure this is something we'd want to maintain in the jax core repo, but it should be easy enough to stand up as a baseline.

Ok, will do! Could I just reuse the jax.experimental.ode implementation of RK to do this? Is there any reason jax.disable_jit() wouldn't do the trick?

@samuela
Copy link
Contributor Author

samuela commented Apr 7, 2020

It's not just about simplicity; being able to statically analyze the shapes and memory requirements enables a ton of optimizations (fusion, layout, remat, etc).

OTOH, what prevents XLA from intelligently inferring loop bounds and then applying optimizations where possible? This seems to be fairly standard practice in compilers.

@mattjj
Copy link
Collaborator

mattjj commented Apr 7, 2020

Yeah, disable_jit might work! Though perhaps not with vmap because of the other issue you identified.

@samuela
Copy link
Contributor Author

samuela commented Apr 7, 2020

Ok, gotcha!

@mattjj
Copy link
Collaborator

mattjj commented Apr 7, 2020

Re: loop bounds, sure that's possible in some cases (though it's more of a JAX issue than an XLA issue because XLA doesn't do autodiff). Actually that was my intention with this code: if we can specialize on the trip count then we can do reverse-mode autodiff. (It failed some internal test that I can't remember, but I think it should work...)

But the point of while_loop is that you might not be able to predict the trip count. Something like an ode integrator, or anything you run to numerical convergence, is a good example: a priori you don't have a bound on the number of iterations. (There are some tricks you can play, like use recursive checkpointing and assume that loop trip counts are always bounded by 2**32 or something, but that's just in the weeds.)

@samuela
Copy link
Contributor Author

samuela commented Apr 7, 2020

Ok after chatting with @MarisaKirisame it sounds like the key difference here is that Relay supports ADTs and allocation in the IR. So in their case the AD of a while loop is just another while loop with tracing that runs in the IR. This is made fast thanks to all kinds of compiler tricks within the IR. My understanding is that when loop bounds can be inferred they switch on using optimized, shape-aware compilation.

If I'm understanding correctly this means that even an omnipotent JAX would not be able to do AD through while loops when compiling down to XLA. If that's true it seems as though the way forward to fix AD for loops is either

  • Convincing the XLA crew that dynamic allocation and support for custom data types (ideally ADTs) are a worthwhile features, or
  • Using XLA custom calls to hack around the system with stateful updates outside the realm of XLA. This way a dynamic (but fast!) tape could be maintained while still enjoying all the usual benefits of XLA.

@Joshuaalbert
Copy link
Contributor

RE:Alternate solver choices. Do you mean implicit solvers like the Livermore Solver (LSODE; https://computing.llnl.gov/casc/nsde/pubs/u113855.pdf), which are defacto in scipy and can handle stiff equations? I started a project to write the LSODE for tensorflow and, let me say, it's tricky due to the many decisions made (page 58 of the link). There are a ton of heuristics that make things converge nicely and efficiently. There is the option however of stripping most of that logic out down to the fundamentals of an implicit ODE solver (which is also described fully pages 1-38 in that document).

@samuela
Copy link
Contributor Author

samuela commented Apr 7, 2020

@Joshuaalbert Yeah, I say the more options the better! I guess I view these sorts of things as good stress tests for tools like JAX and XLA. That flowchart looks pretty nasty, so it'd definitely be a challenge but I believe it should be do-able. Another option would be to support third-party solvers through FFI the same way scipy does. IIRC JAX has some kind of internal-ish way to do that sort of thing somewhere already.

@shoyer
Copy link
Collaborator

shoyer commented Apr 7, 2020

An implicit solver for stiff ODEs would definitely be a welcome addition. We have a BDF method in TF-probability that could be a good start.

It would also be nice to expose an interface for reusing the ODE gradient definition(s), allowing users to bring their own solvers without needing to write new gradient rules. This would be similar to what we did with lax.custom_root and lax.custom_linear_solve.

@John-Boik
Copy link

I'd like to add my voice to those who would like to see differentiating through the solver and other adjoint calculation methods. That would be great.

Also, it appears that the current version of experimental.odeint cannot handle 64 bit data. The error message below stems from my call to odeint. Is it possible to alter odeint to handle 64 bit data?

    tmp = odeint(MassSpring, latents, tp_train, rtol=rtol_make_data, atol=rtol_make_data)  
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 158, in odeint
    return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/api.py", line 150, in f_jitted
    name=flat_fun.__name__)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/core.py", line 951, in call_bind
    outs = primitive.impl(f, *args, **params)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/xla.py", line 463, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 221, in memoized_fun
    ans = call(fun, *args)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/xla.py", line 480, in _xla_callable
    fun, pvals, instantiate=False, stage_out=True, bottom=True)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 374, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 164, in _odeint_wrapper
    out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/custom_derivatives.py", line 455, in __call__
    *args_flat, out_trees=out_trees)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/custom_derivatives.py", line 502, in _custom_vjp_call_bind
    out_trees=out_trees)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 245, in process_custom_vjp_call
    return fun.call_wrapped(*tracers)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 199, in _odeint
    _, ys = lax.scan(scan_fun, init_carry, ts[1:])
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 848, in scan
    jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 62, in _initial_style_jaxpr
    wrapped_fun, in_pvals, instantiate=True, stage_out=False)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 374, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 189, in scan_fun
    _, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 228, in while_loop
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 62, in _initial_style_jaxpr
    wrapped_fun, in_pvals, instantiate=True, stage_out=False)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 374, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 179, in body_fun
    next_y, next_f, next_y_error, k = runge_kutta_step(func_, y, f, t, dt)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 116, in runge_kutta_step
    k = lax.fori_loop(1, 7, body_fun, k)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 171, in fori_loop
    (lower, upper, init_val))
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 228, in while_loop
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 62, in _initial_style_jaxpr
    wrapped_fun, in_pvals, instantiate=True, stage_out=False)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 374, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 100, in while_body_fun
    return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 112, in body_fun
    ft = func(yi, ti)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 169, in <lambda>
    func_ = lambda y, t: func(y, t, *args)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 146, in call_wrapped
    args, kwargs = next(gen)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 51, in ravel_first_arg_
    y = unravel(y_flat)
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/flatten_util.py", line 29, in <lambda>
    unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
  File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/api.py", line 1291, in _vjp_pullback_wrapper
    raise TypeError(msg.format(_dtype(a), dtype))
TypeError: Type of cotangent input to vjp pullback function (float64) does not match type of corresponding primal output (float32)
> /home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/api.py(1291)_vjp_pullback_wrapper()
-> raise TypeError(msg.format(_dtype(a), dtype))

@samuela
Copy link
Contributor Author

samuela commented Apr 24, 2020

@John-Boik Do you have a code sample you could share that reproduces the issue? It looks like the forward solve is not actually using float64 as you requested.

@John-Boik
Copy link

It was my fault. I was unintentionally passing odeint a matrix specified as float32.

@Joshuaalbert
Copy link
Contributor

@John-Boik and others, I would strongly advise against differentiating through an adaptive implict ODE solver like LSODE for performance reasons. Especially for a stiff problem. I would have time in June to help implement LSODE.

@shoyer
Copy link
Collaborator

shoyer commented May 20, 2020

jax.experimental.host_callback looks like could be a nice way to get auxiliary outputs out of ODE solvers. It's still very experimental, but it makes it possible to thread values of out jit compiled code back into Python (e.g., see the example implementing a printer in #3127).

@samuela
Copy link
Contributor Author

samuela commented May 21, 2020

@shoyer jax.experimental.host_callback looks like everything I've ever dreamed of. Haven't tried it yet, but I think my life may be complete now.

@samuela
Copy link
Contributor Author

samuela commented May 29, 2020

I'd also like to add a feature request for

@samuela
Copy link
Contributor Author

samuela commented Jun 25, 2020

Another feature that I'm realizing would be important

  • Returning the "dense output" spline, eg dense_output=True option in scipy's solve_ivp. All of the necessary math is done in the RK solve anyhow.

@skrsna
Copy link

skrsna commented Jul 1, 2020

I have implemented a BDF solver for stiff ODEs in JAX using TF-probability’s code. The implementation is pretty barebones right now and I still need to test JIT, VMAP. Also, I still need to implement adjoint gradient method like TF-probability’s which I’m planning to do soon. I tested it against SciPy’s VODE methods for stiff chemical kinetics problems and results look pretty good to me. The current implementation is here and I’m planning to open a PR here once I get the adjoint gradient method, JIT and VMAP etc working. Feedback and suggestions are welcome! Thanks for such a great framework.

@shoyer
Copy link
Collaborator

shoyer commented Jul 1, 2020

  • Returning the "dense output" spline, eg dense_output=True option in scipy's solve_ivp. All of the necessary math is done in the RK solve anyhow.

To make this work in a way that is compatible with jit, we would need to support picking a (max) static number of interpolation points for the spline. But I agree, this would be nice to have, particularly for the adjoint calculation because the stored interpolation could be used instead of integrating the solution backwards in time to recompute primal values.

@samuela
Copy link
Contributor Author

samuela commented Jul 1, 2020

To make this work in a way that is compatible with jit, we would need to support picking a (max) static number of interpolation points for the spline. But I agree, this would be nice to have, particularly for the adjoint calculation because the stored interpolation could be used instead of integrating the solution backwards in time to recompute primal values.

Yeah, XLA's allocation limitations present a bit of a challenge here.

@cisprague
Copy link

cisprague commented Oct 9, 2020

I'd also like to add a feature request for

I second this! I think this could be accomplished by using odeint in parallel with a multiple shooting scheme. Otherwise, one could perform direct collocation with a chosen quadrature (e.g. Hermite-Simpson) — I've done this with JAX and IPOPT, and it works pretty well. These methods require optimising a sparse nonlinear programme, which is more amenable to constrained optimisers. And, I think extra consideration will be needed for how to handle the Jacobian and Hessian sparsity, as well as mesh refinement (which might be a problem for jit).

@jacobjinkelly
Copy link
Contributor

btw, I have a rough implementation of some other RK solvers here, in case anyone has a use for them. :)

@benjaminpope
Copy link

Just checking on this thread - we are very interested in getting fast gradients through a stiff ODE solution.

Some cool ideas in this thread - are there any updates since last year?

@patrick-kidger
Copy link
Collaborator

For those watching this thread -- check out Diffrax, which is a library doing pretty much everything discussed here. Other RK solvers, implicit solvers, discretise-then-optimise, etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants