Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal: XlaRuntimeError: INTERNAL: Unable to serialize MPS module #20401

Open
jsakaya opened this issue Mar 22, 2024 · 5 comments
Open

jax-metal: XlaRuntimeError: INTERNAL: Unable to serialize MPS module #20401

jsakaya opened this issue Mar 22, 2024 · 5 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@jsakaya
Copy link

jsakaya commented Mar 22, 2024

Description

Encountered a XLARuntimeError while running a basic numpyro program using jax-metal. The issue arises when I try running the MCMC sampler.

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random
import jax.numpy as np

J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])


def eight_schools(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)


nuts_kernel = NUTS(eight_schools)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

Running this gives me the following error:

2024-03-22 13:27:18.342904: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

  0%|                                                                                                       | 0/1500 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "[...]/test.py", line 23, in <module>
    mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
  File "[...]/numpyro/numpyro/infer/mcmc.py", line 644, in run
    states_flat, last_state = partial_map_fn(map_args)
                              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/numpyro/numpyro/infer/mcmc.py", line 450, in _single_chain_mcmc
    collect_vals = fori_collect(
                   ^^^^^^^^^^^^^
  File "[...]/numpyro/numpyro/util.py", line 367, in fori_collect
    vals = jit(_body_fn)(i, vals)
           ^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Unable to serialize MPS module
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-22 13:36:51.819062: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

jax:    0.4.25
jaxlib: 0.4.23
numpy:  1.26.4
python: 3.11.8 (main, Feb 26 2024, 15:36:12) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='[...]-MacBook-Pro.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6031', machine='arm64')


macOS: Sonoma Version 14.4
@jsakaya jsakaya added the bug Something isn't working label Mar 22, 2024
@shuhand0
Copy link
Collaborator

Will you be able to reproduce the issue with a smaller module?

@jsakaya
Copy link
Author

jsakaya commented Mar 25, 2024

Hi, yes, I was able to. Here's a more stripped down version - let me know if you require something smaller.

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random

def model():
    mu = numpyro.sample('mu', dist.Normal(0, 5))    

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, extra_fields=('potential_energy',))

I set JAX_TRACEBACK_FILTERING=off and here's the error I get:

Unfiltered stack trace

Metal device set to: Apple M3 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

  0%|                                                                                                                   | 0/1500 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "[...]/pyro-svi/sample-numpyro.py", line 12, in <module>
    mcmc.run(rng_key, extra_fields=('potential_energy',))
  File "[...]/numpyro/numpyro/infer/mcmc.py", line 644, in run
    states_flat, last_state = partial_map_fn(map_args)
                              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/numpyro/numpyro/infer/mcmc.py", line 450, in _single_chain_mcmc
    collect_vals = fori_collect(
                   ^^^^^^^^^^^^^
  File "[...]/numpyro/numpyro/util.py", line 367, in fori_collect
    vals = jit(_body_fn)(i, vals)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/pjit.py", line 248, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/pjit.py", line 143, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/core.py", line 2727, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/core.py", line 423, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/pjit.py", line 1415, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/pjit.py", line 1392, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/pjit.py", line 1328, in _pjit_call_impl_python
    lowering_parameters=mlir.LoweringParameters()).compile()
                                                   ^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2271, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2734, in from_hlo
    xla_executable, compile_options = _cached_compilation(
                                      ^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2591, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/compiler.py", line 265, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "[...]/miniconda3/envs/numpyro-env/lib/python3.11/site-packages/jax/_src/compiler.py", line 237, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Unable to serialize MPS module

@dcalacci
Copy link

+1, Just ran into this same issue while trying to accelerate inference on my own apple silicon. I don't know nearly enough to help debug but am seeing the same issue

jax 0.4.26
jaxlib 0.4.23
numpy 1.26.4
python 3.12.3

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 22, 2024

@dcalacci there's not much to debug unfortunately: the metal plugin is still experimental and very incomplete, and so you should expect to run into these kinds of issues when using it. My recommendation would be to switch to non-experimental hardware.

@dcalacci
Copy link

Yeah no problem! I understand this is very experimental. Just giving a +1 to the issue as another person playing with these new tools. Thanks for all your hard work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants