Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to initialize backend 'METAL' #21383

Open
drbenvincent opened this issue May 23, 2024 · 16 comments
Open

Unable to initialize backend 'METAL' #21383

drbenvincent opened this issue May 23, 2024 · 16 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@drbenvincent
Copy link

drbenvincent commented May 23, 2024

Description

I ran the Get Started code on the Apple Accelerated JAX training on Mac page, namely:

python3 -m venv ~/jax-metal
source ~/jax-metal/bin/activate
python -m pip install -U pip
python -m pip install numpy wheel ml-dtypes==0.2.0

python -m pip install jax-metal

python -c 'import jax; print(jax.numpy.arange(10))'

On running that last line I get the following error:

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
    return lax.iota(dtype, start)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
    return broadcasted_iota(dtype, (size,), 0)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
    return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
    outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

System info (python version, jaxlib version, accelerator, etc.)

Running import jax; jax.print_environment_info() returns the following error:

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:874, in backends()
    873 try:
--> 874   backend = _init_backend(platform)
    875   _backends[platform] = backend

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:965, in _init_backend(platform)
    964 logger.debug("Initializing backend '%s'", platform)
--> 965 backend = registration.factory()
    966 # TODO(skye): consider raising more descriptive errors directly from backend
    967 # factories instead of returning None.

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:657, in register_plugin.<locals>.factory()
    656 if not xla_client.pjrt_plugin_initialized(plugin_name):
--> 657   xla_client.initialize_pjrt_plugin(plugin_name)
    658 updated_options = {}

File ~/jax-metal/lib/python3.10/site-packages/jaxlib/xla_client.py:176, in initialize_pjrt_plugin(plugin_name)
    169 """Initializes a PJRT plugin.
    170
    171 The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or
   (...)
    174   plugin_name: the name of the PJRT plugin.
    175 """
--> 176 _xla.initialize_pjrt_plugin(plugin_name)

XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In [2], line 1
----> 1 jax.print_environment_info()

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/environment_info.py:45, in print_environment_info(return_string)
     43   python_version = sys.version.replace('\n', ' ')
     44   with np.printoptions(threshold=4, edgeitems=2):
---> 45     devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
     46   info = textwrap.dedent(
     47       f"""\
     48   jax:    {version.__version__}
   (...)
     55 """
     56   )
     57   nvidia_smi = try_nvidia_smi()

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:1077, in devices(backend)
   1052 def devices(
   1053     backend: str | xla_client.Client | None = None
   1054 ) -> list[xla_client.Device]:
   1055   """Returns a list of all devices for a given backend.
   1056
   1057   .. currentmodule:: jaxlib.xla_extension
   (...)
   1075     List of Device subclasses.
   1076   """
-> 1077   return get_backend(backend).devices()

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:1011, in get_backend(platform)
   1007 @lru_cache(maxsize=None)  # don't use util.memoize because there is no X64 dependence.
   1008 def get_backend(
   1009     platform: None | str | xla_client.Client = None
   1010 ) -> xla_client.Client:
-> 1011   return _get_backend_uncached(platform)

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:990, in _get_backend_uncached(platform)
    986   return platform
    988 platform = (platform or _XLA_BACKEND.value or _PLATFORM_NAME.value or None)
--> 990 bs = backends()
    991 if platform is not None:
    992   platform = canonicalize_platform(platform)

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:890, in backends()
    888       else:
    889         err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
--> 890       raise RuntimeError(err_msg)
    892 assert _default_backend is not None
    893 if not config.jax_platforms.value:

RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

Running the command a second time results in:

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='BenjamicStudio7', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

More info

I get the same issue on both my M1 Max MacStudio and M1 2020 MacBook Air. Both running Sonoma 14.5

@drbenvincent drbenvincent added the bug Something isn't working label May 23, 2024
@twiecki
Copy link

twiecki commented May 23, 2024

Getting the same error on M1 Air.

@twiecki
Copy link

twiecki commented May 23, 2024

Duplicate of #20148 ?

@twiecki
Copy link

twiecki commented May 23, 2024

pip install jax==0.4.26 jaxlib==0.4.26 gives:

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716453894.367380  849760 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1716453894.452128  849760 service.cc:145] XLA service 0x600002350e00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716453894.452162  849760 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716453894.454461  849760 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716453894.454479  849760 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
loc("-":0:0): error: current mps dialect version is 1.0.0, can't parse version 1.1.0
/AppleInternal/Library/BuildRoots/1dd9a6a2-74cf-11ee-8ed5-2a65a1af8551/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1097: failed assertion `Error importing MLIR bytecode.
'
Abort trap: 6

which is the error in #20338.

@drbenvincent
Copy link
Author

When I ran pip install jax==0.4.26 jaxlib==0.4.26 then I think I got success.

jax.print_environment_info() now gives:

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='BenjamicStudio7', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

and running print(jax.numpy.arange(10)) in an ipython session gives

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716454190.044810 4298556 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

I0000 00:00:1716454190.058568 4298556 service.cc:145] XLA service 0x600000588a00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716454190.058577 4298556 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716454190.059879 4298556 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716454190.059894 4298556 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
[0 1 2 3 4 5 6 7 8 9]

@zhibor
Copy link

zhibor commented May 24, 2024

When I ran pip install jax==0.4.26 jaxlib==0.4.26 then I think I got success.

this works for me on M1 MAX. thanks for sharing!

@BeeGass
Copy link

BeeGass commented May 24, 2024

Im trying to install via poetry and I find that there is an issue where Jax-Metal will install a version of jax that cannot be overwritten by specifying an additonal jax dependency:

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.26", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }


[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

When installing this, despite the clear attempt at overriding the jax dependecy jax-metal has it set to 0.4.28

  • Installing jax (0.4.28)
  • Installing jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl)

When I ran pip install jax==0.4.26 jaxlib==0.4.26 then I think I got success.

In short, this isnt working for me.

for additional information. if I try

$ python3
>>> import jax
jaxlib is version 0.4.26, but this version of jax requires version >= 0.4.27.

however if I dont try to override, that is with jax and jaxlib as version 0.4.27:

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.27", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }


[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

Also take notice that it only updates the jaxlib and not the jax dependency

  • Updating jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.27 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl)

and do what I tried before:

$ python3
>>> import jax
>>> jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/environment_info.py", line 45, in print_environment_info
    devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1077, in devices
    return get_backend(backend).devices()
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1011, in get_backend
    return _get_backend_uncached(platform)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 990, in _get_backend_uncached
    bs = backends()
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 890, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

This same behavior occurs if I change the jax and jaxlib to version 0.4.28 or remove them entirely allowing jax-metal to install the correct jax and jaxlib versions.

machine and enviroment info:
Chip: Apple M1 Pro
MacOS: Sonoma 14.5
python version: 3.10.13

@shuhand0
Copy link
Collaborator

"RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)". This error comes from jaxlib, which strictly checks the PJRT API version equality. jax-metal 0.0.7 adopts PJRT API from jaxlib-0.4.26. We have been communicated to JAX team and the solution is to set env var ENABLE_PJRT_COMPATIBILITY=1 if running jax-metal with jaxlib>0.4.26. The info can also be found in PYPI jax-metal page: https://pypi.org/project/jax-metal/.

@yrahul3910
Copy link

yrahul3910 commented May 25, 2024

@BeeGass Are you sure your python3 is an arm64 binary? I've been bitten by this more times than I care for. Try

import platform

platform.machine()  # should give you 'arm64'

I got mine (M1 Air, Sonoma) working by doing

python -m pip install jax==0.4.26 jaxlib==0.4.26
python -m pip install jax-metal

after I used #19886 to set up my environment (but make sure that your shell is also arm64 before doing this, e.g. arch -arm64 zsh).

@BeeGass
Copy link

BeeGass commented May 25, 2024

@BeeGass Are you sure your python3 is an arm64 binary? I've been bitten by this more times than I care for. Try

import platform

platform.machine()  # should give you 'arm64'

cleared the cache, all virtual environments and so forth, did a fresh install of all the dependencies ensuring that arm64 is the correct plarform.

$ python
>>> import platform
>>> platform.machine()
'arm64'

for the sake of showing thoroughness:

$ python3
>>> import platform
>>> platform.machine()
'arm64'

Also performed right before install of all dependencies:

arch -arm64 zsh

tried original test again

$ python -c 'import jax; print(jax.numpy.arange(10))'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/__init__.py", line 37, in <module>
    import jax.core as _core
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/core.py", line 18, in <module>
    from jax._src.core import (
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 39, in <module>
    from jax._src import dtypes
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dtypes.py", line 33, in <module>
    from jax._src import config
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/config.py", line 27, in <module>
    from jax._src import lib
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 75, in <module>
    version = check_jaxlib_version(
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 64, in check_jaxlib_version
    raise RuntimeError(msg)
RuntimeError: jaxlib is version 0.4.26, but this version of jax requires version >= 0.4.27.

for the sake of showing thoroughness:

changed the jax and jaxlib dependency versions to 0.4.28 (I know the version of PJRT is within 0.4.26 but given that isnt working I hoped that perhaps the 0.4.28 or 0.4.27 version may have the PJRT version as well.)

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.27", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }


[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
  • Downgrading jaxlib (0.4.28 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.27 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl)

again did the following:

$ python3
>>> import platform
>>> platform.machine()
'arm64'
arch -arm64 zsh

performed the test above

$ python3 -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
    return lax.iota(dtype, start)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
    return broadcasted_iota(dtype, (size,), 0)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
    return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
    outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

also tried

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.28", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }


[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
  • Updating jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.28 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl)

again did the following:

$ python3
>>> import platform
>>> platform.machine()
'arm64'
arch -arm64 zsh

performed the test above

$ python3 -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
    return lax.iota(dtype, start)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
    return broadcasted_iota(dtype, (size,), 0)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
    return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
    outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

I have noticed that the people that have been able to get things working with this fix have been using the M3 chip. perhaps because im using the M1 chip, this could be the issue? Has anyone tried to replicate this on an M1?

just to be clear:
Chip: Apple M1 Pro
MacOS: Sonoma 14.5

@shinhookang
Copy link

@BeeGass Are you sure your python3 is an arm64 binary? I've been bitten by this more times than I care for. Try

import platform

platform.machine()  # should give you 'arm64'

I got mine (M1 Air, Sonoma) working by doing

python -m pip install jax==0.4.26 jaxlib==0.4.26
python -m pip install jax-metal

after I used #19886 to set up my environment (but make sure that your shell is also arm64 before doing this, e.g. arch -arm64 zsh).

Thanks, this works for me.
Mine is M3 Pro, Sonoma 14.5.

@shuhand0
Copy link
Collaborator

shuhand0 commented May 28, 2024

@BeeGass, have you tried setting env ENABLE_PJRT_COMPATIBILITY=1 to run jax-metal with jaxlib>0.4.26?

@BeeGass
Copy link

BeeGass commented May 28, 2024

@BeeGass, have you tried setting env ENABLE_PJRT_COMPATIBILITY=1 to run jax-metal with jaxlib>0.4.26?

Yeah still same behavior. Am told that the jax version needs to be equal to or higher than 0.4.27
@shuhand0

@limyeeun1
Copy link

When I ran pip install jax==0.4.26 jaxlib==0.4.26 then I think I got success.

this works for me on M1 MAX. thanks for sharing!

this works for me on M3. Thank you!!!!!

@Aiyaz3007
Copy link

follow this link, try based on your mac os versions,
this works for me !

@tessadgreen
Copy link

Having the same issue as @BeeGass , any chance you've found a fix?

RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

Works fine if I use versions 4.26 for jax and jaxlib, but unfortunately I need 4.27+ for the package I'm trying to use.

@TylerMclaughlin
Copy link

Adding export ENABLE_PJRT_COMPATIBILITY=1 to my ~/.zshrc works. So far so good! (=

NB: I came here because I was getting the error on my M2 Pro, Mac OSX Sonoma 14.6.1 (latest). I also saw this error on Ventura before I updated my OS. I did get it to work with pip install jax==0.4.26 jaxlib==0.4.26 above, but then Flax and Optax both require >= 0.4.27 so there might be runtime errors if using 0.4.26 (unconfirmed).

Question: Is this a good test to see if XLA is working on Metal?

import numpy as np
import jax
import jax.numpy as jnp

@jax.jit
def mult(X, Y):
    return jnp.multiply(X, Y)

mat_shape = (3000, 3000)

%timeit mult(jnp.ones(mat_shape), jnp.ones(mat_shape)).block_until_ready()
%timeit np.multiply(np.ones(mat_shape), np.ones(mat_shape))
3.7 ms ± 183 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
19.2 ms ± 159 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests