From 4c8f91f243ef7e44f5056d105e0b531193a463ee Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Fri, 1 Sep 2023 11:56:13 +0100 Subject: [PATCH 1/2] [BugFix] Fix RewardSum spec transform to mimic reward spec (#1478) Signed-off-by: Matteo Bettini --- torchrl/envs/transforms/transforms.py | 33 ++++++++++++++++++++------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 81781b04dec..11715e56033 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -16,6 +16,7 @@ import torch from tensordict import unravel_key, unravel_key_list +from tensordict._tensordict import _unravel_key_to_tuple from tensordict.nn import dispatch from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import expand_as_right, NestedKey @@ -3835,15 +3836,31 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: state_spec = input_spec["full_state_spec"] if state_spec is None: state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) - reward_spec = self.parent.reward_spec + reward_spec = self.parent.output_spec["full_reward_spec"] + reward_spec_keys = list(reward_spec.keys(True, True)) # Define episode specs for all out_keys - for out_key in self.out_keys: - episode_spec = UnboundedContinuousTensorSpec( - shape=reward_spec.shape, - device=reward_spec.device, - dtype=reward_spec.dtype, - ) - state_spec[out_key] = episode_spec + for in_key, out_key in zip(self.in_keys, self.out_keys): + if ( + in_key in reward_spec_keys + ): # if this out_key has a corresponding key in reward_spec + out_key = _unravel_key_to_tuple(out_key) + temp_state_spec = state_spec + temp_rew_spec = reward_spec + for sub_key in out_key[:-1]: + if ( + not isinstance(temp_rew_spec, CompositeSpec) + or sub_key not in temp_rew_spec.keys() + ): + break + if sub_key not in temp_state_spec.keys(): + temp_state_spec[sub_key] = temp_rew_spec[sub_key].empty() + temp_rew_spec = temp_rew_spec[sub_key] + temp_state_spec = temp_state_spec[sub_key] + state_spec[out_key] = reward_spec[in_key].clone() + else: + raise ValueError( + f"The in_key: {in_key} is not present in the reward spec {reward_spec}." + ) input_spec["full_state_spec"] = state_spec return input_spec From dbab7bb593ff33ef0423d607329c3687b13833d4 Mon Sep 17 00:00:00 2001 From: Skander Moalla <37197319+skandermoalla@users.noreply.github.com> Date: Fri, 1 Sep 2023 12:58:54 +0200 Subject: [PATCH 2/2] [BugFix] Fix NoopResetEnv behavior when trials exceeded. (#1477) --- test/mocking_classes.py | 19 ++++++++++++++ test/test_transforms.py | 15 +++++++++++ torchrl/envs/transforms/transforms.py | 36 +++++++++++++-------------- 3 files changed, 52 insertions(+), 18 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 91676699997..b3ccb9caf51 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1037,6 +1037,25 @@ def _step( return tensordict +class IncrementingEnv(CountingEnv): + # Same as CountingEnv but always increments the count by 1 regardless of the action. + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + self.count += 1 # The only difference with CountingEnv. + tensordict = TensorDict( + source={ + "observation": self.count.clone(), + "done": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict.select().set("next", tensordict) + + class NestedCountingEnv(CountingEnv): # an env with nested reward and done states def __init__( diff --git a/test/test_transforms.py b/test/test_transforms.py index f20e505145a..40037085e8d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -26,6 +26,7 @@ CountingBatchedEnv, CountingEnvCountPolicy, DiscreteActionConvMockEnvNumpy, + IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, NestedCountingEnv, @@ -3172,6 +3173,20 @@ def test_noop_reset_env_error(self, random, device, compose): ): transformed_env.reset() + @pytest.mark.parametrize("noops", [0, 2, 8]) + @pytest.mark.parametrize("max_steps", [0, 5, 9]) + def test_noop_reset_limit_exceeded(self, noops, max_steps): + env = IncrementingEnv(max_steps=max_steps) + check_env_specs(env) + noop_reset_env = NoopResetEnv(noops=noops, random=False) + transformed_env = TransformedEnv(env, noop_reset_env) + if noops <= max_steps: # Normal behavior. + result = transformed_env.reset() + assert result["observation"] == noops + elif noops > max_steps: # Raise error as reset limit exceeded. + with pytest.raises(RuntimeError): + transformed_env.reset() + class TestObservationNorm(TransformBase): @pytest.mark.parametrize( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 11715e56033..6a726883971 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3123,8 +3123,11 @@ class NoopResetEnv(Transform): env (EnvBase): env on which the random actions have to be performed. Can be the same env as the one provided to the TransformedEnv class - noops (int, optional): number of actions performed after reset. - Default is `30`. + noops (int, optional): upper-bound on the number of actions + performed after reset. Default is `30`. + If noops is too high such that it results in the env being + done or truncated before the all the noops are applied, + in multiple trials, the transform raises a RuntimeError. random (bool, optional): if False, the number of random ops will always be equal to the noops value. If True, the number of random actions will be randomly selected between 0 and noops. @@ -3133,10 +3136,7 @@ class NoopResetEnv(Transform): """ def __init__(self, noops: int = 30, random: bool = True): - """Sample initial states by taking random number of no-ops on reset. - - No-op is assumed to be action 0. - """ + """Sample initial states by taking random number of no-ops on reset.""" super().__init__([]) self.noops = noops self.random = random @@ -3171,31 +3171,31 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: noops = ( self.noops if not self.random else torch.randint(self.noops, (1,)).item() ) - trial = 0 - while True: + trial = 0 + while trial <= _MAX_NOOPS_TRIALS: i = 0 + while i < noops: i += 1 tensordict = parent.rand_step(tensordict) tensordict = step_mdp(tensordict, exclude_done=False) - if tensordict.get(done_key): + if tensordict.get(done_key) or tensordict.get( + "truncated", torch.tensor(False) + ): tensordict = parent.reset(td_reset.clone(False)) break else: break trial += 1 - if trial > _MAX_NOOPS_TRIALS: - tensordict = parent.rand_step(tensordict) - if tensordict.get(("next", done_key)): - raise RuntimeError( - f"parent is still done after a single random step (i={i})." - ) - break - if tensordict.get(done_key): - raise RuntimeError("NoopResetEnv concluded with done environment") + else: + raise RuntimeError( + f"Parent env was repeatedly done or truncated" + f" before the sampled number of noops (={noops}) could be applied. " + ) + return tensordict.exclude(reward_key, inplace=True) def __repr__(self) -> str: