Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal: segmentation fault inside jax.lax.while_loop #21552

Open
jonatanklosko opened this issue May 31, 2024 · 6 comments
Open

jax-metal: segmentation fault inside jax.lax.while_loop #21552

jonatanklosko opened this issue May 31, 2024 · 6 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@jonatanklosko
Copy link

jonatanklosko commented May 31, 2024

Description

import jax
import jax.numpy as jnp

def f(x):
  def cond(carry):
    i, x, acc = carry
    return i < x.shape[0]

  def body(carry):
    i, x, acc = carry
    return (i + 1, x, acc + x[i])

  i = jnp.array(0)
  acc = jnp.array(0)

  return jax.lax.while_loop(cond, body, (i, x, acc))

x = jnp.array([1, 2, 3])

# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<i32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<0> : tensor<i32>
    %1 = stablehlo.constant dense<1> : tensor<i32>
    %2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %arg0, %iterArg_1 = %1) : tensor<i32>, tensor<3xi32>, tensor<i32>
     cond {
      %3 = stablehlo.constant dense<3> : tensor<i32>
      %4 = stablehlo.compare  LT, %iterArg, %3,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %4 : tensor<i1>
    } do {
      %3 = stablehlo.constant dense<1> : tensor<i32>
      %4 = stablehlo.add %iterArg, %3 : tensor<i32>
      %5 = stablehlo.constant dense<0> : tensor<i32>
      %6 = stablehlo.compare  LT, %iterArg, %5,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      %7 = stablehlo.convert %iterArg : tensor<i32>
      %8 = stablehlo.constant dense<3> : tensor<i32>
      %9 = stablehlo.add %7, %8 : tensor<i32>
      %10 = stablehlo.select %6, %9, %iterArg : tensor<i1>, tensor<i32>
      %11 = stablehlo.dynamic_slice %iterArg_0, %10, sizes = [1] : (tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
      %12 = stablehlo.reshape %11 : (tensor<1xi32>) -> tensor<i32>
      %13 = stablehlo.convert %iterArg_1 : tensor<i32>
      %14 = stablehlo.add %13, %12 : tensor<i32>
      stablehlo.return %4, %iterArg_0, %14 : tensor<i32>, tensor<3xi32>, tensor<i32>
    }
    return %2#0, %2#1, %2#2 : tensor<i32>, tensor<3xi32>, tensor<i32>
  }
}

This above loop computes the sum of tensor elements. Running the code results in a segmentation fault.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

@jonatanklosko jonatanklosko added the bug Something isn't working label May 31, 2024
@jonatanklosko
Copy link
Author

Pretty sure the issue is specific to the dynamic slice inside while, I already run into this in several places, and removing the dynamic slice from the code makes it no longer segfault.

@shuhand0
Copy link
Collaborator

shuhand0 commented Jun 3, 2024

The dynamic slice prevents the backend to encode the whileOp. We are looking for the fix.

@acranej
Copy link

acranej commented Jun 11, 2024

Running into the same issue in jax-metal 0.1.0

@abrasumente233
Copy link

import jax
import jax.numpy as jnp

def f(x):
    def scan_fn(h, w):
        h_bne = w * h
        return h_bne, None

    return jax.lax.scan(scan_fn, x, jnp.array([[0.0]]))

x = jnp.ones(1)
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))

jax.lax.scan hits segfault as well, and also has a dynamic_slice in lowered HLO.

my system info:

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:34:54) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='[redacted]', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:16:51 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T8103', machine='arm64')

@aniquetahir
Copy link

I have the same issue :(

@vyeevani
Copy link

I'm having the same issue as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants