You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After playing around with some inputs, the problem seems to occur for .at[:, i] and .at[:, :, i], but .at[i] works fine. So, any dimension higher than 0 seems to have a bug in the scatter update algorithm for booleans.
Is there some way I can help debug this? Is the jax-metal code open source? If it is, then if you point me to build instructions, I can try to track down the bug.
System info (python version, jaxlib version, accelerator, etc.)
Description
I ran into a rather surprising case involving 3D boolean arrays, which only seems to fail on jax-metal.
Correct behavior (CPU):
But on jax-metal, I get:
After playing around with some inputs, the problem seems to occur for
.at[:, i]
and.at[:, :, i]
, but.at[i]
works fine. So, any dimension higher than 0 seems to have a bug in the scatter update algorithm for booleans.Is there some way I can help debug this? Is the jax-metal code open source? If it is, then if you point me to build instructions, I can try to track down the bug.
System info (python version, jaxlib version, accelerator, etc.)
On jax-metal:
On CPU (correct behavior):
The text was updated successfully, but these errors were encountered: