diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 32294a25edd..be24a06e39c 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1121,7 +1121,10 @@ def _maybe_set_truncated(self, final_rollout): truncated = final_rollout["next", truncated_key] truncated[last_step] = True final_rollout["next", truncated_key] = truncated - final_rollout["next", _replace_last(truncated_key, "done")] = truncated + done = final_rollout["next", _replace_last(truncated_key, "done")] + final_rollout["next", _replace_last(truncated_key, "done")] = ( + done | truncated + ) return final_rollout @torch.no_grad()