From bb0ddb5e3397df10ca7bd22145135b1b9ae3f3a1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 12 Aug 2024 07:46:09 -0400 Subject: [PATCH] [Feature] break_when_all_done in rollout ghstack-source-id: 103fd4f3ba8eb8d6e916b6921ab14f95c920f3b5 Pull Request resolved: https://github.com/pytorch/rl/pull/2381 --- torchrl/envs/batched_envs.py | 17 ++++++--- torchrl/envs/common.py | 74 +++++++++++++++++++++++++++++------- 2 files changed, 72 insertions(+), 19 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 73ecdba64a9..eff1808af34 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1077,8 +1077,9 @@ def _step( if partial_steps is not None and partial_steps.all(): partial_steps = None if partial_steps is not None: + partial_steps = partial_steps.view(tensordict.shape) tensordict = tensordict[partial_steps] - workers_range = partial_steps.nonzero().squeeze().tolist() + workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist() tensordict_in = tensordict else: workers_range = range(self.num_workers) @@ -1468,8 +1469,9 @@ def _step_and_maybe_reset_no_buffers( if partial_steps is not None and partial_steps.all(): partial_steps = None if partial_steps is not None: + partial_steps = partial_steps.view(tensordict.shape) tensordict = tensordict[partial_steps] - workers_range = partial_steps.nonzero().squeeze().tolist() + workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist() else: workers_range = range(self.num_workers) @@ -1518,7 +1520,8 @@ def step_and_maybe_reset( if partial_steps is not None and partial_steps.all(): partial_steps = None if partial_steps is not None: - workers_range = partial_steps.nonzero().squeeze().tolist() + partial_steps = partial_steps.view(tensordict.shape) + workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist() shared_tensordict_parent = TensorDict.lazy_stack( [self.shared_tensordict_parent[i] for i in workers_range] ) @@ -1648,8 +1651,9 @@ def _step_no_buffers( if partial_steps is not None and partial_steps.all(): partial_steps = None if partial_steps is not None: + partial_steps = partial_steps.view(tensordict.shape) tensordict = tensordict[partial_steps] - workers_range = partial_steps.nonzero().squeeze().tolist() + workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist() else: workers_range = range(self.num_workers) @@ -1691,7 +1695,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if partial_steps is not None and partial_steps.all(): partial_steps = None if partial_steps is not None: - workers_range = partial_steps.nonzero().squeeze().tolist() + partial_steps = partial_steps.view(tensordict.shape) + workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist() shared_tensordict_parent = TensorDict.lazy_stack( [self.shared_tensordicts[i] for i in workers_range] ) @@ -1723,7 +1728,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # if we have input "next" data (eg, RNNs which pass the next state) # the sub-envs will need to process them through step_and_maybe_reset. # We keep track of which keys are present to let the worker know what - # should be passd to the env (we don't want to pass done states for instance) + # should be passed to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) data = [ {"next_td_passthrough_keys": next_td_keys} diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 16ce7d0d534..2aacf76168b 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2317,9 +2317,11 @@ def rollout( max_steps: int, policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, callback: Optional[Callable[[TensorDictBase, ...], Any]] = None, + *, auto_reset: bool = True, auto_cast_to_device: bool = False, - break_when_any_done: bool = True, + break_when_any_done: bool | None = None, + break_when_all_done: bool | None = None, return_contiguous: bool = True, tensordict: Optional[TensorDictBase] = None, set_truncated: bool = False, @@ -2342,6 +2344,8 @@ def rollout( TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user responsibility to save any result within the callback call if data needs to be carried over beyond the call to ``rollout``. + + Keyword Args: auto_reset (bool, optional): if ``True``, resets automatically the environment if it is in a done state when the rollout is initiated. Default is ``True``. @@ -2349,6 +2353,7 @@ def rollout( policy device before the policy is used. Default is ``False``. break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is called on the sub-envs that are done. Default is True. + break_when_all_done (bool): TODO return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial tensordict must be provided. Rollout will check if this tensordict has done flags and reset the @@ -2545,6 +2550,19 @@ def rollout( ... ) """ + if break_when_any_done is None: # True by default + if break_when_all_done: # all overrides + break_when_any_done = False + else: + break_when_any_done = True + if break_when_all_done is None: + # There is no case where break_when_all_done is True by default + break_when_all_done = False + if break_when_all_done and break_when_any_done: + raise TypeError( + "Cannot have both break_when_all_done and break_when_any_done True at the same time." + ) + if policy is not None: policy = _make_compatible_policy( policy, self.observation_spec, env=self, fast_wrap=True @@ -2578,8 +2596,12 @@ def rollout( "env_device": env_device, "callback": callback, } - if break_when_any_done: - tensordicts = self._rollout_stop_early(**kwargs) + if break_when_any_done or break_when_all_done: + tensordicts = self._rollout_stop_early( + break_when_all_done=break_when_all_done, + break_when_any_done=break_when_any_done, + **kwargs, + ) else: tensordicts = self._rollout_nonstop(**kwargs) batch_size = self.batch_size if tensordict is None else tensordict.batch_size @@ -2639,6 +2661,8 @@ def _step_mdp(self): def _rollout_stop_early( self, *, + break_when_any_done, + break_when_all_done, tensordict, auto_cast_to_device, max_steps, @@ -2651,6 +2675,7 @@ def _rollout_stop_early( if auto_cast_to_device: sync_func = _get_sync_func(policy_device, env_device) tensordicts = [] + partial_steps = True for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: @@ -2668,6 +2693,14 @@ def _rollout_stop_early( tensordict.clear_device_() tensordict = self.step(tensordict) td_append = tensordict.copy() + if break_when_all_done: + if partial_steps is not True: + # At least one partial step has been done + del td_append["_partial_steps"] + td_append = torch.where( + partial_steps.view(td_append.shape), td_append, tensordicts[-1] + ) + tensordicts.append(td_append) if i == max_steps - 1: @@ -2675,16 +2708,31 @@ def _rollout_stop_early( break tensordict = self._step_mdp(tensordict) - # done and truncated are in done_keys - # We read if any key is done. - any_done = _terminated_or_truncated( - tensordict, - full_done_spec=self.output_spec["full_done_spec"], - key=None, - ) - - if any_done: - break + if break_when_any_done: + # done and truncated are in done_keys + # We read if any key is done. + any_done = _terminated_or_truncated( + tensordict, + full_done_spec=self.output_spec["full_done_spec"], + key=None, + ) + if any_done: + break + else: + _terminated_or_truncated( + tensordict, + full_done_spec=self.output_spec["full_done_spec"], + key="_partial_steps", + write_full_false=False, + ) + partial_step_curr = tensordict.get("_partial_steps", None) + if partial_step_curr is not None: + partial_step_curr = ~partial_step_curr + partial_steps = partial_steps & partial_step_curr + if partial_steps is not True: + if not partial_steps.any(): + break + tensordict.set("_partial_steps", partial_steps) if callback is not None: callback(self, tensordict)