-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax-metal: segmentation fault inside jax.lax.while_loop
#21552
Comments
Pretty sure the issue is specific to the dynamic slice inside while, I already run into this in several places, and removing the dynamic slice from the code makes it no longer segfault. |
The dynamic slice prevents the backend to encode the whileOp. We are looking for the fix. |
Running into the same issue in jax-metal 0.1.0 |
import jax
import jax.numpy as jnp
def f(x):
def scan_fn(h, w):
h_bne = w * h
return h_bne, None
return jax.lax.scan(scan_fn, x, jnp.array([[0.0]]))
x = jnp.ones(1)
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
my system info:
|
I have the same issue :( |
I'm having the same issue as well |
Description
HLO
This above loop computes the sum of tensor elements. Running the code results in a segmentation fault.
System info (python version, jaxlib version, accelerator, etc.)
jax-metal 0.0.7
The text was updated successfully, but these errors were encountered: