Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal: cond fails in compile in certain cases #21601

Open
jonatanklosko opened this issue Jun 3, 2024 · 2 comments
Open

jax-metal: cond fails in compile in certain cases #21601

jonatanklosko opened this issue Jun 3, 2024 · 2 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@jonatanklosko
Copy link

Description

import jax
import jax.numpy as jnp

def f(pred, x, y):
  return jax.lax.cond(pred, lambda xy: xy[0], lambda xy: xy[1], (x, y))

pred = jnp.array(0)
x = jnp.array(10.0)
y = jnp.array(20.0)

# Print lowered HLO
print(jax.jit(f).lower(pred, x, y).as_text())
print(jax.jit(f)(pred, x, y))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<f32> {mhlo.layout_mode = "default"}, %arg2: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<0> : tensor<i32>
    %1 = stablehlo.compare  NE, %arg0, %0,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %2 = stablehlo.convert %1 : (tensor<i1>) -> tensor<i32>
    %3 = "stablehlo.case"(%2) ({
      stablehlo.return %arg2 : tensor<f32>
    }, {
      stablehlo.return %arg1 : tensor<f32>
    }) : (tensor<i32>) -> tensor<f32>
    return %3 : tensor<f32>
  }
}

The above fails with

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Unable to serialize MPS module

However, it works if we change the branches slightly (just adding a constant):

jax.lax.cond(pred, lambda xy: xy[0] + 1.0, lambda xy: xy[1] + 2.0, (x, y))

Another example that fails is this:

def f(pred, x):
  return jax.lax.cond(pred, lambda x: (x + 1.0, 1.0), lambda x: (x + 2.0, 2.0), x)

pred = jnp.array(0)
x = jnp.array([1, 2, 3])

but works when changed to

jax.lax.cond(pred, lambda x: (x + 1.0, x + 1.0), lambda x: (x + 2.0, x + 2.0), x)

From the above I gather that this operation it is already supposed to work, but there is clearly some inconsistency.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

@jonatanklosko jonatanklosko added the bug Something isn't working label Jun 3, 2024
@jonatanklosko
Copy link
Author

On a sidenote, I am wondering why jax.lax.cond(pred, true_fun, false_fun, *operands) uses stablehlo.case, and not stablehlo.if (which is closer semantically). I found openxla/stablehlo#599, so perhaps there's nothing to be gained from stablehlo.if and stablehlo.case is used just because it's more generic? I would love someone from the Jax team to confirm :)

@dlwh
Copy link
Contributor

dlwh commented Jul 17, 2024

Can still reproduce in 0.1.0 fwiw

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants