Support symbolic zero operations in custom_jvp #5913
Labels
enhancement
New feature or request
P2 (eventual)
This ought to be addressed, but has no schedule at the moment. (Assignee optional)
It would be nice if
jax.custom_jvp
supported symbolic zero operations much like the deprecateddefvjp
functions:https://cs.opensource.google/jax/jax/+/master:jax/api.py;drc=654a5b332c0d2f7a77b275bf31e186cfe528cdb2;l=2565
Specifically, given a bivariate function
f(x, y)
, I would like to define thejvp
off
, such that I define thejvp
with respect tox
, and raise an error when differentiating with respect toy
. This can happen if I don't have a closed form for thejvp
with respect toy
, and I want users to avoid this as an issue.In the autodiff tracer API this is allowed (for instance see betainc: https://cs.opensource.google/jax/jax/+/master:jax/_src/lax/lax.py;drc=0dd1b5516d4028cd182d01865279dcebf2f27c40;l=2372). Taking
grad
orjvp
with respect to thex
parameter ofbetainc
gives the right answer, whereas writing acustom_jvp
forbetainc
will throw an exception.The text was updated successfully, but these errors were encountered: