Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.lax.scan() segmentation fault with jax-metal on Mac M1 #20750

Open
kdesoto-astro opened this issue Apr 14, 2024 · 12 comments
Open

jax.lax.scan() segmentation fault with jax-metal on Mac M1 #20750

kdesoto-astro opened this issue Apr 14, 2024 · 12 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@kdesoto-astro
Copy link

Description

Segmentation fault when calling jax.lax.scan(), jax.lax.map(), and related functions. Segmentation fault can be traced back to core.AxisPrimitive().bind() call. Reproducible using jax-metal=0.0.5 and jax-metal=0.0.4, and using either M1 or M2 MacBook Pro.

import jax
rng = jax.random.PRNGKey(0)
test_input = jax.random.normal(key=rng, shape=(5,5,5))
initial_state = jax.numpy.array(0.0)

def test_func(x, y):
    return x, y

x, y = jax.lax.scan(test_func, initial_state, test_input)

System info (python version, jaxlib version, accelerator, etc.)

Device: Apple M1 Pro (and M2)
macOS: Sonoma 14.4 (and 14.5 Beta)
jax-metal: 0.0.6
jax: 0.4.26
jaxlib: 0.4.23
numpy: 1.26.4
python: 3.11.0 (main, Mar 1 2023, 12:33:14) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', release='23.5.0', version='Darwin Kernel Version 23.5.0', machine='arm64')

@kdesoto-astro kdesoto-astro added the bug Something isn't working label Apr 14, 2024
@carloswert
Copy link

Thank you for posting, I have the same error I also added it to the Apple developer forum for the Apple team side (https://forums.developer.apple.com/forums/thread/750160)

@twiecki
Copy link

twiecki commented May 27, 2024

Running into the same issue, but getting an error:

Running window adaptation
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 688, in sample
    return _sample_external_nuts(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 351, in _sample_external_nuts
    idata = pymc_jax.sample_jax_nuts(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 567, in sample_jax_nuts
    raw_mcmc_samples, sample_stats, library = sampler_fn(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 413, in _sample_blackjax_nuts
    states, stats = map_fn(get_posterior_samples)(keys, initial_points)
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 250, in _blackjax_inference_loop
    (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/adaptation/window_adaptation.py", line 334, in run
    last_state, info = jax.lax.scan(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 80, in wrapper_progress_bar
    _update_progress_bar(iter_num)
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 46, in _update_progress_bar
    _ = lax.cond(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 48, in <lambda>
    lambda _: io_callback(_define_bar, None, iter_num),
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/jax/_src/callback.py", line 502, in io_callback
    out_flat = io_callback_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: `EmitPythonCallback` not supported on METAL backend.

@paullabonne
Copy link

I also have the same issue with lax.scan() when trying to run on the metal GPU; it gives the segmentation fault or crashes the jupyter kernel. On the CPU with jax.config.update('jax_platform_name', 'cpu') it works fine.

@tsumme1
Copy link

tsumme1 commented May 30, 2024

I am having the same issue as well. The scan function causes a seg fault on versions 0.0.5-7 with an M3 Max (os: Sonoma 14.5). I also found that the error is triggered when using values from xs (inp in the code below) inside the scanned function. The scan function still works when only using the carry. As a stopgap, I created the function below that avoids the error while functioning similarly.

import jax
from collections.abc import Iterable
from jax.tree import map as tree_map

def compat_scan(f,carry,xs,unroll=False,length=None):
    ind = jnp.zeros(1,jnp.uint32)
    def exec(c,inp):
        state,k = c
        if isinstance(xs,Iterable):
            vals = tree_map(lambda x: x[k][0],xs)
            state,out = f(state,vals)
        else:
            state,out = f(state,xs[k][0])
        k += jnp.uint32(1)
        return (state,k),out
    (carry,ind), ys = jax.lax.scan(exec,(carry,ind),xs,unroll=unroll,length=length)
    return carry, ys

@shuhand0
Copy link
Collaborator

We are aware of the issue and working on a fix.

@adam-hartshorne
Copy link

This still doesn't appear to be fixed in 0.1.0

@shuhand0
Copy link
Collaborator

The fix will be in next public OS and need OS upgrade.

@bsarkar321
Copy link

Hi @shuhand0 . I upgraded to the next public OS released today (Darwin Kernel Version 23.6.0) and there is still a segmentation fault (both when testing on the original poster's library versions and the latest versions).

@kdesoto-astro
Copy link
Author

@shuhand0 Is there an update on this? With the OS update and Jax-metal==0.1.0 the segfault still occurs. It's concerning that a core functionality of Jax has been broken for all Mac Apple Silicon users for the past 4 months, with no fix or workaround.

@shuhand0
Copy link
Collaborator

The fix is within the MetalPerformanceShaderGraph Framework in MacOS Sequoia. Could you try the test on the latest MacOS 15 Beta 7?

@bsarkar321
Copy link

The fix is within the MetalPerformanceShaderGraph Framework in MacOS Sequoia. Could you try the test on the latest MacOS 15 Beta 7?

This worked for me (no longer segfaulting)! I'm using MacOS 15 Beta 7, and validated this on both the original poster's library versions and the latest versions.

@kdesoto-astro
Copy link
Author

This works for me too after updating to Sequoia - thank you!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

9 participants