Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apple Metal: tile aborting kernel with error: 'anec.reshape' op result #0 must be 4D/5D memref... #20413

Open
muffin-rice opened this issue Mar 24, 2024 · 1 comment
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@muffin-rice
Copy link

muffin-rice commented Mar 24, 2024

Description

A reproducible example (reproduceable with reshape, but tile aborts whereas reshape seems to continue):

 $ python 
Python 3.12.2 (main, Feb  6 2024, 20:19:44) [Clang 15.0.0 (clang-1500.1.0.2.5)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp 
>>> arr1 = jnp.array([[[2, 5, 0]]], dtype=jnp.uint8) # default datatype int32 actually does work 
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-24 16:15:50.307903: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

>>> arr2 = jnp.array([12, 13,  1]) # both uint8 and default work 
>>> jnp.tile(arr1, arr2)
loc("jit(reshape)/jit(main)/reshape[new_sizes=(1, 1, 1, 1, 1, 3) dimensions=None]"("<stdin>":1:0)): error: 'anec.reshape' op result #0 must be 4D/5D memref of 16-bit float or 8-bit signed integer or 8-bit unsigned integer values, but got 'memref<1x1x1x1x1x3xui8>'
zsh: abort      python

Here's the kernel with just reshape:

 $ python                
Python 3.12.2 (main, Feb  6 2024, 20:19:44) [Clang 15.0.0 (clang-1500.1.0.2.5)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp 
r>>> 
>>> arr = jnp.arange(20)
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-24 16:21:53.255063: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

>>> arr = jnp.arange(20).reshape((1,1,1,1,1,5,4)) # does not error? 
>>> arr = jnp.arange(20, dtype=jnp.uint8).reshape((1,1,1,1,1,5,4)) # has the error but does not abort the kernel
loc("jit(reshape)/jit(main)/reshape[new_sizes=(1, 1, 1, 1, 1, 5, 4) dimensions=None]"("<stdin>":1:6)): error: 'anec.reshape' op result #0 must be 4D/5D memref of 16-bit float or 8-bit signed integer or 8-bit unsigned integer values, but got 'memref<1x1x1x1x1x5x4xui8>'
>>> arr
Array([[[[[[[ 0,  1,  2,  3],
            [ 4,  5,  6,  7],
            [ 8,  9, 10, 11],
            [12, 13, 14, 15],
            [16, 17, 18, 19]]]]]]], dtype=uint8)
>>> arr = jnp.arange(20, dtype=jnp.uint8).reshape((1,1,1,1,1,5,4)) # unsure why this doesn't error, assuming some under-the-hood behavior
>>> arr = jnp.arange(21, dtype=jnp.uint8).reshape((1,1,1,1,1,3,7))
loc("jit(reshape)/jit(main)/reshape[new_sizes=(1, 1, 1, 1, 1, 3, 7) dimensions=None]"("<stdin>":1:6)): error: 'anec.reshape' op result #0 must be 4D/5D memref of 16-bit float or 8-bit signed integer or 8-bit unsigned integer values, but got 'memref<1x1x1x1x1x3x7xui8>'
>>> 

System info (python version, jaxlib version, accelerator, etc.)

M1 Mac with OS 14.4 non-beta, python 3.12, and versions:

 $ pip freeze | grep jax 
jax==0.4.20
jax-metal==0.0.5
jaxlib==0.4.20

I ran into this issue with jax-metal 0.0.6 and jaxlib 0.4.23 as well, tried to downgrade but it didn't work.

@muffin-rice muffin-rice added the bug Something isn't working label Mar 24, 2024
@muffin-rice muffin-rice changed the title tile aborting kernel with error: 'anec.reshape' op result #0 must be 4D/5D memref... on Apple Metal Apple Metal: tile aborting kernel with error: 'anec.reshape' op result #0 must be 4D/5D memref... Mar 25, 2024
@shuhand0
Copy link
Collaborator

The error msg seems a false one, which we will need to bypass certain backend compilation verification passes to avoid it. The tile abortion is reproducible, and points to the combination of (dtype != f32 and rank>4).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants