Errors from interaction between shard_map
with auto
and lax.map
#23019
Labels
bug
Something isn't working
shard_map
with auto
and lax.map
#23019
Description
I'm converting #23015 into an issue per @yashk2810's request.
Here's a minimal working example of what I'm trying to do. My goal in function
h
is to map a second functiong
over the first dimensiona
of a 2x2 array. Withing
, I attempt to map a third functionf
over the array's second dimensionb
. Each map appliesshard_map
tolax.map
;vmap
isn't a good replacement because its batching rule for my code's actualf
is quite expensive.My understanding is that
auto
should make this work. But I get aNotImplementedError
from the start ofjax.experimental.shard_map._shard_map_impl
. When I wraph
in ajit
before calling it, I instead get the following:System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.31
jaxlib: 0.4.31
numpy: 2.0.1
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', release='5.10.16.3-microsoft-standard-WSL2', version='1 SMP Fri Apr 2 22:23:49 UTC 2021', machine='x86_64')
The text was updated successfully, but these errors were encountered: