Skip to content

[BUG] examples of brax in brax.py are failing on CUDA enabled machine #2319

Closed
@Jendker

Description

Describe the bug

The examples in https://github.com/pytorch/rl/blob/main/torchrl/envs/libs/brax.py are failing

To Reproduce

Steps to reproduce the behavior.

On a machine where CUDA is available, confirm with:

>>> import torch
>>> torch.cuda.is_available()
True

Run any example from the brax.py script.

E.g.

from torchrl.envs import BraxEnv
env = BraxEnv("ant")
env.set_seed(0)
td = env.reset()
td["action"] = env.action_spec.rand()
td = env.step(td)
print(td)

This results in

ValueError: Received incompatible devices for jitted computation. Got ARG_SHARDING with device ids [0] on platform GPU and ARG_SHARDING with device ids [0] on platform CPU

Expected behavior

No errors should appear.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

0.5.0+f840a1a 2.0.1 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] linux

Installed with requirements.txt:

torch==2.3.1
torchvision==0.18.1
mujoco
mujoco-mjx
gymnasium
matplotlib
wandb
matplotlib
joblib
hydra-core
ipython
brax
jax[cuda12]==0.4.28

and calls:

$ cd path/to/root
$ git clone https://github.com/pytorch/tensordict
$ git clone https://github.com/pytorch/rl
$ cd tensordict
$ python setup.py develop
$ cd ../rl
$ python setup.py develop

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions