-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax-metal: XlaRuntimeError: INTERNAL: Unable to serialize MPS module #20401
Comments
Will you be able to reproduce the issue with a smaller module? |
Hi, yes, I was able to. Here's a more stripped down version - let me know if you require something smaller.
I set JAX_TRACEBACK_FILTERING=off and here's the error I get: Unfiltered stack trace
|
+1, Just ran into this same issue while trying to accelerate inference on my own apple silicon. I don't know nearly enough to help debug but am seeing the same issue
|
@dcalacci there's not much to debug unfortunately: the metal plugin is still experimental and very incomplete, and so you should expect to run into these kinds of issues when using it. My recommendation would be to switch to non-experimental hardware. |
Yeah no problem! I understand this is very experimental. Just giving a +1 to the issue as another person playing with these new tools. Thanks for all your hard work! |
Description
Encountered a XLARuntimeError while running a basic numpyro program using jax-metal. The issue arises when I try running the MCMC sampler.
Running this gives me the following error:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: