-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Proposal] Consistent argnums
and argnames
parameters for transformations
#10614
Comments
This is totally doable using sig = inspect.signature(f)
sig = ... # replace all defaults with False; elided for space
static_args = tuple(True if i in static_argnums else False for i in range(static_argnums))
static_kwargs = {k: True for k in static_argnames}
bound = sig.bind_partial(*static_args, **static_kwargs)
bound.apply_defaults()
static_args = bound.static_args
static_kwargs = bound.static_kwargs which canonicalises args and kwargs based on the signature of the function. Then if necessary do the same thing to the actual (This is exactly how Equinox handles |
@patrick-kidger I think a native JAX solution might even be able to just lean on the existing |
Supporting negative indices everywhere SGTM. (Sorry, haven't had time to reply to the broader proposal in more detail...) |
@patrick-kidger It seems that your example has some typos, and |
If Indeed it should be `range(len(static_argnums)). This was typed out without testing. If in doubt use the Equinox version, that definitely works ;) Equinox provides a superset of the interface being considered here - it also handles mapping over PyTrees, filter functions, auxiliary outputs, etc. |
@patrick-kidger It is not a user error since And |
Ah! I see what you're saying. Sorry, yes, being a bit slow today. JAX uses an index-based way of selecting arguments and I was thinking of a mask-based way. The basic principle holds, but you're right that the parsing would be a bit more involved. |
I have tried to implement it leveraging |
My trial: neg_argnums = tuple(argnum for argnum in argnums if argnum < 0)
argnums_set = set(argnum for argnum in argnums if argnum >= 0)
sentinel = object()
args = tuple(None if i in argnums_set else sentinel for i in range(max(argnums_set)))
kwargs = {k: None for k in argnames}
sig = inspect.signature(fun)
ba = inspect.BoundArguments(sig, sig.bind_partial(*args).arguments | sig.bind_partial(**kwargs))
args = ba.args
kwargs = ba.kwargs
# JAX need POSITIONAL_OR_KEYWORD, KEYWORD_ONLY and VAR_KEYWORD
# but ba.kwargs only contains KEYWORD_ONLY and VAR_KEYWORD |
Idea - Interface discussion: Don't use This is not only more succinct, but also allows us to maintain full backwards compatibility: Using the container class approach as proposed (proposal not finished) in #10746 would enable relatively painless support of |
@mattjj @hawkinsp I just wanted to flag this issue again, it would be great to make progress on this to improve JAX usability. The approach of just having e.g. Background: The current behavior in JAX is somewhat broken, where static_argnums cannot be passed as kwargs and static_argnames cannot be passed positionally. Moreover, counting argnums has a good amount of mental overhead, especially when using function transformations that changes the function signature, whereas argnames isn't supported everywhere, e.g. in pmap. |
@danijar Thank you for highlighting this. Even after having spent a good bit of time with this particular part of the JAX source, the behaviour still manages to confuse me from time to time. #10746 is intended as more of a rough sketch, but I think having an immutable dataclass object and passing that around might be a good option. Having had a very cursory look at the code, I think in most places I would be able to figure out how to expand the code to accept named arguments as well (many places have |
(I just wanted to note that looking at this is still on my radar but I haven't had time to do so between travel and other higher priorities. Sorry for the delay.) |
Is this true for all functions? It's not what I've been seeing with jit or pmap. Another failure case right now is d default arguments. If I mark an input with default value as static, it raises an error if the value isn't pissed in. This should all be pretty easy to do with the |
I second the proposal of #10614 (comment) and #10614 (comment) to use |
Any update on this ? Not being able to use |
Generally speaking I'd recommend against using In any case, I don't believe there are any plans to change the API for |
Hey JAX team,
I have been trying to wrap my head around 'argument annotation` in JAX for a bit in the hopes of finding a more intuitive/consistent implementation, which has lead me to the big block of text below. I would be super keen to hear your thoughts as I try to dive deeper into the inner workings of JAX.
Lately there have been a number of issues requesting improvements to
*_argnums
and*_argnames
parameters used in transformations in addition to other ergonomics improvements related to declaring which function arguments should be annotated with a given property. I figured it might be helpful to make an over-arching issue with the end goal of having a consistent, ergonomic way of specifying these parameters. Managing argument 'annotations' in transformations has definitely been one of the more frustrating experiences of learning JAX (which is otherwise entirely amazing, of course)Related issues:
jax.jit(donate_argnames=...)
#10539jit
input validation can lead to silently dynamic variables #10601 (Additional input validation for transformations #10603)jax.jit
correctly implementsstatic_argnames
even for cases with keyword-only arguments, which would suggest that it should be possible to addargnames
equivalents to any function that currently only implementsargnums
.An easier but less robust fix could be to map
argnames
toargnums
usinginspect
(see discussion: #1159). This would likely not work for keyword-only arguments (though it might for things likedonate_arg...
?)Current shortcomings
Currently even the most robust implementation of the 'argument annotation' mechanism behaves in a somewhat counter-intuitive way (although this is suggested in the fine print of the docstring, if one reads it with sufficient care):
The fact that we have one instance where we are able to get the expected result gives hope that a solution should be possible by inspecting the function and arguments and modifying
static_argnums
andstatic_argnames
accordingly – or perhaps a better solution exists? Ideally we would want to avoid inspecting the arguments at call-time.I have started toying with validation of
static_argnums
andstatic_argnames
in #10603Goals
My suggestion would be that a solution that fixes the inconsistencies above (or in the worst case documents them thoroughly) is found for
jax.jit
.Once that is done, it would be great to see
*_argnames
and keyword-arg support added to other functions:jax.experiment.pjit
jax.pmap
jax.value_and_grad
jax.custom_vjp
jax.custom_jvp
jax.hessian
jax.jacrev
jax.jacfwd
jax.grad
Additionally #10476 can be explored (could live in
jax.experimental.annotations
, if there is any interest for this feature at all)Progress
jax.jit
(PR: Betterargnums
/argnames
inference #10619)jax.experiment.pjit
jax.pmap
jax.value_and_grad
jax.custom_vjp
jax.custom_jvp
jax.hessian
jax.jacrev
jax.jacfwd
jax.grad
The text was updated successfully, but these errors were encountered: