From 4c8f91f243ef7e44f5056d105e0b531193a463ee Mon Sep 17 00:00:00 2001
From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
Date: Fri, 1 Sep 2023 11:56:13 +0100
Subject: [PATCH 1/2] [BugFix] Fix RewardSum spec transform to mimic reward
spec (#1478)
Signed-off-by: Matteo Bettini
---
torchrl/envs/transforms/transforms.py | 33 ++++++++++++++++++++-------
1 file changed, 25 insertions(+), 8 deletions(-)
diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py
index 81781b04dec..11715e56033 100644
--- a/torchrl/envs/transforms/transforms.py
+++ b/torchrl/envs/transforms/transforms.py
@@ -16,6 +16,7 @@
import torch
from tensordict import unravel_key, unravel_key_list
+from tensordict._tensordict import _unravel_key_to_tuple
from tensordict.nn import dispatch
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import expand_as_right, NestedKey
@@ -3835,15 +3836,31 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
state_spec = input_spec["full_state_spec"]
if state_spec is None:
state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device)
- reward_spec = self.parent.reward_spec
+ reward_spec = self.parent.output_spec["full_reward_spec"]
+ reward_spec_keys = list(reward_spec.keys(True, True))
# Define episode specs for all out_keys
- for out_key in self.out_keys:
- episode_spec = UnboundedContinuousTensorSpec(
- shape=reward_spec.shape,
- device=reward_spec.device,
- dtype=reward_spec.dtype,
- )
- state_spec[out_key] = episode_spec
+ for in_key, out_key in zip(self.in_keys, self.out_keys):
+ if (
+ in_key in reward_spec_keys
+ ): # if this out_key has a corresponding key in reward_spec
+ out_key = _unravel_key_to_tuple(out_key)
+ temp_state_spec = state_spec
+ temp_rew_spec = reward_spec
+ for sub_key in out_key[:-1]:
+ if (
+ not isinstance(temp_rew_spec, CompositeSpec)
+ or sub_key not in temp_rew_spec.keys()
+ ):
+ break
+ if sub_key not in temp_state_spec.keys():
+ temp_state_spec[sub_key] = temp_rew_spec[sub_key].empty()
+ temp_rew_spec = temp_rew_spec[sub_key]
+ temp_state_spec = temp_state_spec[sub_key]
+ state_spec[out_key] = reward_spec[in_key].clone()
+ else:
+ raise ValueError(
+ f"The in_key: {in_key} is not present in the reward spec {reward_spec}."
+ )
input_spec["full_state_spec"] = state_spec
return input_spec
From dbab7bb593ff33ef0423d607329c3687b13833d4 Mon Sep 17 00:00:00 2001
From: Skander Moalla <37197319+skandermoalla@users.noreply.github.com>
Date: Fri, 1 Sep 2023 12:58:54 +0200
Subject: [PATCH 2/2] [BugFix] Fix NoopResetEnv behavior when trials exceeded.
(#1477)
---
test/mocking_classes.py | 19 ++++++++++++++
test/test_transforms.py | 15 +++++++++++
torchrl/envs/transforms/transforms.py | 36 +++++++++++++--------------
3 files changed, 52 insertions(+), 18 deletions(-)
diff --git a/test/mocking_classes.py b/test/mocking_classes.py
index 91676699997..b3ccb9caf51 100644
--- a/test/mocking_classes.py
+++ b/test/mocking_classes.py
@@ -1037,6 +1037,25 @@ def _step(
return tensordict
+class IncrementingEnv(CountingEnv):
+ # Same as CountingEnv but always increments the count by 1 regardless of the action.
+ def _step(
+ self,
+ tensordict: TensorDictBase,
+ ) -> TensorDictBase:
+ self.count += 1 # The only difference with CountingEnv.
+ tensordict = TensorDict(
+ source={
+ "observation": self.count.clone(),
+ "done": self.count > self.max_steps,
+ "reward": torch.zeros_like(self.count, dtype=torch.float),
+ },
+ batch_size=self.batch_size,
+ device=self.device,
+ )
+ return tensordict.select().set("next", tensordict)
+
+
class NestedCountingEnv(CountingEnv):
# an env with nested reward and done states
def __init__(
diff --git a/test/test_transforms.py b/test/test_transforms.py
index f20e505145a..40037085e8d 100644
--- a/test/test_transforms.py
+++ b/test/test_transforms.py
@@ -26,6 +26,7 @@
CountingBatchedEnv,
CountingEnvCountPolicy,
DiscreteActionConvMockEnvNumpy,
+ IncrementingEnv,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
NestedCountingEnv,
@@ -3172,6 +3173,20 @@ def test_noop_reset_env_error(self, random, device, compose):
):
transformed_env.reset()
+ @pytest.mark.parametrize("noops", [0, 2, 8])
+ @pytest.mark.parametrize("max_steps", [0, 5, 9])
+ def test_noop_reset_limit_exceeded(self, noops, max_steps):
+ env = IncrementingEnv(max_steps=max_steps)
+ check_env_specs(env)
+ noop_reset_env = NoopResetEnv(noops=noops, random=False)
+ transformed_env = TransformedEnv(env, noop_reset_env)
+ if noops <= max_steps: # Normal behavior.
+ result = transformed_env.reset()
+ assert result["observation"] == noops
+ elif noops > max_steps: # Raise error as reset limit exceeded.
+ with pytest.raises(RuntimeError):
+ transformed_env.reset()
+
class TestObservationNorm(TransformBase):
@pytest.mark.parametrize(
diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py
index 11715e56033..6a726883971 100644
--- a/torchrl/envs/transforms/transforms.py
+++ b/torchrl/envs/transforms/transforms.py
@@ -3123,8 +3123,11 @@ class NoopResetEnv(Transform):
env (EnvBase): env on which the random actions have to be
performed. Can be the same env as the one provided to the
TransformedEnv class
- noops (int, optional): number of actions performed after reset.
- Default is `30`.
+ noops (int, optional): upper-bound on the number of actions
+ performed after reset. Default is `30`.
+ If noops is too high such that it results in the env being
+ done or truncated before the all the noops are applied,
+ in multiple trials, the transform raises a RuntimeError.
random (bool, optional): if False, the number of random ops will
always be equal to the noops value. If True, the number of
random actions will be randomly selected between 0 and noops.
@@ -3133,10 +3136,7 @@ class NoopResetEnv(Transform):
"""
def __init__(self, noops: int = 30, random: bool = True):
- """Sample initial states by taking random number of no-ops on reset.
-
- No-op is assumed to be action 0.
- """
+ """Sample initial states by taking random number of no-ops on reset."""
super().__init__([])
self.noops = noops
self.random = random
@@ -3171,31 +3171,31 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
noops = (
self.noops if not self.random else torch.randint(self.noops, (1,)).item()
)
- trial = 0
- while True:
+ trial = 0
+ while trial <= _MAX_NOOPS_TRIALS:
i = 0
+
while i < noops:
i += 1
tensordict = parent.rand_step(tensordict)
tensordict = step_mdp(tensordict, exclude_done=False)
- if tensordict.get(done_key):
+ if tensordict.get(done_key) or tensordict.get(
+ "truncated", torch.tensor(False)
+ ):
tensordict = parent.reset(td_reset.clone(False))
break
else:
break
trial += 1
- if trial > _MAX_NOOPS_TRIALS:
- tensordict = parent.rand_step(tensordict)
- if tensordict.get(("next", done_key)):
- raise RuntimeError(
- f"parent is still done after a single random step (i={i})."
- )
- break
- if tensordict.get(done_key):
- raise RuntimeError("NoopResetEnv concluded with done environment")
+ else:
+ raise RuntimeError(
+ f"Parent env was repeatedly done or truncated"
+ f" before the sampled number of noops (={noops}) could be applied. "
+ )
+
return tensordict.exclude(reward_key, inplace=True)
def __repr__(self) -> str: