-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA XlaRuntimeError
with MPI on jax==0.4.31
#22995
Comments
@MasterSkepticista the error is related with getting XLA_FLAGS=--xla_gpu_shard_autotuning=false mpirun -n 8 python error.py Let me know if this workaround helps |
openxla/xla#13108 was reverted. --xla_gpu_shard_autotuning=false disables sharding of autotuning, not the autotuning itself. |
I can reproduce with jax==0.4.31 and --xla_gpu_shard_autotuning=false helps - looks like openxla/xla#13108 got into this JAX release before it got reverted. Thank you for cc'ing me, I'll investigate why does it fail. |
@vfdev-5 Your suggestion worked. |
… to 24 hours. Imported from GitHub PR openxla/xla#16153 Fixes a problem with MPI: jax-ml/jax#22995 Copybara import of the project: -- df0dfdd323385fbd07fdb2f909240b9f6264c712 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Change sharded GEMM autotuning timeout from infinity to 24 hours. -- ea59210f7ec7bad918304af63684beb8dc8100e7 by Ilia Sergachev <isergachev@nvidia.com>: Infinite duration causes issues with MPI. Merging this change closes #16153 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16153 from openxla:fix_autotuner_timeout ea59210f7ec7bad918304af63684beb8dc8100e7 PiperOrigin-RevId: 663683904
… to 24 hours. Imported from GitHub PR #16153 Fixes a problem with MPI: jax-ml/jax#22995 Copybara import of the project: -- df0dfdd by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Change sharded GEMM autotuning timeout from infinity to 24 hours. -- ea59210 by Ilia Sergachev <isergachev@nvidia.com>: Infinite duration causes issues with MPI. Merging this change closes #16153 COPYBARA_INTEGRATE_REVIEW=#16153 from openxla:fix_autotuner_timeout ea59210 PiperOrigin-RevId: 663760476
… to 24 hours. Imported from GitHub PR openxla/xla#16153 Fixes a problem with MPI: jax-ml/jax#22995 Copybara import of the project: -- df0dfdd323385fbd07fdb2f909240b9f6264c712 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Change sharded GEMM autotuning timeout from infinity to 24 hours. -- ea59210f7ec7bad918304af63684beb8dc8100e7 by Ilia Sergachev <isergachev@nvidia.com>: Infinite duration causes issues with MPI. Merging this change closes #16153 PiperOrigin-RevId: 663760476
I sent a fix to XLA which makes the reproducer from this bug work. Independent of that, sharded autotuning got enabled yesterday again and it will likely get into the next JAX release. |
Description
Hi,
jax.jit
on a function seems to fail when running in an OpenMPI environment. An MWE is shown below:The error can be on select processes (in which case I see the output tensor) or all processes (it hangs/exits). I can confirm this error does not appear in
jax==0.4.30
.System info (python version, jaxlib version, accelerator, etc.)
Error log
System info:
The text was updated successfully, but these errors were encountered: