Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA XlaRuntimeError with MPI on jax==0.4.31 #22995

Open
MasterSkepticista opened this issue Aug 12, 2024 · 5 comments
Open

CUDA XlaRuntimeError with MPI on jax==0.4.31 #22995

MasterSkepticista opened this issue Aug 12, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@MasterSkepticista
Copy link

MasterSkepticista commented Aug 12, 2024

Description

Hi,

jax.jit on a function seems to fail when running in an OpenMPI environment. An MWE is shown below:

# error.py
# Run as: mpirun -n 8 python error.py

import os
from absl import logging
import jax, jax.numpy as jnp

logging.set_verbosity("info")
os.environ["no_proxy"] = "x.x.x.x"  # Internal use.
jax.distributed.initialize()

print("Hello from process %d holding %d device(s)" % (jax.process_index(), jax.local_device_count()))

def dot_product_attention(
    query: jnp.ndarray,
    key: jnp.ndarray,
    value: jnp.ndarray,
    *,
    dtype: jnp.dtype = jnp.float32) -> jnp.ndarray:
  depth = query.shape[-1]
  query = query / jnp.sqrt(depth).astype(dtype)
  attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key)
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
  return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value)

x = jnp.ones((1, 512, 8, 32), dtype=jnp.bfloat16)
f = lambda x: dot_product_attention(x, x, x)

print(jax.jit(f)(x))

The error can be on select processes (in which case I see the output tensor) or all processes (it hangs/exits). I can confirm this error does not appear in jax==0.4.30.

System info (python version, jaxlib version, accelerator, etc.)

Error log
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
Hello from process 3 holding 1 device(s)
Hello from process 5 holding 1 device(s)
Hello from process 1 holding 1 device(s)
Hello from process 7 holding 1 device(s)
Hello from process 0 holding 1 device(s)
Hello from process 4 holding 1 device(s)
Hello from process 6 holding 1 device(s)
Hello from process 2 holding 1 device(s)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_1 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[53590,1],2]
  Exit code:    1
--------------------------------------------------------------------------

System info:

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ubuntu', release='6.5.0-35-generic', version='#35~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue May  7 09:00:52 UTC 2', machine='x86_64')

Truncated nvidia-smi info: 
NVIDIA-SMI 555.42.06              
Driver Version: 555.42.06      
CUDA Version: 12.5
GPU: RTX A6000
@MasterSkepticista MasterSkepticista added the bug Something isn't working label Aug 12, 2024
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 14, 2024

@MasterSkepticista the error is related with getting cuda:gemm_fusion_autotuning_results on shards and maybe related to openxla/xla#13108 (cc @sergachev). To disable the autotuning and to make your MWE work, you could try to run it with:

XLA_FLAGS=--xla_gpu_shard_autotuning=false  mpirun -n 8 python error.py

Let me know if this workaround helps

@sergachev
Copy link
Contributor

openxla/xla#13108 was reverted.

--xla_gpu_shard_autotuning=false disables sharding of autotuning, not the autotuning itself.

@sergachev
Copy link
Contributor

I can reproduce with jax==0.4.31 and --xla_gpu_shard_autotuning=false helps - looks like openxla/xla#13108 got into this JAX release before it got reverted. Thank you for cc'ing me, I'll investigate why does it fail.

@MasterSkepticista
Copy link
Author

@vfdev-5 Your suggestion worked.
@sergachev I observed that JAX was built against openxla/xla@95e3eea, which was before the revert

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Aug 16, 2024
… to 24 hours.

Imported from GitHub PR openxla/xla#16153

Fixes a problem with MPI: jax-ml/jax#22995
Copybara import of the project:

--
df0dfdd323385fbd07fdb2f909240b9f6264c712 by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Change sharded GEMM autotuning timeout from infinity to 24 hours.

--
ea59210f7ec7bad918304af63684beb8dc8100e7 by Ilia Sergachev <isergachev@nvidia.com>:

Infinite duration causes issues with MPI.

Merging this change closes #16153

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16153 from openxla:fix_autotuner_timeout ea59210f7ec7bad918304af63684beb8dc8100e7
PiperOrigin-RevId: 663683904
copybara-service bot pushed a commit to openxla/xla that referenced this issue Aug 16, 2024
… to 24 hours.

Imported from GitHub PR #16153

Fixes a problem with MPI: jax-ml/jax#22995
Copybara import of the project:

--
df0dfdd by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Change sharded GEMM autotuning timeout from infinity to 24 hours.

--
ea59210 by Ilia Sergachev <isergachev@nvidia.com>:

Infinite duration causes issues with MPI.

Merging this change closes #16153

COPYBARA_INTEGRATE_REVIEW=#16153 from openxla:fix_autotuner_timeout ea59210
PiperOrigin-RevId: 663760476
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Aug 16, 2024
… to 24 hours.

Imported from GitHub PR openxla/xla#16153

Fixes a problem with MPI: jax-ml/jax#22995
Copybara import of the project:

--
df0dfdd323385fbd07fdb2f909240b9f6264c712 by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Change sharded GEMM autotuning timeout from infinity to 24 hours.

--
ea59210f7ec7bad918304af63684beb8dc8100e7 by Ilia Sergachev <isergachev@nvidia.com>:

Infinite duration causes issues with MPI.

Merging this change closes #16153

PiperOrigin-RevId: 663760476
@sergachev
Copy link
Contributor

I sent a fix to XLA which makes the reproducer from this bug work. Independent of that, sharded autotuning got enabled yesterday again and it will likely get into the next JAX release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants