Skip to content

Commit

Permalink
[BugFix] Fix Brax (#2233)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 18, 2024
1 parent 35df59e commit 45ab9de
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 20 deletions.
35 changes: 20 additions & 15 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,12 +1966,13 @@ def test_env_device(self, env_name, frame_skip, transformed_out, device):


@pytest.mark.skipif(not _has_brax, reason="brax not installed")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("envname", ["fast"])
class TestBrax:
@pytest.mark.parametrize("requires_grad", [False, True])
def test_brax_constructor(self, envname, requires_grad):
env0 = BraxEnv(envname, requires_grad=requires_grad)
env1 = BraxWrapper(env0._env, requires_grad=requires_grad)
def test_brax_constructor(self, envname, requires_grad, device):
env0 = BraxEnv(envname, requires_grad=requires_grad, device=device)
env1 = BraxWrapper(env0._env, requires_grad=requires_grad, device=device)

env0.set_seed(0)
torch.manual_seed(0)
Expand All @@ -1994,12 +1995,12 @@ def test_brax_constructor(self, envname, requires_grad):
assert r1.requires_grad == requires_grad
assert_allclose_td(r0.data, r1.data)

def test_brax_seeding(self, envname):
def test_brax_seeding(self, envname, device):
final_seed = []
tdreset = []
tdrollout = []
for _ in range(2):
env = BraxEnv(envname)
env = BraxEnv(envname, device=device)
torch.manual_seed(0)
np.random.seed(0)
final_seed.append(env.set_seed(0))
Expand All @@ -2012,8 +2013,8 @@ def test_brax_seeding(self, envname):
assert_allclose_td(*tdrollout)

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_brax_batch_size(self, envname, batch_size):
env = BraxEnv(envname, batch_size=batch_size)
def test_brax_batch_size(self, envname, batch_size, device):
env = BraxEnv(envname, batch_size=batch_size, device=device)
env.set_seed(0)
tdreset = env.reset()
tdrollout = env.rollout(max_steps=50)
Expand All @@ -2023,8 +2024,8 @@ def test_brax_batch_size(self, envname, batch_size):
assert tdrollout.batch_size[:-1] == batch_size

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_brax_spec_rollout(self, envname, batch_size):
env = BraxEnv(envname, batch_size=batch_size)
def test_brax_spec_rollout(self, envname, batch_size, device):
env = BraxEnv(envname, batch_size=batch_size, device=device)
env.set_seed(0)
check_env_specs(env)

Expand All @@ -2036,7 +2037,7 @@ def test_brax_spec_rollout(self, envname, batch_size):
False,
],
)
def test_brax_consistency(self, envname, batch_size, requires_grad):
def test_brax_consistency(self, envname, batch_size, requires_grad, device):
import jax
import jax.numpy as jnp
from torchrl.envs.libs.jax_utils import (
Expand All @@ -2045,7 +2046,9 @@ def test_brax_consistency(self, envname, batch_size, requires_grad):
_tree_flatten,
)

env = BraxEnv(envname, batch_size=batch_size, requires_grad=requires_grad)
env = BraxEnv(
envname, batch_size=batch_size, requires_grad=requires_grad, device=device
)
env.set_seed(1)
rollout = env.rollout(10)

Expand All @@ -2064,9 +2067,9 @@ def test_brax_consistency(self, envname, batch_size, requires_grad):
torch.testing.assert_close(t1, t2)

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_brax_grad(self, envname, batch_size):
def test_brax_grad(self, envname, batch_size, device):
batch_size = (1,)
env = BraxEnv(envname, batch_size=batch_size, requires_grad=True)
env = BraxEnv(envname, batch_size=batch_size, requires_grad=True, device=device)
env.set_seed(0)
td1 = env.reset()
action = torch.randn(env.action_spec.shape)
Expand All @@ -2080,10 +2083,12 @@ def test_brax_grad(self, envname, batch_size):
@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
@pytest.mark.parametrize("parallel", [False, True])
def test_brax_parallel(
self, envname, batch_size, parallel, maybe_fork_ParallelEnv, n=1
self, envname, batch_size, parallel, maybe_fork_ParallelEnv, device, n=1
):
def make_brax():
env = BraxEnv(envname, batch_size=batch_size, requires_grad=False)
env = BraxEnv(
envname, batch_size=batch_size, requires_grad=False, device=device
)
env.set_seed(1)
return env

Expand Down
13 changes: 11 additions & 2 deletions torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib.util
import warnings

from typing import Dict, Optional, Union

Expand Down Expand Up @@ -202,6 +203,11 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs):
self._seed_calls_reset = None
self._categorical_action_encoding = categorical_action_encoding
super().__init__(**kwargs)
if not self.device:
warnings.warn(
f"No device is set for env {self}. "
f"Setting a device in Brax wrapped environments is strongly recommended."
)

def _check_kwargs(self, kwargs: Dict):
brax = self.lib
Expand Down Expand Up @@ -657,20 +663,23 @@ def _make_none(key, val):

# call vjp to get gradients
grad_state, grad_action = ctx.vjp_fn(grad_next_state_flat)
# assert grad_action.device == ctx.env.device

# reshape batch size
grad_state = _tree_reshape(grad_state, ctx.env.batch_size)
grad_action = _tree_reshape(grad_action, ctx.env.batch_size)
# assert grad_action.device == ctx.env.device

# convert ndarrays to tensors
grad_state_qp = _object_to_tensordict(
grad_state.pipeline_state,
device=ctx.env.device,
batch_size=ctx.env.batch_size,
)
grad_action = _ndarray_to_tensor(grad_action)
grad_action = _ndarray_to_tensor(grad_action).to(ctx.env.device)
grad_state_qp = {
key: val if key not in none_keys else None
for key, val in grad_state_qp.items()
}
return (None, None, grad_action, *grad_state_qp.values())
grads = (grad_action, *grad_state_qp.values())
return (None, None, *grads)
25 changes: 22 additions & 3 deletions torchrl/envs/libs/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase:

def _tensordict_to_object(tensordict: TensorDictBase, object_example):
"""Converts a TensorDict to a namedtuple or a dataclass."""
from jax import dlpack as jax_dlpack
from jax import dlpack as jax_dlpack, numpy as jnp

t = {}
_fields = _get_object_fields(object_example)
Expand All @@ -125,8 +125,27 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example):
else:
if value.dtype is torch.bool:
value = value.to(torch.uint8)
value = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous()))
t[name] = value.reshape(example.shape).view(example.dtype)
shape = value.shape
# We need to flatten to fix https://github.com/pytorch/rl/issues/2184
value = value.contiguous()
value = value.detach()
if value.ndim > 1:
value = value.flatten().clone()
else:
# Need this because otherwise an exception is raised
# ValueError: INTERNAL: Address of buffer 1 must be a multiple of 10, but was 0x7efccec00824
value = value.clone()
value = jax_dlpack.from_dlpack(value)
if shape.numel() == 1 and not value.shape:
while value.shape != shape:
value = jnp.expand_dims(value, 0)
if value.dtype != example.dtype:
t[name] = value.view(example.dtype)
else:
t[name] = value
else:
value = jnp.reshape(value, tuple(shape))
t[name] = value.view(example.dtype).reshape(example.shape)
return type(object_example)(**t)


Expand Down

0 comments on commit 45ab9de

Please sign in to comment.