You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
jax.debug.breakpoint() give strange results.
Some of the inputs are tensors of 1s with shape of 2048 (bfloat16). But if we do x.sum() is return 256.
But if I do that in the python script, I get the good results (2048):
Description
jax.debug.breakpoint() give strange results.
Some of the inputs are tensors of 1s with shape of 2048 (bfloat16). But if we do
x.sum()
is return 256.But if I do that in the python script, I get the good results (2048):
To prevent all that category of issues, breakpoint should give ArrayImpl object and not numpy object.
@sharadmv
What jax/jaxlib version are you using?
upstream
Which accelerator(s) are you using?
Independent of backend.
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: