You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
System info (python version, jaxlib version, accelerator, etc.)
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:15:35.748974: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro
systemMemory: 32.00 GB
maxCacheSize: 10.67 GB
jax: 0.4.20
jaxlib: 0.4.20
numpy: 1.26.4
The text was updated successfully, but these errors were encountered:
The backend kernel doesn't support rank > 4 for reduce op. Is it possible for the app to work around the issue by reshaping the tensor, e.g., a = jnp.zeros( (2, 3, 4, 5, 6)).reshape(-1, 4, 5, 6)
Thanks! I can work around, particularly for these --> scalar conversions. (But for this case it also seems like a straightforward thing to do on the plugin-end?) I could be wrong but I think any reduction of either 1 or all axes can be written as a reshape -> reduce -> reshape.
Can we leave this open as a sign post?
Is there a guide on Metal perf yet (presumably not JAX-focused, but something close by?)
The reduction dimension is not limited to 1 nor all axes from stablehlo spec, so the pattern(reshape-reduce-reshape) will not resolve all the cases. We will look into whether a more general conversion pattern could be added to jax-metal.
Sure, i meant those are the easy cases, and felt like a minimum. I think you can define something that's correct modulo floating point (and definitely not optimal) with transpose->reshape->reduce->reshape
Description
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: