Skip to content

Commit

Permalink
[BugFix] Fix brax wrapping (#2190)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 31, 2024
1 parent c0c32a0 commit 765952a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
28 changes: 27 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
RenameTransform,
)
from torchrl.envs.batched_envs import SerialEnv
from torchrl.envs.libs.brax import _has_brax, BraxEnv
from torchrl.envs.libs.brax import _has_brax, BraxEnv, BraxWrapper
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper
from torchrl.envs.libs.envpool import _has_envpool, MultiThreadedEnvWrapper
from torchrl.envs.libs.gym import (
Expand Down Expand Up @@ -1882,6 +1882,32 @@ def test_env_device(self, env_name, frame_skip, transformed_out, device):
@pytest.mark.skipif(not _has_brax, reason="brax not installed")
@pytest.mark.parametrize("envname", ["fast"])
class TestBrax:
@pytest.mark.parametrize("requires_grad", [False, True])
def test_brax_constructor(self, envname, requires_grad):
env0 = BraxEnv(envname, requires_grad=requires_grad)
env1 = BraxWrapper(env0._env, requires_grad=requires_grad)

env0.set_seed(0)
torch.manual_seed(0)
init = env0.reset()
if requires_grad:
init = init.apply(
lambda x: x.requires_grad_(True) if x.is_floating_point() else x
)
r0 = env0.rollout(10, tensordict=init, auto_reset=False)
assert r0.requires_grad == requires_grad

env1.set_seed(0)
torch.manual_seed(0)
init = env1.reset()
if requires_grad:
init = init.apply(
lambda x: x.requires_grad_(True) if x.is_floating_point() else x
)
r1 = env1.rollout(10, tensordict=init, auto_reset=False)
assert r1.requires_grad == requires_grad
assert_allclose_td(r0.data, r1.data)

def test_brax_seeding(self, envname):
final_seed = []
tdreset = []
Expand Down
13 changes: 8 additions & 5 deletions torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, Optional, Union

import torch
from packaging import version
from tensordict import TensorDict, TensorDictBase

from torchrl.data.tensor_specs import (
Expand All @@ -15,9 +16,6 @@
UnboundedContinuousTensorSpec,
)
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.utils import _classproperty

_has_brax = importlib.util.find_spec("brax") is not None
from torchrl.envs.libs.jax_utils import (
_extract_spec,
_ndarray_to_tensor,
Expand All @@ -27,6 +25,9 @@
_tree_flatten,
_tree_reshape,
)
from torchrl.envs.utils import _classproperty

_has_brax = importlib.util.find_spec("brax") is not None


def _get_envs():
Expand Down Expand Up @@ -204,12 +205,14 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs):

def _check_kwargs(self, kwargs: Dict):
brax = self.lib
if version.parse(brax.__version__) < version.parse("0.10.4"):
raise ImportError("Brax v0.10.4 or greater is required.")

if "env" not in kwargs:
raise TypeError("Could not find environment key 'env' in kwargs.")
env = kwargs["env"]
if not isinstance(env, brax.envs.env.Env):
raise TypeError("env is not of type 'brax.envs.env.Env'.")
if not isinstance(env, brax.envs.Env):
raise TypeError("env is not of type 'brax.envs.Env'.")

def _build_env(
self,
Expand Down

0 comments on commit 765952a

Please sign in to comment.