Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better argnums/argnames inference #10619

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

JeppeKlitgaard
Copy link
Contributor

@JeppeKlitgaard JeppeKlitgaard commented May 7, 2022

This PR adds additional tests for infer_argnums_and_argnames which has been moved from api.py to api_utils.py.

This is still a WIP and notably it changes the inference logic.

The previous logic is described in detail in the documentation, but seems to be needlessly restrictive, since I cannot foresee a use-case where you would want:

if both static_argnums and static_argnames are provided, inspect.signature is not used, and only actual parameters listed in either static_argnums or static_argnames will be treated as static.

Particularly since this gives somewhat unexpected/counter-intuitive behaviour when switching from a positional argument to a keyword argument:

def f(a): ...

# We expect f(1) == f(a=1) always, but this is not the case with previous logic

Similarly there doesn't seem to be a reason to only look up POSITIONAL_OR_KEYWORD type parameters, which is what leads to the bug described in #10618.

TODO if this PR is deemed viable:

  • change docstring of jax.jit to reflect new logic - will do this after first review

Description of new logic

infer_argnums_and_argnames always fills argnums with parameters listed in argnames AND vice-versa. Only in the exceedingly rare event that signature inspection is not possible is this behaviour not true, in which case we return argnums and argnames unaltered as expected.

I am curious if there is a reason why one would prefer the old behaviour over the new?

Other

Similarly to #10603, two cases of invalid use of argnames/argnums was found. They are also fixed by this PR in order to make tests pass.

This ties in with #10614

Fix: #10618

inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
)
def infer_argnums_and_argnames(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need tweaking once #10603 is merged

@shoyer
Copy link
Member

shoyer commented May 12, 2022

Similarly there doesn't seem to be a reason to only look up POSITIONAL_OR_KEYWORD type parameters, which is what leads to the bug described in #10618.

I agree, this seems worth fixing. I didn't really think about positional-only arguments when writing this function (we were not using a new enough version of Python to use them).

infer_argnums_and_argnames always fills argnums with parameters listed in argnames AND vice-versa. Only in the exceedingly rare event that signature inspection is not possible is this behaviour not true, in which case we return argnums and argnames unaltered as expected.

I am curious if there is a reason why one would prefer the old behaviour over the new?

This original behavior was intentional. I think it's better because it's more explicit -- we only infer these arguments if they are not provided at all.

This is potentially a little more manual work, but is less magical.

@JeppeKlitgaard
Copy link
Contributor Author

This original behavior was intentional. I think it's better because it's more explicit -- we only infer these arguments if they are not provided at all.

This is potentially a little more manual work, but is less magical.

@shoyer When I first ran into the following, I found the existing behaviour to be quite confusing:

@partial(jit, static_argnames=("arg2", "arg3",))
def fun(arg1, arg2, arg3):
	...

fun(1, 2, arg3=3)  # here arg2 is not static!

Partly, this is because I haven't seen other places in the Python eco-system where behaviour changes based on whether an argument is given as positional or through a keyword.

@shoyer
Copy link
Member

shoyer commented May 12, 2022

@shoyer When I first ran into the following, I found the existing behaviour to be quite confusing:

@partial(jit, static_argnames=("arg2", "arg3",))
def fun(arg1, arg2, arg3):
	...

fun(1, 2, arg3=3)  # here arg2 is not static!

Partly, this is because I haven't seen other places in the Python eco-system where behaviour changes based on whether an argument is given as positional or through a keyword.

What version of JAX are you running this on?

I tried to reproduce this:

from jax import jit
from functools import partial

@partial(jit, static_argnames=("arg2", "arg3",))
def fun(arg1, arg2, arg3):
  print(arg1, arg2, arg3)
  return arg1

fun(1, 2, arg3=3)

This prints Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 2 3. Which indicates that arg2 is indeed static here.

The more surprising case might be something like:

@partial(jit, static_argnames=("arg2",), static_argnums=(2,))
def fun(arg1, arg2, arg3):
  print(arg1, arg2, arg3)
  return arg1

fun(1, 2, arg3=3)

In which case none of the arguments are static. But I'm not sure this is really so surprising, given that it's inconsistent with the explicit list of static arguments.

@JeppeKlitgaard
Copy link
Contributor Author

This prints Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 2 3. Which indicates that arg2 is indeed static here.

@shoyer you're absolutely right, this is what I get as well. I can't recall where I had an example that behaved like this – it seems like I must've misremembered it – it probably involved both argnums and argnames which is, as you pointed out, disallowed under the current implementation. Apologies for the confusion!

I still like the extra magic that this PR provides, but I respect that this might not be the direction that JAX want's to take. I'll make a new PR with the fixes discussed below if you confirm that it's a no on the magic inference

Similarly there doesn't seem to be a reason to only look up POSITIONAL_OR_KEYWORD type parameters, which is what leads to the bug described in #10618.

I agree, this seems worth fixing. I didn't really think about positional-only arguments when writing this function (we were not using a new enough version of Python to use them).

@shoyer
Copy link
Member

shoyer commented May 12, 2022

I still like the extra magic that this PR provides, but I respect that this might not be the direction that JAX want's to take. I'll make a new PR with the fixes discussed below if you confirm that it's a no on the magic inference

Other JAX team members may have opinions here, but personally I don't like magic inference.

@JeppeKlitgaard
Copy link
Contributor Author

JeppeKlitgaard commented May 12, 2022

What initially lead me down this path was wanting to experiment with type annotations for 'argument annotation' (see: #10476). This would add a (in my opinion) significantly more ergonomic way of annotating arguments.

That would only make sense if the end user is allowed to 'mix' argument annotation styles such that the final annotation becomes the union of annotations given by type annotation, argnums and argnames. Perhaps this is a better description rather than calling it magic inference, which has a bit of a voodoo ring to it.

Example:

# Assuming there is appetite for type annotations approach
def f(a, b, c: Static[Any]):
	...

jitted = jit(f)  # Just c is marked static
jitted = jit(f, static_argnames=('b',))  # c and b are marked static
jitted = jit(f, static_argnums=(0,), static_argnames=('b',))  # a, b, c marked static

I wouldn't consider these magic, as in all 3 cases it does what I would expect

@partial(jit, static_argnames=("arg2",), static_argnums=(2,))
def fun(arg1, arg2, arg3):
  print(arg1, arg2, arg3)
  return arg1

fun(1, 2, arg3=3)

In which case none of the arguments are static. But I'm not sure this is really so surprising, given that it's inconsistent with the explicit list of static arguments.

I would consider this quite surprising, as the intention of the user here should be clear: the argument named arg2 should be static, as should the 3rd argument. Only upon reading the documentation of jax.jit (and doing so quite carefully) would I understand that this is not the behaviour I get.

Your example also nicely demonstrates how the user can get different behaviour depending on whether an argument is given as a positional argument or keyword argument – something that, to me, is deeply odd:

@partial(jit, static_argnames=("arg2",), static_argnums=(2,))
def fun(arg1, arg2, arg3):
  print(arg1, arg2, arg3)
  return arg1

fun(1, 2, arg3=3)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

fun(1, arg2=2, arg3=3)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 2 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

In both cases 2 is bound to arg2. The difference comes down to whether one thinks of the static annotations as living on the arguments of the function or in the runtime call signature, I suppose. This PR in effect moves them from the caller signature to the function signature. The current implementation can be particularly confusing when the function is defined and jitted in a different place than where it is called.

I am very keen to hear your (and other team members') thoughts on this, as I think the current approach can be quite confusing to new-comers.

@JeppeKlitgaard
Copy link
Contributor Author

JeppeKlitgaard commented May 13, 2022

I am currently working on a different PR that addresses some of these problems using a different approach, having taken the feedback received so far into consideration – feel free to disregard this for now, although hopefully the discussion is interesting.

@JeppeKlitgaard JeppeKlitgaard marked this pull request as draft May 13, 2022 14:16
@YouJiacheng
Copy link
Contributor

As a user, I think current inference behavior is surprising: no inference done if both argnums and argnames are provided.
Especially, the absence of inference will cause the behavior depends on how user pass the arguments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

static_argnames of jax.jit does not correctly infer argnums
5 participants