-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finite precision error calculations always 0 under JIT with bfloat16 #23007
Comments
Hi - thanks for the question! I spent some time making a more concise reproduction here import jax
def check_err(x, y):
result = x + y
y2 = result - x
return y - y2
op1 = jax.random.normal(jax.random.key(0), (5,), dtype='bfloat16')
op2 = jax.random.normal(jax.random.key(1), (5,), dtype='bfloat16')
print(check_err(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]
print(jax.jit(check_err)(op1, op2))
# [0 0 0 0 0] Since it looks like the compiler is doing something unexpected here, it will help to print the optimized HLO for the function: print(jax.jit(check_err).lower(op1, op2).compile().as_text())
and this shows what the problem is: the line def check_err(x, y):
x, y = x.astype('float32'), y.astype('float32')
result = x + y
y2 = result - x
return (y - y2).astype('bfloat16') I'm not aware of any way to prevent the compiler from doing this kind of casting – it's probably due to the fact that the hardware (CPU in my case) does not support native bfloat16 operations. I'll ask around to see if others have ideas. |
Via @apaszke, it seems the import os
os.environ['XLA_FLAGS'] = "--xla_allow_excess_precision=false"
import jax
def check_err(x, y):
result = x + y
y2 = result - x
return y - y2
op1 = jax.random.normal(jax.random.key(0), (5,), dtype='bfloat16')
op2 = jax.random.normal(jax.random.key(1), (5,), dtype='bfloat16')
print(check_err(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]
print(jax.jit(check_err)(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625] Note that XLA flag values are only read at the time the backend is initialized, so be sure to set them either as a system variable outside your script, or in your script via |
That seems to work. Thanks! |
Description
I have some stochastic rounding code and uncovered a bug when trying to use the code like the following:
With bfloat16, the final line prints
True
even though it's clear from the preceding line that not all errors ought to be 0.np.float32
does not have this behavior.Here are some lowering and compilation outputs, if that happens to be helpful. First bfloat16 and then float32:
(Originally reported at: jax-ml/ml_dtypes#167)
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: