Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xmap doesn't preserve static argnums #10741

Closed
nestordemeure opened this issue May 17, 2022 · 3 comments
Closed

xmap doesn't preserve static argnums #10741

nestordemeure opened this issue May 17, 2022 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@nestordemeure
Copy link

nestordemeure commented May 17, 2022

xmap appears to turn static arguments into traced shapearrays. This can cause problems when one needs an argument to be static for a test.

The following code demonstrate the issue, first using vmap (no problem) then replacing it with xmap which fails (Abstract tracer value encountered where concrete value is expected) when testing the value of cond at the very beginning of the func function:

import jax
import jax.numpy as jnp
from jax.experimental.maps import xmap

# dummy function
def func(cond, data):
   if cond: return data
   else: return data

# vectorisation
func_vmap = jax.vmap(func, in_axes=[None, 0], out_axes=0)
func_xmap = xmap(func, in_axes=[[...], ['axis']], out_axes=['axis'])

# jit compiling
func_vmap_jit = jax.jit(func_vmap, static_argnames=['cond'])
func_xmap_jit = jax.jit(func_xmap, static_argnames=['cond'])

# running
cond = True
data = jnp.ones(100)
out_vmap = func_vmap_jit(cond, data)
out_xmap = func_xmap_jit(cond, data)

This happens with the very last version of Jax (0.3.13).

@nestordemeure nestordemeure added the bug Something isn't working label May 17, 2022
@JeppeKlitgaard
Copy link
Contributor

I believe the changes I am working on related to #10614 would fix this.

@froystig
Copy link
Member

Assigned @apaszke for xmap and @JeppeKlitgaard based on the previous comment (#10741 (comment)).

@apaszke
Copy link
Collaborator

apaszke commented Jul 30, 2024

xmap was removed from JAX, so this is obsolete

@apaszke apaszke closed this as completed Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants