diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 91676699997..b3ccb9caf51 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1037,6 +1037,25 @@ def _step( return tensordict +class IncrementingEnv(CountingEnv): + # Same as CountingEnv but always increments the count by 1 regardless of the action. + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + self.count += 1 # The only difference with CountingEnv. + tensordict = TensorDict( + source={ + "observation": self.count.clone(), + "done": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict.select().set("next", tensordict) + + class NestedCountingEnv(CountingEnv): # an env with nested reward and done states def __init__( diff --git a/test/test_transforms.py b/test/test_transforms.py index f20e505145a..40037085e8d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -26,6 +26,7 @@ CountingBatchedEnv, CountingEnvCountPolicy, DiscreteActionConvMockEnvNumpy, + IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, NestedCountingEnv, @@ -3172,6 +3173,20 @@ def test_noop_reset_env_error(self, random, device, compose): ): transformed_env.reset() + @pytest.mark.parametrize("noops", [0, 2, 8]) + @pytest.mark.parametrize("max_steps", [0, 5, 9]) + def test_noop_reset_limit_exceeded(self, noops, max_steps): + env = IncrementingEnv(max_steps=max_steps) + check_env_specs(env) + noop_reset_env = NoopResetEnv(noops=noops, random=False) + transformed_env = TransformedEnv(env, noop_reset_env) + if noops <= max_steps: # Normal behavior. + result = transformed_env.reset() + assert result["observation"] == noops + elif noops > max_steps: # Raise error as reset limit exceeded. + with pytest.raises(RuntimeError): + transformed_env.reset() + class TestObservationNorm(TransformBase): @pytest.mark.parametrize( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 11715e56033..6a726883971 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3123,8 +3123,11 @@ class NoopResetEnv(Transform): env (EnvBase): env on which the random actions have to be performed. Can be the same env as the one provided to the TransformedEnv class - noops (int, optional): number of actions performed after reset. - Default is `30`. + noops (int, optional): upper-bound on the number of actions + performed after reset. Default is `30`. + If noops is too high such that it results in the env being + done or truncated before the all the noops are applied, + in multiple trials, the transform raises a RuntimeError. random (bool, optional): if False, the number of random ops will always be equal to the noops value. If True, the number of random actions will be randomly selected between 0 and noops. @@ -3133,10 +3136,7 @@ class NoopResetEnv(Transform): """ def __init__(self, noops: int = 30, random: bool = True): - """Sample initial states by taking random number of no-ops on reset. - - No-op is assumed to be action 0. - """ + """Sample initial states by taking random number of no-ops on reset.""" super().__init__([]) self.noops = noops self.random = random @@ -3171,31 +3171,31 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: noops = ( self.noops if not self.random else torch.randint(self.noops, (1,)).item() ) - trial = 0 - while True: + trial = 0 + while trial <= _MAX_NOOPS_TRIALS: i = 0 + while i < noops: i += 1 tensordict = parent.rand_step(tensordict) tensordict = step_mdp(tensordict, exclude_done=False) - if tensordict.get(done_key): + if tensordict.get(done_key) or tensordict.get( + "truncated", torch.tensor(False) + ): tensordict = parent.reset(td_reset.clone(False)) break else: break trial += 1 - if trial > _MAX_NOOPS_TRIALS: - tensordict = parent.rand_step(tensordict) - if tensordict.get(("next", done_key)): - raise RuntimeError( - f"parent is still done after a single random step (i={i})." - ) - break - if tensordict.get(done_key): - raise RuntimeError("NoopResetEnv concluded with done environment") + else: + raise RuntimeError( + f"Parent env was repeatedly done or truncated" + f" before the sampled number of noops (={noops}) could be applied. " + ) + return tensordict.exclude(reward_key, inplace=True) def __repr__(self) -> str: