Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal: reduction "Runtime canonicalization must simplify reduction axes to minor 4 dimensions" #20112

Open
dlwh opened this issue Mar 7, 2024 · 4 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@dlwh
Copy link
Contributor

dlwh commented Mar 7, 2024

Description

Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> a = jnp.zeros( (2, 3, 4, 5, 6))
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:16:28.766754: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

>>> jnp.all(a == a)
Assertion failed: (0 <= mpsAxis && mpsAxis < 4 && "Runtime canonicalization must simplify reduction axes to minor 4 dimensions."), function encodeNDArrayOp, file GPUReductionOps.mm, line 76.
Abort trap: 6

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:15:35.748974: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

jax:    0.4.20
jaxlib: 0.4.20
numpy:  1.26.4
@dlwh dlwh added the bug Something isn't working label Mar 7, 2024
@shuhand0
Copy link
Collaborator

The backend kernel doesn't support rank > 4 for reduce op. Is it possible for the app to work around the issue by reshaping the tensor, e.g.,
a = jnp.zeros( (2, 3, 4, 5, 6)).reshape(-1, 4, 5, 6)

@dlwh
Copy link
Contributor Author

dlwh commented Mar 12, 2024

Thanks! I can work around, particularly for these --> scalar conversions. (But for this case it also seems like a straightforward thing to do on the plugin-end?) I could be wrong but I think any reduction of either 1 or all axes can be written as a reshape -> reduce -> reshape.

Can we leave this open as a sign post?

Is there a guide on Metal perf yet (presumably not JAX-focused, but something close by?)

@shuhand0
Copy link
Collaborator

The reduction dimension is not limited to 1 nor all axes from stablehlo spec, so the pattern(reshape-reduce-reshape) will not resolve all the cases. We will look into whether a more general conversion pattern could be added to jax-metal.

@dlwh
Copy link
Contributor Author

dlwh commented Mar 12, 2024

Sure, i meant those are the easy cases, and felt like a minimum. I think you can define something that's correct modulo floating point (and definitely not optimal) with transpose->reshape->reduce->reshape

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants