Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On jax-metal, updating multidimensional boolean arrays sometimes fails #20675

Open
shawwn opened this issue Apr 10, 2024 · 2 comments
Open

On jax-metal, updating multidimensional boolean arrays sometimes fails #20675

shawwn opened this issue Apr 10, 2024 · 2 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@shawwn
Copy link
Contributor

shawwn commented Apr 10, 2024

Description

I ran into a rather surprising case involving 3D boolean arrays, which only seems to fail on jax-metal.

Correct behavior (CPU):

>>> jnp.zeros((2,2,2), dtype=jnp.bool).at[:, :, 0].set(True)[:, :, 0]
Array([[ True,  True],
       [ True,  True]], dtype=bool)

But on jax-metal, I get:

>>> jnp.zeros((2,2,2), dtype=jnp.bool).at[:, :, 0].set(True)[:, :, 0]
Array([[False, False],
       [False, False]], dtype=bool)

After playing around with some inputs, the problem seems to occur for .at[:, i] and .at[:, :, i], but .at[i] works fine. So, any dimension higher than 0 seems to have a bug in the scatter update algorithm for booleans.

Is there some way I can help debug this? Is the jax-metal code open source? If it is, then if you point me to build instructions, I can try to track down the bug.

System info (python version, jaxlib version, accelerator, etc.)

On jax-metal:

>>> import jax; jax.print_environment_info()
jax:    0.4.25
jaxlib: 0.4.23
numpy:  1.26.2
python: 3.10.13 (main, Aug 24 2023, 22:36:46) [Clang 14.0.3 (clang-1403.0.22.14.1)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='shawn.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:49 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64')

On CPU (correct behavior):

>>> import jax; jax.print_environment_info()
jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.2
python: 3.10.13 (main, Aug 24 2023, 22:36:46) [Clang 14.0.3 (clang-1403.0.22.14.1)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='shawn.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:49 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64')
@shawwn shawwn added the bug Something isn't working label Apr 10, 2024
@shawwn
Copy link
Contributor Author

shawwn commented Apr 10, 2024

(Note that integers and floats seem to work fine; only dtype=jnp.bool_ seems affected.)

@shuhand0
Copy link
Collaborator

jax-metal is not open sourced as of the time. We'll look into the issue and update any change here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants