[BUG] examples of brax in brax.py are failing on CUDA enabled machine #2319
Closed
Description
Describe the bug
The examples in https://github.com/pytorch/rl/blob/main/torchrl/envs/libs/brax.py are failing
To Reproduce
Steps to reproduce the behavior.
On a machine where CUDA is available, confirm with:
>>> import torch
>>> torch.cuda.is_available()
True
Run any example from the brax.py script.
E.g.
from torchrl.envs import BraxEnv
env = BraxEnv("ant")
env.set_seed(0)
td = env.reset()
td["action"] = env.action_spec.rand()
td = env.step(td)
print(td)
This results in
ValueError: Received incompatible devices for jitted computation. Got ARG_SHARDING with device ids [0] on platform GPU and ARG_SHARDING with device ids [0] on platform CPU
Expected behavior
No errors should appear.
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
- Python version
- Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.5.0+f840a1a 2.0.1 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] linux
Installed with requirements.txt:
torch==2.3.1
torchvision==0.18.1
mujoco
mujoco-mjx
gymnasium
matplotlib
wandb
matplotlib
joblib
hydra-core
ipython
brax
jax[cuda12]==0.4.28
and calls:
$ cd path/to/root
$ git clone https://github.com/pytorch/tensordict
$ git clone https://github.com/pytorch/rl
$ cd tensordict
$ python setup.py develop
$ cd ../rl
$ python setup.py develop
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)