Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finite precision error calculations always 0 under JIT with bfloat16 #23007

Open
colehaus opened this issue Aug 12, 2024 · 3 comments
Open

Finite precision error calculations always 0 under JIT with bfloat16 #23007

colehaus opened this issue Aug 12, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@colehaus
Copy link

Description

I have some stochastic rounding code and uncovered a bug when trying to use the code like the following:

def _error(x: ndarray[*Shape, Float], y: ndarray[*Shape, Float], result: ndarray[*Shape, Float]):
    y2 = result - x
    x2 = result - y2
    error_y = y - y2
    error_x = x - x2
    return error_x + error_y


def add(x: ndarray[*Shape, Float], y: ndarray[*Shape, Float]):
    result = x + y
    return _error(x, y, result)
dtype = bfloat16
op1 = jax.random.normal(jax.random.key(0), (1000, 4), dtype=dtype)
op2 = jax.random.normal(jax.random.key(1), (1000, 4), dtype=dtype)
print(jax.vmap(add)(op1, op2))
print(jnp.all(jax.jit(jax.vmap(add))(op1, op2) == 0))

With bfloat16, the final line prints True even though it's clear from the preceding line that not all errors ought to be 0. np.float32 does not have this behavior.

Here are some lowering and compilation outputs, if that happens to be helpful. First bfloat16 and then float32:

dtype = bfloat16
args = (jnp.arange(4, 7, dtype=dtype), jnp.arange(3, dtype=dtype) / 1000)
print(add(*args))
print(jax.jit(add)(*args))
print(jax.jit(add).lower(*args).as_text())
print(jax.jit(add).lower(*args).compile().as_text())
[0 0.000999451 0.0019989]
[0 0 0]
module @jit_add attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xbf16> {mhlo.layout_mode = "default"}, %arg1: tensor<3xbf16> {mhlo.layout_mode = "default"}) -> (tensor<3xbf16> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<3xbf16>
    %1 = stablehlo.subtract %0, %arg0 : tensor<3xbf16>
    %2 = stablehlo.subtract %0, %1 : tensor<3xbf16>
    %3 = stablehlo.subtract %arg1, %1 : tensor<3xbf16>
    %4 = stablehlo.subtract %arg0, %2 : tensor<3xbf16>
    %5 = stablehlo.add %4, %3 : tensor<3xbf16>
    return %5 : tensor<3xbf16>
  }
}

HloModule jit_add, is_scheduled=true, entry_computation_layout={(bf16[3]{0}, bf16[3]{0})->bf16[3]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="82694c3355091a0097f584dec86f3d57"}

%fused_convert (param_0.3: bf16[3], param_1.5: bf16[3]) -> bf16[3] {
  %param_0.3 = bf16[3]{0} parameter(0)
  %convert.9.1 = f32[3]{0} convert(bf16[3]{0} %param_0.3)
  %param_1.5 = bf16[3]{0} parameter(1)
  %convert.1.1 = f32[3]{0} convert(bf16[3]{0} %param_1.5)
  %add.2.1 = f32[3]{0} add(f32[3]{0} %convert.9.1, f32[3]{0} %convert.1.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=787}
  %subtract.8.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %convert.9.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=779}
  %subtract.9.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=780}
  %subtract.10.1 = f32[3]{0} subtract(f32[3]{0} %convert.9.1, f32[3]{0} %subtract.9.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=782}
  %subtract.11.1 = f32[3]{0} subtract(f32[3]{0} %convert.1.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=781}
  %add.4.1 = f32[3]{0} add(f32[3]{0} %subtract.10.1, f32[3]{0} %subtract.11.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=783}
  ROOT %convert.17.1 = bf16[3]{0} convert(f32[3]{0} %add.4.1)
}

