-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better argnums
/argnames
inference
#10619
base: main
Are you sure you want to change the base?
Conversation
badf8dc
to
d139a92
Compare
d139a92
to
b00deb2
Compare
inspect.Parameter.VAR_POSITIONAL, | ||
inspect.Parameter.VAR_KEYWORD, | ||
) | ||
def infer_argnums_and_argnames( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need tweaking once #10603 is merged
I agree, this seems worth fixing. I didn't really think about positional-only arguments when writing this function (we were not using a new enough version of Python to use them).
This original behavior was intentional. I think it's better because it's more explicit -- we only infer these arguments if they are not provided at all. This is potentially a little more manual work, but is less magical. |
@shoyer When I first ran into the following, I found the existing behaviour to be quite confusing: @partial(jit, static_argnames=("arg2", "arg3",))
def fun(arg1, arg2, arg3):
...
fun(1, 2, arg3=3) # here arg2 is not static! Partly, this is because I haven't seen other places in the Python eco-system where behaviour changes based on whether an argument is given as positional or through a keyword. |
What version of JAX are you running this on? I tried to reproduce this: from jax import jit
from functools import partial
@partial(jit, static_argnames=("arg2", "arg3",))
def fun(arg1, arg2, arg3):
print(arg1, arg2, arg3)
return arg1
fun(1, 2, arg3=3) This prints The more surprising case might be something like: @partial(jit, static_argnames=("arg2",), static_argnums=(2,))
def fun(arg1, arg2, arg3):
print(arg1, arg2, arg3)
return arg1
fun(1, 2, arg3=3) In which case none of the arguments are static. But I'm not sure this is really so surprising, given that it's inconsistent with the explicit list of static arguments. |
@shoyer you're absolutely right, this is what I get as well. I can't recall where I had an example that behaved like this – it seems like I must've misremembered it – it probably involved both I still like the extra magic that this PR provides, but I respect that this might not be the direction that JAX want's to take. I'll make a new PR with the fixes discussed below if you confirm that it's a no on the magic inference
|
Other JAX team members may have opinions here, but personally I don't like magic inference. |
What initially lead me down this path was wanting to experiment with type annotations for 'argument annotation' (see: #10476). This would add a (in my opinion) significantly more ergonomic way of annotating arguments. That would only make sense if the end user is allowed to 'mix' argument annotation styles such that the final annotation becomes the union of annotations given by type annotation, Example: # Assuming there is appetite for type annotations approach
def f(a, b, c: Static[Any]):
...
jitted = jit(f) # Just c is marked static
jitted = jit(f, static_argnames=('b',)) # c and b are marked static
jitted = jit(f, static_argnums=(0,), static_argnames=('b',)) # a, b, c marked static I wouldn't consider these magic, as in all 3 cases it does what I would expect
I would consider this quite surprising, as the intention of the user here should be clear: the argument named Your example also nicely demonstrates how the user can get different behaviour depending on whether an argument is given as a positional argument or keyword argument – something that, to me, is deeply odd: @partial(jit, static_argnames=("arg2",), static_argnums=(2,))
def fun(arg1, arg2, arg3):
print(arg1, arg2, arg3)
return arg1
fun(1, 2, arg3=3)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
fun(1, arg2=2, arg3=3)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 2 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> In both cases I am very keen to hear your (and other team members') thoughts on this, as I think the current approach can be quite confusing to new-comers. |
I am currently working on a different PR that addresses some of these problems using a different approach, having taken the feedback received so far into consideration – feel free to disregard this for now, although hopefully the discussion is interesting. |
As a user, I think current inference behavior is surprising: no inference done if both |
This PR adds additional tests for
infer_argnums_and_argnames
which has been moved fromapi.py
toapi_utils.py
.This is still a WIP and notably it changes the inference logic.
The previous logic is described in detail in the documentation, but seems to be needlessly restrictive, since I cannot foresee a use-case where you would want:
Particularly since this gives somewhat unexpected/counter-intuitive behaviour when switching from a positional argument to a keyword argument:
Similarly there doesn't seem to be a reason to only look up
POSITIONAL_OR_KEYWORD
type parameters, which is what leads to the bug described in #10618.TODO if this PR is deemed viable:
jax.jit
to reflect new logic - will do this after first reviewDescription of new logic
infer_argnums_and_argnames
always fillsargnums
with parameters listed inargnames
AND vice-versa. Only in the exceedingly rare event that signature inspection is not possible is this behaviour not true, in which case we returnargnums
andargnames
unaltered as expected.I am curious if there is a reason why one would prefer the old behaviour over the new?
Other
Similarly to #10603, two cases of invalid use of
argnames
/argnums
was found. They are also fixed by this PR in order to make tests pass.This ties in with #10614
Fix: #10618