Skip to content

Commit

Permalink
[BugFix] Fix NoopResetEnv behavior when trials exceeded. (pytorch#1477)
Browse files Browse the repository at this point in the history
  • Loading branch information
skandermoalla authored Sep 1, 2023
1 parent 4c8f91f commit dbab7bb
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 18 deletions.
19 changes: 19 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,25 @@ def _step(
return tensordict


class IncrementingEnv(CountingEnv):
# Same as CountingEnv but always increments the count by 1 regardless of the action.
def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
self.count += 1 # The only difference with CountingEnv.
tensordict = TensorDict(
source={
"observation": self.count.clone(),
"done": self.count > self.max_steps,
"reward": torch.zeros_like(self.count, dtype=torch.float),
},
batch_size=self.batch_size,
device=self.device,
)
return tensordict.select().set("next", tensordict)


class NestedCountingEnv(CountingEnv):
# an env with nested reward and done states
def __init__(
Expand Down
15 changes: 15 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CountingBatchedEnv,
CountingEnvCountPolicy,
DiscreteActionConvMockEnvNumpy,
IncrementingEnv,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
NestedCountingEnv,
Expand Down Expand Up @@ -3172,6 +3173,20 @@ def test_noop_reset_env_error(self, random, device, compose):
):
transformed_env.reset()

@pytest.mark.parametrize("noops", [0, 2, 8])
@pytest.mark.parametrize("max_steps", [0, 5, 9])
def test_noop_reset_limit_exceeded(self, noops, max_steps):
env = IncrementingEnv(max_steps=max_steps)
check_env_specs(env)
noop_reset_env = NoopResetEnv(noops=noops, random=False)
transformed_env = TransformedEnv(env, noop_reset_env)
if noops <= max_steps: # Normal behavior.
result = transformed_env.reset()
assert result["observation"] == noops
elif noops > max_steps: # Raise error as reset limit exceeded.
with pytest.raises(RuntimeError):
transformed_env.reset()


class TestObservationNorm(TransformBase):
@pytest.mark.parametrize(
Expand Down
36 changes: 18 additions & 18 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3123,8 +3123,11 @@ class NoopResetEnv(Transform):
env (EnvBase): env on which the random actions have to be
performed. Can be the same env as the one provided to the
TransformedEnv class
noops (int, optional): number of actions performed after reset.
Default is `30`.
noops (int, optional): upper-bound on the number of actions
performed after reset. Default is `30`.
If noops is too high such that it results in the env being
done or truncated before the all the noops are applied,
in multiple trials, the transform raises a RuntimeError.
random (bool, optional): if False, the number of random ops will
always be equal to the noops value. If True, the number of
random actions will be randomly selected between 0 and noops.
Expand All @@ -3133,10 +3136,7 @@ class NoopResetEnv(Transform):
"""

def __init__(self, noops: int = 30, random: bool = True):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
"""Sample initial states by taking random number of no-ops on reset."""
super().__init__([])
self.noops = noops
self.random = random
Expand Down Expand Up @@ -3171,31 +3171,31 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
noops = (
self.noops if not self.random else torch.randint(self.noops, (1,)).item()
)
trial = 0

while True:
trial = 0
while trial <= _MAX_NOOPS_TRIALS:
i = 0

while i < noops:
i += 1
tensordict = parent.rand_step(tensordict)
tensordict = step_mdp(tensordict, exclude_done=False)
if tensordict.get(done_key):
if tensordict.get(done_key) or tensordict.get(
"truncated", torch.tensor(False)
):
tensordict = parent.reset(td_reset.clone(False))
break
else:
break

trial += 1
if trial > _MAX_NOOPS_TRIALS:
tensordict = parent.rand_step(tensordict)
if tensordict.get(("next", done_key)):
raise RuntimeError(
f"parent is still done after a single random step (i={i})."
)
break

if tensordict.get(done_key):
raise RuntimeError("NoopResetEnv concluded with done environment")
else:
raise RuntimeError(
f"Parent env was repeatedly done or truncated"
f" before the sampled number of noops (={noops}) could be applied. "
)

return tensordict.exclude(reward_key, inplace=True)

def __repr__(self) -> str:
Expand Down

0 comments on commit dbab7bb

Please sign in to comment.