-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.lax.scan() segmentation fault with jax-metal on Mac M1 #20750
Comments
Thank you for posting, I have the same error I also added it to the Apple developer forum for the Apple team side (https://forums.developer.apple.com/forums/thread/750160) |
Running into the same issue, but getting an error:
|
I also have the same issue with lax.scan() when trying to run on the metal GPU; it gives the segmentation fault or crashes the jupyter kernel. On the CPU with jax.config.update('jax_platform_name', 'cpu') it works fine. |
I am having the same issue as well. The scan function causes a seg fault on versions 0.0.5-7 with an M3 Max (os: Sonoma 14.5). I also found that the error is triggered when using values from
|
We are aware of the issue and working on a fix. |
This still doesn't appear to be fixed in 0.1.0 |
The fix will be in next public OS and need OS upgrade. |
Hi @shuhand0 . I upgraded to the next public OS released today (Darwin Kernel Version 23.6.0) and there is still a segmentation fault (both when testing on the original poster's library versions and the latest versions). |
@shuhand0 Is there an update on this? With the OS update and Jax-metal==0.1.0 the segfault still occurs. It's concerning that a core functionality of Jax has been broken for all Mac Apple Silicon users for the past 4 months, with no fix or workaround. |
The fix is within the MetalPerformanceShaderGraph Framework in MacOS Sequoia. Could you try the test on the latest MacOS 15 Beta 7? |
This worked for me (no longer segfaulting)! I'm using MacOS 15 Beta 7, and validated this on both the original poster's library versions and the latest versions. |
This works for me too after updating to Sequoia - thank you!! |
Description
Segmentation fault when calling jax.lax.scan(), jax.lax.map(), and related functions. Segmentation fault can be traced back to core.AxisPrimitive().bind() call. Reproducible using jax-metal=0.0.5 and jax-metal=0.0.4, and using either M1 or M2 MacBook Pro.
System info (python version, jaxlib version, accelerator, etc.)
Device: Apple M1 Pro (and M2)
macOS: Sonoma 14.4 (and 14.5 Beta)
jax-metal: 0.0.6
jax: 0.4.26
jaxlib: 0.4.23
numpy: 1.26.4
python: 3.11.0 (main, Mar 1 2023, 12:33:14) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', release='23.5.0', version='Darwin Kernel Version 23.5.0', machine='arm64')
The text was updated successfully, but these errors were encountered: