Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax metal failed to install #19886

Open
Kubiczek36 opened this issue Feb 20, 2024 · 8 comments
Open

Jax metal failed to install #19886

Kubiczek36 opened this issue Feb 20, 2024 · 8 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@Kubiczek36
Copy link

Kubiczek36 commented Feb 20, 2024

Description

Using the instructions on the pip website the jax_metal failed to install

(base) jakub.dokulil@nbm-imp-134 jd_python_learning % conda create -n jax_metal python=3.10          
Channels:
 - defaults
Platform: osx-64
Collecting package metadata (repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal

  added / updated specs:
    - python=3.10


The following NEW packages will be INSTALLED:

  bzip2              pkgs/main/osx-64::bzip2-1.0.8-h1de35cc_0 
  ca-certificates    pkgs/main/osx-64::ca-certificates-2023.12.12-hecd8cb5_0 
  libffi             pkgs/main/osx-64::libffi-3.4.4-hecd8cb5_0 
  ncurses            pkgs/main/osx-64::ncurses-6.4-hcec6c5f_0 
  openssl            pkgs/main/osx-64::openssl-3.0.13-hca72f7f_0 
  pip                pkgs/main/osx-64::pip-23.3.1-py310hecd8cb5_0 
  python             pkgs/main/osx-64::python-3.10.13-h5ee71fb_0 
  readline           pkgs/main/osx-64::readline-8.2-hca72f7f_0 
  setuptools         pkgs/main/osx-64::setuptools-68.2.2-py310hecd8cb5_0 
  sqlite             pkgs/main/osx-64::sqlite-3.41.2-h6c40b1e_0 
  tk                 pkgs/main/osx-64::tk-8.6.12-h5d9f67b_0 
  tzdata             pkgs/main/noarch::tzdata-2023d-h04d1e81_0 
  wheel              pkgs/main/osx-64::wheel-0.41.2-py310hecd8cb5_0 
  xz                 pkgs/main/osx-64::xz-5.4.5-h6c40b1e_0 
  zlib               pkgs/main/osx-64::zlib-1.2.13-h4dc903c_0 


Proceed ([y]/n)? 


Downloading and Extracting Packages:

Preparing transaction: done
Verifying transaction: done
Executing transaction: done
#
# To activate this environment, use
#
#     $ conda activate jax_metal
#
# To deactivate an active environment, use
#
#     $ conda deactivate

(base) jakub.dokulil@nbm-imp-134 jd_python_learning % conda activate jax_metal
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -m pip install -U pip                       
Requirement already satisfied: pip in /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages (23.3.1)
Collecting pip
  Using cached pip-24.0-py3-none-any.whl.metadata (3.6 kB)
Using cached pip-24.0-py3-none-any.whl (2.1 MB)
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.3.1
    Uninstalling pip-23.3.1:
      Successfully uninstalled pip-23.3.1
Successfully installed pip-24.0
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -m pip install numpy                        
Collecting numpy
  Using cached numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl (20.6 MB)
Installing collected packages: numpy
Successfully installed numpy-1.26.4
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -m pip install jax-metal                    
Collecting jax-metal
  Using cached jax_metal-0.0.5-py3-none-macosx_10_14_x86_64.whl.metadata (1.4 kB)
Requirement already satisfied: wheel~=0.35 in /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages (from jax-metal) (0.41.2)
Collecting six>=1.15.0 (from jax-metal)
  Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting jax==0.4.20 (from jax-metal)
  Using cached jax-0.4.20-py3-none-any.whl.metadata (23 kB)
Collecting jaxlib==0.4.20 (from jax-metal)
  Downloading jaxlib-0.4.20-cp310-cp310-macosx_10_14_x86_64.whl.metadata (2.1 kB)
Collecting ml-dtypes>=0.2.0 (from jax==0.4.20->jax-metal)
  Using cached ml_dtypes-0.3.2-cp310-cp310-macosx_10_9_universal2.whl.metadata (20 kB)
Requirement already satisfied: numpy>=1.22 in /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages (from jax==0.4.20->jax-metal) (1.26.4)
Collecting opt-einsum (from jax==0.4.20->jax-metal)
  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Collecting scipy>=1.9 (from jax==0.4.20->jax-metal)
  Using cached scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl.metadata (60 kB)
Using cached jax_metal-0.0.5-py3-none-macosx_10_14_x86_64.whl (54.6 MB)
Using cached jax-0.4.20-py3-none-any.whl (1.7 MB)
Downloading jaxlib-0.4.20-cp310-cp310-macosx_10_14_x86_64.whl (82.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 82.6/82.6 MB 3.5 MB/s eta 0:00:00
Using cached ml_dtypes-0.3.2-cp310-cp310-macosx_10_9_universal2.whl (389 kB)
Using cached scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl (38.9 MB)
Installing collected packages: six, scipy, opt-einsum, ml-dtypes, jaxlib, jax, jax-metal
Successfully installed jax-0.4.20 jax-metal-0.0.5 jaxlib-0.4.20 ml-dtypes-0.3.2 opt-einsum-3.3.0 scipy-1.12.0 six-1.16.0
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -c 'import jax; print(jax.numpy.arange(10))'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/__init__.py", line 39, in <module>
    from jax import config as _config_module
  File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/config.py", line 15, in <module>
    from jax._src.config import config as _deprecated_config  # noqa: F401
  File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/_src/config.py", line 28, in <module>
    from jax._src import lib
  File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 83, in <module>
    cpu_feature_guard.check_cpu_features()
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

System info (python version, jaxlib version, accelerator, etc.)

Macbook Air M2
Macos Sonoma 14.3.1 (23D60)
Python 3.10

@Kubiczek36 Kubiczek36 added the bug Something isn't working label Feb 20, 2024
@Kubiczek36 Kubiczek36 changed the title failed to install Jax metal failed to install Feb 20, 2024
@shuhand0
Copy link
Collaborator

Based on the packages, it is AMD GPU? Could you try a venv with python=3.9?

@curlup
Copy link
Contributor

curlup commented Feb 26, 2024

Reproduces on my m2 mac. with both py 3.10.6 and 3.9.13

@curlup
Copy link
Contributor

curlup commented Feb 26, 2024

Tried jax==0.4.11 jaxlib==0.4.11 jax-metal==0.0.4 - same thing

@shuhand0
Copy link
Collaborator

shuhand0 commented Mar 9, 2024

Haven't been able to reproduce the issue. The below config shows an installation and verification result:
ProductName: macOS
ProductVersion: 14.4

The following NEW packages will be INSTALLED:

  ca-certificates    pkgs/main/osx-64::ca-certificates-2023.12.12-hecd8cb5_0 
  libcxx             pkgs/main/osx-64::libcxx-14.0.6-h9765a3e_0 
  libffi             pkgs/main/osx-64::libffi-3.4.4-hecd8cb5_0 
  ncurses            pkgs/main/osx-64::ncurses-6.4-hcec6c5f_0 
  openssl            pkgs/main/osx-64::openssl-3.0.13-hca72f7f_0 
  pip                pkgs/main/osx-64::pip-23.3.1-py39hecd8cb5_0 
  python             pkgs/main/osx-64::python-3.9.18-h5ee71fb_0 
  readline           pkgs/main/osx-64::readline-8.2-hca72f7f_0 
  setuptools         pkgs/main/osx-64::setuptools-68.2.2-py39hecd8cb5_0 
  sqlite             pkgs/main/osx-64::sqlite-3.41.2-h6c40b1e_0 
  tk                 pkgs/main/osx-64::tk-8.6.12-h5d9f67b_0 
  tzdata             pkgs/main/noarch::tzdata-2024a-h04d1e81_0 
  wheel              pkgs/main/osx-64::wheel-0.41.2-py39hecd8cb5_0 
  xz                 pkgs/main/osx-64::xz-5.4.6-h6c40b1e_0 
  zlib               pkgs/main/osx-64::zlib-1.2.13-h4dc903c_0 
Package            Version
------------------ -------
importlib_metadata 7.0.2
jax                0.4.20
jax-metal          0.0.5
jaxlib             0.4.20
ml-dtypes          0.3.2
numpy              1.26.4
opt-einsum         3.3.0
pip                24.0
scipy              1.12.0
setuptools         68.2.2
six                1.16.0
wheel              0.41.2
zipp               3.17.0
python -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-08 17:33:36.946600: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: AMD Radeon Pro Vega 20

systemMemory: 32.00 GB
maxCacheSize: 1.99 GB

[0 1 2 3 4 5 6 7 8 9]

@curlup
Copy link
Contributor

curlup commented Mar 14, 2024

Right, i think i was able to figure it out - in my case it was due python being i386 arch and not arm64.
After switching arch and installing native python, it worked.

@phisanti
Copy link

Right, i think i was able to figure it out - in my case it was due python being i386 arch and not arm64. After switching arch and installing native python, it worked.

I have just tried to install following the instructions in the apple website (https://developer.apple.com/metal/jax/) and it failed. Same error than everyone here in a M2. How did you switched your native python3?

I have just ran the following code:

import platform

# Check the machine architecture
machine = platform.machine()

if machine == 'arm64':
    print("Your Python version is ARM64")
elif machine == 'i386':
    print("Your Python version is i386 (32-bit)")
elif machine == 'x86_64':
    print("Your Python version is x86_64 (64-bit)")
else:
    print(f"Unknown machine architecture: {machine}")

and the print out is:

Your Python version is x86_64 (64-bit)

@curlup
Copy link
Contributor

curlup commented Apr 5, 2024

@phisanti you switch in you CLI with arch command, then you install python afresh (it will be a different python) and go with jax m install instruct from apple.

@phisanti
Copy link

phisanti commented Apr 7, 2024

@curlup thanks for the tip. It worked for me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants