From e53eb738fea2a924903f89d3faf4db7eb096c721 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Feb 2024 17:43:20 +0000 Subject: [PATCH] [BugFix] Fix _reset data passing in parallel env (#1880) --- test/test_env.py | 23 +++++++++++++++++++++++ torchrl/envs/batched_envs.py | 25 +++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index e316e1ae10f..15bcf5e3fcb 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -624,6 +624,29 @@ def test_parallel_env_with_policy( # env_serial.close() env0.close() + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + @pytest.mark.parametrize("heterogeneous", [False, True]) + def test_transform_env_transform_no_device(self, heterogeneous): + # Tests non-regression on 1865 + def make_env(): + return TransformedEnv( + ContinuousActionVecMockEnv(), StepCounter(max_steps=3) + ) + + if heterogeneous: + make_envs = [EnvCreator(make_env), EnvCreator(make_env)] + else: + make_envs = make_env + penv = ParallelEnv(2, make_envs) + r = penv.rollout(6, break_when_any_done=False) + assert r.shape == (2, 6) + try: + env = TransformedEnv(penv) + r = env.rollout(6, break_when_any_done=False) + assert r.shape == (2, 6) + finally: + penv.close() + @pytest.mark.skipif(not _has_gym, reason="no gym") @pytest.mark.parametrize( "env_name", diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index cfb977d4bb2..2a955af1261 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1284,7 +1284,22 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_, keys_to_update=list(self._selected_reset_keys) ) continue - out = ("reset", tensordict_) + if tensordict_ is not None: + tdkeys = list(tensordict_.keys(True, True)) + + # This way we can avoid calling select over all the keys in the shared tensordict + def tentative_update(val, other): + if other is not None: + val.copy_(other) + return val + + self.shared_tensordicts[i].apply_( + tentative_update, tensordict_, default=None + ) + out = ("reset", tdkeys) + else: + out = ("reset", False) + channel.send(out) workers.append(i) @@ -1509,7 +1524,13 @@ def look_for_cuda(tensor, has_cuda=has_cuda): torchrl_logger.info(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") - cur_td = env.reset(tensordict=data) + # we use 'data' to pass the keys that we need to pass to reset, + # because passing the entire buffer may have unwanted consequences + cur_td = env.reset( + tensordict=root_shared_tensordict.select(*data, strict=False) + if data + else None + ) shared_tensordict.update_( cur_td, keys_to_update=list(_selected_reset_keys),