ENTRY %main.9 (Arg_0.1.0: bf16[3], Arg_1.2.0: bf16[3]) -> bf16[3] {
  %Arg_1.2.0 = bf16[3]{0} parameter(1), metadata={op_name="y"}
  %Arg_0.1.0 = bf16[3]{0} parameter(0), metadata={op_name="x"}
  ROOT %loop_convert_fusion = bf16[3]{0} fusion(bf16[3]{0} %Arg_0.1.0, bf16[3]{0} %Arg_1.2.0), kind=kLoop, calls=%fused_convert
}
dtype = np.float32
args = (jnp.arange(4, 7, dtype=dtype), jnp.arange(3, dtype=dtype) / 1000)
print(add(*args))
print(jax.jit(add)(*args))
print(jax.jit(add).lower(*args).as_text())
print(jax.jit(add).lower(*args).compile().as_text())
[0.0000000e+00 7.2526745e-08 1.4505349e-07]
[0.0000000e+00 7.2526745e-08 1.4505349e-07]
module @jit_add attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3xf32> {mhlo.layout_mode = "default"}) -> (tensor<3xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
    %1 = stablehlo.subtract %0, %arg0 : tensor<3xf32>
    %2 = stablehlo.subtract %0, %1 : tensor<3xf32>
    %3 = stablehlo.subtract %arg1, %1 : tensor<3xf32>
    %4 = stablehlo.subtract %arg0, %2 : tensor<3xf32>
    %5 = stablehlo.add %4, %3 : tensor<3xf32>
    return %5 : tensor<3xf32>
  }
}

HloModule jit_add, is_scheduled=true, entry_computation_layout={(f32[3]{0}, f32[3]{0})->f32[3]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="9d06ea6507421c8754deb14690ff8cd9"}

%fused_add (param_0.1: f32[3], param_1.3: f32[3]) -> f32[3] {
  %param_1.3 = f32[3]{0} parameter(1)
  %param_0.1 = f32[3]{0} parameter(0)
  %add.2.1 = f32[3]{0} add(f32[3]{0} %param_1.3, f32[3]{0} %param_0.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=787}
  %subtract.8.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %param_1.3), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=779}
  %subtract.9.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=780}
  %subtract.10.1 = f32[3]{0} subtract(f32[3]{0} %param_1.3, f32[3]{0} %subtract.9.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=782}
  %subtract.11.1 = f32[3]{0} subtract(f32[3]{0} %param_0.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=781}
  ROOT %add.4.1 = f32[3]{0} add(f32[3]{0} %subtract.10.1, f32[3]{0} %subtract.11.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=783}
}

ENTRY %main.9 (Arg_0.1.0: f32[3], Arg_1.2.0: f32[3]) -> f32[3] {
  %Arg_1.2.0 = f32[3]{0} parameter(1), metadata={op_name="y"}
  %Arg_0.1.0 = f32[3]{0} parameter(0), metadata={op_name="x"}
  ROOT %loop_add_fusion = f32[3]{0} fusion(f32[3]{0} %Arg_1.2.0, f32[3]{0} %Arg_0.1.0), kind=kLoop, calls=%fused_add, metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=783}
}

(Originally reported at: jax-ml/ml_dtypes#167)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.1
python: 3.11.9 (main, Apr  6 2024, 17:59:24) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='npjfe11cq9', release='5.19.0-45-generic', version='#46~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Jun 7 15:06:04 UTC 20', machine='x86_64')


$ nvidia-smi
Sun Aug 11 01:03:45 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   42C    P8    27W [/](https://file+.vscode-resource.vscode-cdn.net/) 300W |  36785MiB [/](https://file+.vscode-resource.vscode-cdn.net/) 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
@colehaus colehaus added the bug Something isn't working label Aug 12, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 13, 2024

Hi - thanks for the question! I spent some time making a more concise reproduction here

import jax

def check_err(x, y):
  result = x + y
  y2 = result - x
  return y - y2

op1 = jax.random.normal(jax.random.key(0), (5,), dtype='bfloat16')
op2 = jax.random.normal(jax.random.key(1), (5,), dtype='bfloat16')

print(check_err(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]

print(jax.jit(check_err)(op1, op2))
# [0 0 0 0 0]

Since it looks like the compiler is doing something unexpected here, it will help to print the optimized HLO for the function:

print(jax.jit(check_err).lower(op1, op2).compile().as_text())
HloModule jit_check_err, entry_computation_layout={(bf16[5]{0}, bf16[5]{0})->bf16[5]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation (param_0.2: bf16[5], param_1.4: bf16[5]) -> bf16[5] {
  %param_1.4 = bf16[5]{0} parameter(1)
  %convert.11 = f32[5]{0} convert(bf16[5]{0} %param_1.4)
  %param_0.2 = bf16[5]{0} parameter(0)
  %convert.10 = f32[5]{0} convert(bf16[5]{0} %param_0.2)
  %add.0 = f32[5]{0} add(f32[5]{0} %convert.10, f32[5]{0} %convert.11), metadata={op_name="jit(check_err)/jit(main)/add" source_file="<ipython-input-4-c332ca662f3d>" source_line=4}
  %subtract.1 = f32[5]{0} subtract(f32[5]{0} %add.0, f32[5]{0} %convert.10), metadata={op_name="jit(check_err)/jit(main)/sub" source_file="<ipython-input-4-c332ca662f3d>" source_line=5}
  %subtract.0 = f32[5]{0} subtract(f32[5]{0} %convert.11, f32[5]{0} %subtract.1), metadata={op_name="jit(check_err)/jit(main)/sub" source_file="<ipython-input-4-c332ca662f3d>" source_line=6}
  ROOT %convert.9 = bf16[5]{0} convert(f32[5]{0} %subtract.0)
}

ENTRY %main.6 (Arg_0.1: bf16[5], Arg_1.2: bf16[5]) -> bf16[5] {
  %Arg_0.1 = bf16[5]{0} parameter(0)
  %Arg_1.2 = bf16[5]{0} parameter(1)
  ROOT %fusion = bf16[5]{0} fusion(bf16[5]{0} %Arg_0.1, bf16[5]{0} %Arg_1.2), kind=kLoop, calls=%fused_computation
}

and this shows what the problem is: the line %convert.11 = f32[5]{0} convert(bf16[5]{0} %param_1.4) is converting the input to float32 before doing all the operations, and then %convert.9 = bf16[5]{0} convert(f32[5]{0} %subtract.0) converts this back to bfloat16. Thus the error is accumulating in float32 precision, and then when this small error is cast back to bfloat16, it is too small to be represented in bfloat16, and so we get zero. Essentially, the JIT-compiled version is effectively doing this:

def check_err(x, y):
  x, y = x.astype('float32'), y.astype('float32')
  result = x + y
  y2 = result - x
  return (y - y2).astype('bfloat16')

I'm not aware of any way to prevent the compiler from doing this kind of casting – it's probably due to the fact that the hardware (CPU in my case) does not support native bfloat16 operations. I'll ask around to see if others have ideas.

@jakevdp jakevdp self-assigned this Aug 13, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 13, 2024

Via @apaszke, it seems the xla_allow_excess_precision flag controls this behavior. If you set it to False, then the compiler won't do this sort of internal upcasting:

import os
os.environ['XLA_FLAGS'] = "--xla_allow_excess_precision=false"

import jax

def check_err(x, y):
  result = x + y
  y2 = result - x
  return y - y2

op1 = jax.random.normal(jax.random.key(0), (5,), dtype='bfloat16')
op2 = jax.random.normal(jax.random.key(1), (5,), dtype='bfloat16')

print(check_err(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]

print(jax.jit(check_err)(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]

Note that XLA flag values are only read at the time the backend is initialized, so be sure to set them either as a system variable outside your script, or in your script via os.environ before running any jax commands.

@colehaus
Copy link
Author

That seems to work. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants