-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.grad should auto-detect non-differentiable values #15861
Comments
Thanks for the suggestion. There's also a proposal in #10614 to make |
As far as I understand, #10614 will still require to split the tree into So this won't solve the root issue of being able to differentiate only part of the tree. Unless it was possible to mask the tree with something like: mask = {'a': True, 'b': False} # Differentiate with `a` but not `b`
@jax.grad(static={'x': mask})
def loss_fn(x):
...
loss_fn({'a': a, 'b': b}) # Differentiate with `a` but not `b` Related to #5487. |
I'm missing what's wrong with the
|
I'm all for something explicit, but it should be possible to have both an explicit syntax without scarifying the user experience. For example with something like: @jax.grad
def loss_fn(x):
...
loss_fn({'a': a, 'b': jax.no_grad(b)}) It's better than having to split into |
@Conchylicultor Using |
I usually deal with PyTrees as input, so I resort to something like this : import jax
import jax.numpy as jnp
class NoGrad:
def __init__(self, wrapped):
self.wrapped = wrapped
def __repr__(self):
return f'NoGrad({self.wrapped})'
def unwrap(x):
return x.wrapped if isinstance(x, NoGrad) else x
jax.tree_util.register_pytree_node(
nodetype=NoGrad,
flatten_func=lambda x: ((), x.wrapped),
unflatten_func=lambda x, _: NoGrad(x),
)
@jax.grad
def fn(vals):
vals = jax.tree_map(unwrap, vals)
return vals['x'] # y not used
vals = {
'x': jnp.ones((), dtype=jnp.float32),
'y': NoGrad(jnp.ones((), dtype=jnp.float32)),
}
fn(vals)
# {'x': Array(1., dtype=float32), 'y': NoGrad(1.0)} Unless I'm missing something, this pattern is explicit and does not alter the function signature (static and dynamic splitting). |
#14960 is related and I suspect could possibly help here too. |
Currently Jax does not allow to differentiate a tree with non-differentiable.
Fails with
Currently, it require to either:
jax.grad
op (like equinoxfilter_grad
)static_argnum
, then merge back the tree.Both of those options feels ugly / hacky.
Additionally, Jax will return
0.
values even for variables which are not connected to the loss value graph:Instead, other frameworks returns
None
so it's explicit which inputs have0
gradients vs not connected to the graph.By comparison, TF do this perfectly well out of the box:
The text was updated successfully, but these errors were encountered: