Skip to content

Commit

Permalink
[BugFix] Make TransformedEnv mirror allow_done_after_reset proper…
Browse files Browse the repository at this point in the history
…ty of base env (#1810)
  • Loading branch information
matteobettini authored Jan 17, 2024
1 parent d1138e2 commit d7e20e1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
super().__init__(
device=kwargs.pop("device", "cpu"),
dtype=torch.get_default_dtype(),
allow_done_after_reset=kwargs.pop("allow_done_after_reset", False),
)
self.set_seed(seed)
self.is_closed = False
Expand Down
15 changes: 15 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7691,6 +7691,21 @@ def test_independent_reward_specs_from_shared_env(self):
assert base_env.reward_spec.space.minimum == -np.inf
assert base_env.reward_spec.space.maximum == np.inf

def test_allow_done_after_reset(self):
base_env = ContinuousActionVecMockEnv(allow_done_after_reset=True)
assert base_env._allow_done_after_reset
t1 = TransformedEnv(
base_env, transform=RewardClipping(clamp_min=0, clamp_max=4)
)
assert t1._allow_done_after_reset
with pytest.raises(
RuntimeError,
match="_allow_done_after_reset is a read-only property for TransformedEnvs",
):
t1._allow_done_after_reset = False
base_env._allow_done_after_reset = False
assert not t1._allow_done_after_reset


def test_nested_transformed_env():
base_env = ContinuousActionVecMockEnv()
Expand Down
14 changes: 13 additions & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def __init__(
env = env.to(device)
else:
device = env.device
super().__init__(device=None, **kwargs)
super().__init__(device=None, allow_done_after_reset=None, **kwargs)

if isinstance(env, TransformedEnv):
self._set_env(env.base_env, device)
Expand Down Expand Up @@ -679,6 +679,18 @@ def run_type_checks(self, value):
"run_type_checks is a read-only property for TransformedEnvs"
)

@property
def _allow_done_after_reset(self) -> bool:
return self.base_env._allow_done_after_reset

@_allow_done_after_reset.setter
def _allow_done_after_reset(self, value):
if value is None:
return
raise RuntimeError(
"_allow_done_after_reset is a read-only property for TransformedEnvs"
)

@property
def _inplace_update(self):
return self.base_env._inplace_update
Expand Down

0 comments on commit d7e20e1

Please sign in to comment.