You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
>>> import jax.numpy as jnp
>>> a = jnp.ones((2,3,4))
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:43:05.128623: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro
systemMemory: 32.00 GB
maxCacheSize: 10.67 GB
>>> b = jnp.ones((4,3,2))
>>> jnp.einsum("ijk,kji->k", a, b)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3369, in einsum
return _einsum_computation(operands, contractions, precision, # type: ignore[operator]
File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'mhlo.dot_general'
<stdin>:1:0: note: see current operation: %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [2], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [0, 1], rhs_contracting_dimensions = [2, 1]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<2x3x4xf32>, tensor<4x3x2xf32>) -> tensor<4xf32>
System info (python version, jaxlib version, accelerator, etc.)
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:45:49.148886: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro
systemMemory: 32.00 GB
maxCacheSize: 10.67 GB
jax: 0.4.20
jaxlib: 0.4.20
numpy: 1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
The text was updated successfully, but these errors were encountered:
dlwh
changed the title
Metal: failed to legalize operation 'mhlo.dot_general'
Metal: failed to legalize operation 'mhlo.dot_general' for einsum "ijk,kji->k"
Mar 7, 2024
>>> b = jnp.ones((4,3,2))
>>> b = jnp.ones((2,4,3))
>>> jnp.einsum("ijk,jki->k", a, b)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3362, in einsum
operands, contractions = contract_path(
File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/opt_einsum/contract.py", line 238, in contract_path
raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
ValueError: Size of label 'j' for operand 1 (3) does not match previous terms (2).
>>> a.shape, b.shape
((2, 3, 4), (2, 4, 3))
>>> jnp.einsum("ijk,ikj->k", a, b)
Encountered a similar issue (I didn't narrow down the issue to the exact computation (matrix multiplication) though)
System info (python version, jaxlib version, accelerator, etc.)
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Metal device set to: Apple M2
2024-03-25 22:16:49.889705: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
jax :0.4.25
jax-metal :0.0.6
jaxlib :0.4.23
Description
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: