Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal: failed to legalize operation 'mhlo.dot_general' for einsum "ijk,kji->k" #20114

Open
dlwh opened this issue Mar 7, 2024 · 3 comments
Open
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@dlwh
Copy link
Contributor

dlwh commented Mar 7, 2024

Description

>>> import jax.numpy as jnp
>>> a = jnp.ones((2,3,4))
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:43:05.128623: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

>>> b = jnp.ones((4,3,2))
>>> jnp.einsum("ijk,kji->k", a, b)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3369, in einsum
    return _einsum_computation(operands, contractions, precision,  # type: ignore[operator]
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'mhlo.dot_general'
<stdin>:1:0: note: see current operation: %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [2], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [0, 1], rhs_contracting_dimensions = [2, 1]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<2x3x4xf32>, tensor<4x3x2xf32>) -> tensor<4xf32>

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:45:49.148886: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

jax:    0.4.20
jaxlib: 0.4.20
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
@dlwh dlwh added the bug Something isn't working label Mar 7, 2024
@dlwh dlwh changed the title Metal: failed to legalize operation 'mhlo.dot_general' Metal: failed to legalize operation 'mhlo.dot_general' for einsum "ijk,kji->k" Mar 7, 2024
@dlwh
Copy link
Contributor Author

dlwh commented Mar 7, 2024

also the similar but slightly different

>>> b = jnp.ones((4,3,2))
>>> b = jnp.ones((2,4,3))
>>> jnp.einsum("ijk,jki->k", a, b)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3362, in einsum
    operands, contractions = contract_path(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/opt_einsum/contract.py", line 238, in contract_path
    raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
ValueError: Size of label 'j' for operand 1 (3) does not match previous terms (2).
>>> a.shape, b.shape
((2, 3, 4), (2, 4, 3))
>>> jnp.einsum("ijk,ikj->k", a, b)

@steeve
Copy link

steeve commented Mar 21, 2024

Same with jax-metal 0.0.6

@ramithuh
Copy link

ramithuh commented Mar 26, 2024

Encountered a similar issue (I didn't narrow down the issue to the exact computation (matrix multiplication) though)

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Metal device set to: Apple M2
2024-03-25 22:16:49.889705: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!

jax :0.4.25
jax-metal  :0.0.6
jaxlib :0.4.23
see current operation: %7513 = "mhlo.dot_general"(%7499, %7395) 
{dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], 
lhs_contracting_dimensions = [2, 3], rhs_contracting_dimensions = [1, 3]>, 
precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : 
(tensor<1x31x20x5xf32>, tensor<1x20x31x5xf32>) -> tensor<1x31x31xf32>
Screenshot 2024-03-25 at 22 06 56

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants