Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.debug.breakpoint() give numpy array and with bfloat16, it comptue #17402

Open
nouiz opened this issue Sep 1, 2023 · 1 comment
Open

jax.debug.breakpoint() give numpy array and with bfloat16, it comptue #17402

nouiz opened this issue Sep 1, 2023 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@nouiz
Copy link
Collaborator

nouiz commented Sep 1, 2023

Description

jax.debug.breakpoint() give strange results.
Some of the inputs are tensors of 1s with shape of 2048 (bfloat16). But if we do x.sum() is return 256.
But if I do that in the python script, I get the good results (2048):

# python test.py
2048 <class 'jax.numpy.bfloat16'> <class 'jaxlib.xla_extension.ArrayImpl'>
Entering jdb:
(jdb) x.sum()
256
(jdb) c
2048.0 <class 'jax.numpy.float16'> <class 'jaxlib.xla_extension.ArrayImpl'>
Entering jdb:
(jdb) x.sum()
2048.0
(jdb) c
2048.0 <class 'jax.numpy.float32'> <class 'jaxlib.xla_extension.ArrayImpl'>
Entering jdb:
(jdb) x.sum()
2048.0

To prevent all that category of issues, breakpoint should give ArrayImpl object and not numpy object.
@sharadmv

What jax/jaxlib version are you using?

upstream

Which accelerator(s) are you using?

Independent of backend.

Additional system info

No response

NVIDIA GPU info

No response

@nouiz nouiz added the bug Something isn't working label Sep 1, 2023
@hawkinsp
Copy link
Member

hawkinsp commented Sep 1, 2023

We should probably return a CPU Array in all cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants