We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Example for argmin
import jax import jax.numpy as jnp def f(x, y): return jnp.minimum(x, y) x = jnp.array(2.0) y = jnp.array(float("nan")) # Print lowered HLO print(jax.jit(f).lower(x, y).as_text()) print(jax.jit(f)(x, y))
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}, %arg1: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = stablehlo.minimum %arg0, %arg1 : tensor<f32> return %0 : tensor<f32> } }
Numbers are not comparable with NaN, therefore min(2.0, NaN) => NaN. On the CPU, NaN is returned as expected.
min(2.0, NaN) => NaN
The same logic apples to the reducers min/max and consequently to argmin/argmax.
Currently the result is invalid for jnp.minimum, jnp.maximum, jnp.min, jnp.max, jnp.argmin, jnp.argmax.
jnp.minimum
jnp.maximum
jnp.min
jnp.max
jnp.argmin
jnp.argmax
jax: 0.4.26 jaxlib: 0.4.26 numpy: 1.26.4 python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')
jax-metal 0.1.0
The text was updated successfully, but these errors were encountered:
kulinseth
shuhand0
No branches or pull requests
Description
Example for argmin
HLO
Numbers are not comparable with NaN, therefore
min(2.0, NaN) => NaN
. On the CPU, NaN is returned as expected.The same logic apples to the reducers min/max and consequently to argmin/argmax.
Currently the result is invalid for
jnp.minimum
,jnp.maximum
,jnp.min
,jnp.max
,jnp.argmin
,jnp.argmax
.System info (python version, jaxlib version, accelerator, etc.)
jax-metal 0.1.0
The text was updated successfully, but these errors were encountered: