Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.grad should auto-detect non-differentiable values #15861

Open
Conchylicultor opened this issue May 4, 2023 · 7 comments
Open

jax.grad should auto-detect non-differentiable values #15861

Conchylicultor opened this issue May 4, 2023 · 7 comments
Labels
enhancement New feature or request

Comments

@Conchylicultor
Copy link
Member

Conchylicultor commented May 4, 2023

Currently Jax does not allow to differentiate a tree with non-differentiable.

@jax.grad
def fn(vals):
  return vals['x'] + vals['y']


vals = {
    'x': jnp.ones((), dtype=jnp.float32),
    'y': jnp.ones((), dtype=jnp.int32),  # Fail: << Integer not supported
}
fn(vals)

Fails with

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32.
If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

Currently, it require to either:

  • Implement a custom jax.grad op (like equinox filter_grad)
  • Split the tree into static / non-static part, pass a two arguments with using static_argnum, then merge back the tree.

Both of those options feels ugly / hacky.

Additionally, Jax will return 0. values even for variables which are not connected to the loss value graph:

@jax.grad
def fn(vals):
  return vals['x']  # y not used

vals = {
    'x': jnp.ones((), dtype=jnp.float32),
    'y': jnp.ones((), dtype=jnp.float32),
}

fn(vals) == {'x': 1, 'y': 0}  # y is 0, not None

Instead, other frameworks returns None so it's explicit which inputs have 0 gradients vs not connected to the graph.

By comparison, TF do this perfectly well out of the box:

with tf.GradientTape() as tape:
  loss = fn(vals)

grad = tape.gradient(loss, vals)
grad  # {'x': <tf.Tensor: shape=(), dtype=float32, numpy=1.0>, 'y': None}
@Conchylicultor Conchylicultor added the enhancement New feature or request label May 4, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented May 4, 2023

Thanks for the suggestion. There's also a proposal in #10614 to make argnums more expressive in order to handle this kind of case explicitly, rather than implicitly as in your proposal. JAX's design tends to avoid implicit choices, so I suspect that's the route the team will end up taking to address this issue.

@Conchylicultor
Copy link
Member Author

As far as I understand, #10614 will still require to split the tree into static and dynamic part.

So this won't solve the root issue of being able to differentiate only part of the tree.

Unless it was possible to mask the tree with something like:

mask = {'a': True, 'b': False}  # Differentiate with `a` but not `b`

@jax.grad(static={'x': mask})
def loss_fn(x):
  ...


loss_fn({'a': a, 'b': b})  # Differentiate with `a` but not `b`

Related to #5487.

@hawkinsp
Copy link
Member

hawkinsp commented May 5, 2023

I'm missing what's wrong with the allow_int option?

In [7]: @partial(jax.grad, allow_int=True)
   ...: def fn(vals):
   ...:   return vals['x'] + vals['y']
   ...:
   ...:
   ...: vals = {
   ...:     'x': jnp.ones((), dtype=jnp.float32),
   ...:     'y': jnp.ones((), dtype=jnp.int32),  # Fail: << Integer not supported
   ...: }
   ...: fn(vals)
Out[7]: {'x': Array(1., dtype=float32), 'y': array((b'',), dtype=[('float0', 'V')])}

@Conchylicultor
Copy link
Member Author

I'm all for something explicit, but it should be possible to have both an explicit syntax without scarifying the user experience.

For example with something like:

@jax.grad
def loss_fn(x):
  ...


loss_fn({'a': a, 'b': jax.no_grad(b)})

It's better than having to split into 2 args because jax.no_grad() could be applied directly by the libraries (equinox,...).

@hawkinsp
Copy link
Member

hawkinsp commented May 5, 2023

@Conchylicultor Using allow_int=True does not require you to split your arguments.

@ASEM000
Copy link

ASEM000 commented May 7, 2023

I usually deal with PyTrees as input, so I resort to something like this :

import jax 
import jax.numpy as jnp 

class NoGrad:
    def __init__(self, wrapped):
        self.wrapped = wrapped
    def __repr__(self):
        return f'NoGrad({self.wrapped})'

def unwrap(x):
   return x.wrapped if isinstance(x, NoGrad) else x

jax.tree_util.register_pytree_node(
    nodetype=NoGrad,
    flatten_func=lambda x: ((), x.wrapped),
    unflatten_func=lambda x, _: NoGrad(x),
)

@jax.grad
def fn(vals):
  vals = jax.tree_map(unwrap, vals)
  return vals['x']  # y not used

vals = {
    'x': jnp.ones((), dtype=jnp.float32),
    'y': NoGrad(jnp.ones((), dtype=jnp.float32)),
}

fn(vals)
# {'x': Array(1., dtype=float32), 'y': NoGrad(1.0)}

Unless I'm missing something, this pattern is explicit and does not alter the function signature (static and dynamic splitting).

@froystig
Copy link
Member

#14960 is related and I suspect could possibly help here too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants