-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unable to initialize backend 'METAL' #21383
Comments
Getting the same error on M1 Air. |
Duplicate of #20148 ? |
which is the error in #20338. |
When I ran
and running
|
this works for me on M1 MAX. thanks for sharing! |
Im trying to install via poetry and I find that there is an issue where Jax-Metal will install a version of jax that cannot be overwritten by specifying an additonal jax dependency: [tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.26", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" When installing this, despite the clear attempt at overriding the • Installing jax (0.4.28)
• Installing jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl)
In short, this isnt working for me. for additional information. if I try $ python3
>>> import jax
jaxlib is version 0.4.26, but this version of jax requires version >= 0.4.27. however if I dont try to override, that is with [tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.27", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" Also take notice that it only updates the • Updating jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.27 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl) and do what I tried before: $ python3
>>> import jax
>>> jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/environment_info.py", line 45, in print_environment_info
devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1077, in devices
return get_backend(backend).devices()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1011, in get_backend
return _get_backend_uncached(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 990, in _get_backend_uncached
bs = backends()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 890, in backends
raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.) This same behavior occurs if I change the machine and enviroment info: |
"RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)". This error comes from jaxlib, which strictly checks the PJRT API version equality. jax-metal 0.0.7 adopts PJRT API from jaxlib-0.4.26. We have been communicated to JAX team and the solution is to set env var ENABLE_PJRT_COMPATIBILITY=1 if running jax-metal with jaxlib>0.4.26. The info can also be found in PYPI jax-metal page: https://pypi.org/project/jax-metal/. |
@BeeGass Are you sure your import platform
platform.machine() # should give you 'arm64' I got mine (M1 Air, Sonoma) working by doing python -m pip install jax==0.4.26 jaxlib==0.4.26
python -m pip install jax-metal after I used #19886 to set up my environment (but make sure that your shell is also |
cleared the cache, all virtual environments and so forth, did a fresh install of all the dependencies ensuring that arm64 is the correct plarform. $ python
>>> import platform
>>> platform.machine()
'arm64' for the sake of showing thoroughness: $ python3
>>> import platform
>>> platform.machine()
'arm64' Also performed right before install of all dependencies: arch -arm64 zsh tried original test again $ python -c 'import jax; print(jax.numpy.arange(10))'
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/__init__.py", line 37, in <module>
import jax.core as _core
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/core.py", line 18, in <module>
from jax._src.core import (
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 39, in <module>
from jax._src import dtypes
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dtypes.py", line 33, in <module>
from jax._src import config
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/config.py", line 27, in <module>
from jax._src import lib
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 75, in <module>
version = check_jaxlib_version(
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 64, in check_jaxlib_version
raise RuntimeError(msg)
RuntimeError: jaxlib is version 0.4.26, but this version of jax requires version >= 0.4.27. for the sake of showing thoroughness: changed the [tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.27", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" • Downgrading jaxlib (0.4.28 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.27 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl) again did the following: $ python3
>>> import platform
>>> platform.machine()
'arm64' arch -arm64 zsh performed the test above $ python3 -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
return lax.iota(dtype, start)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
return broadcasted_iota(dtype, (size,), 0)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.) also tried [tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.28", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" • Updating jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.28 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl) again did the following: $ python3
>>> import platform
>>> platform.machine()
'arm64' arch -arm64 zsh performed the test above $ python3 -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
return lax.iota(dtype, start)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
return broadcasted_iota(dtype, (size,), 0)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.) I have noticed that the people that have been able to get things working with this fix have been using the M3 chip. perhaps because im using the M1 chip, this could be the issue? Has anyone tried to replicate this on an M1? just to be clear: |
Thanks, this works for me. |
@BeeGass, have you tried setting env ENABLE_PJRT_COMPATIBILITY=1 to run jax-metal with jaxlib>0.4.26? |
|
follow this link, try based on your mac os versions, |
Having the same issue as @BeeGass , any chance you've found a fix?
Works fine if I use versions 4.26 for jax and jaxlib, but unfortunately I need 4.27+ for the package I'm trying to use. |
Adding NB: I came here because I was getting the error on my M2 Pro, Mac OSX Sonoma 14.6.1 (latest). I also saw this error on Ventura before I updated my OS. I did get it to work with Question: Is this a good test to see if XLA is working on Metal? import numpy as np
import jax
import jax.numpy as jnp
@jax.jit
def mult(X, Y):
return jnp.multiply(X, Y)
mat_shape = (3000, 3000)
%timeit mult(jnp.ones(mat_shape), jnp.ones(mat_shape)).block_until_ready()
%timeit np.multiply(np.ones(mat_shape), np.ones(mat_shape))
|
Description
I ran the Get Started code on the Apple Accelerated JAX training on Mac page, namely:
On running that last line I get the following error:
System info (python version, jaxlib version, accelerator, etc.)
Running
import jax; jax.print_environment_info()
returns the following error:Running the command a second time results in:
More info
I get the same issue on both my M1 Max MacStudio and M1 2020 MacBook Air. Both running Sonoma 14.5
The text was updated successfully, but these errors were encountered: