diff --git a/test/test_env.py b/test/test_env.py index adce3e48326..32e9ffccb55 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -311,6 +311,62 @@ def test_rollout_predictability(device): ).all() +# Check that the "terminated" key is filled in automatically if only the "done" +# key is provided in `_step`. +def test_done_key_completion_done(): + class DoneEnv(CountingEnv): + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + self.count += 1 + tensordict = TensorDict( + source={ + "observation": self.count.clone(), + "done": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict + + env = DoneEnv(max_steps=torch.tensor([[0], [1]]), batch_size=(2,)) + td = env.reset() + env.rand_action(td) + td = env.step(td) + assert torch.equal(td[("next", "done")], torch.tensor([[True], [False]])) + assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]])) + + +# Check that the "done" key is filled in automatically if only the "terminated" +# key is provided in `_step`. +def test_done_key_completion_terminated(): + class TerminatedEnv(CountingEnv): + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + self.count += 1 + tensordict = TensorDict( + source={ + "observation": self.count.clone(), + "terminated": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict + + env = TerminatedEnv(max_steps=torch.tensor([[0], [1]]), batch_size=(2,)) + td = env.reset() + env.rand_action(td) + td = env.step(td) + assert torch.equal(td[("next", "done")], torch.tensor([[True], [False]])) + assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]])) + + @pytest.mark.skipif(not _has_gym, reason="no gym") @pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED]) @pytest.mark.parametrize("frame_skip", [1]) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index e30de3534d9..eaf701fde34 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1515,7 +1515,8 @@ def _complete_done( shape = (*leading_dim, *item.shape) if val is not None: if val.shape != shape: - data.set(key, val.reshape(shape)) + val = val.reshape(shape) + data.set(key, val) vals[key] = val if len(vals) < i + 1: @@ -1535,6 +1536,7 @@ def _complete_done( "Cannot infer the value of terminated when only done and truncated are present." ) data.set("terminated", val) + data_keys.add("terminated") elif ( key == "terminated" and val is not None @@ -1542,11 +1544,10 @@ def _complete_done( and "done" not in data_keys ): if "truncated" in data_keys: - done = val | data.get("truncated") - data.set("done", done) - else: - data.set("done", val) - elif val is None: + val = val | data.get("truncated") + data.set("done", val) + data_keys.add("done") + elif val is None and key not in data_keys: # we must keep this here: we only want to fill with 0s if we're sure # done should not be copied to terminated or terminated to done # in this case, just fill with 0s