From cdcd9f32a20130f6a40999f218c2163cc0689e4c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 29 Jul 2024 20:08:03 +0100 Subject: [PATCH] [BugFix] `done = done | truncated` in collector (#2333) --- torchrl/collectors/collectors.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 32294a25edd..be24a06e39c 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1121,7 +1121,10 @@ def _maybe_set_truncated(self, final_rollout): truncated = final_rollout["next", truncated_key] truncated[last_step] = True final_rollout["next", truncated_key] = truncated - final_rollout["next", _replace_last(truncated_key, "done")] = truncated + done = final_rollout["next", _replace_last(truncated_key, "done")] + final_rollout["next", _replace_last(truncated_key, "done")] = ( + done | truncated + ) return final_rollout @torch.no_grad()