You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Buffer donation would be nice. I don't see an issue for it so just opening it for tracking/asking if it's on the Apple JAX Metal Team's roadmap
>>> import jax
>>> jax.devices()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-12 16:07:16.498439: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro
systemMemory: 32.00 GB
maxCacheSize: 10.67 GB
[METAL(id=0)]
>>> import jax.numpy as jnp
>>> x = jax.jit(lambda x: x)(jnp.zeros((4, 5))
...
...
...
... )
>>> x = jax.jit(lambda x: x)(jnp.zeros((4, 5)))
>>> x = jax.jit(lambda x: x, donate_args=True)(jnp.zeros((4, 5)))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: jit() got an unexpected keyword argument 'donate_args'
>>> x = jax.jit(lambda x: x, donate_argnums=(0,))(jnp.zeros((4, 5)))
/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[4,5]).
Donation is not implemented for ('METAL',).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
warnings.warn("Some donated buffers were not usable:"
>>>
The text was updated successfully, but these errors were encountered:
Buffer donation would be nice. I don't see an issue for it so just opening it for tracking/asking if it's on the Apple JAX Metal Team's roadmap
The text was updated successfully, but these errors were encountered: