Skip to content

Commit

Permalink
[Feature] break_when_all_done in rollout
Browse files Browse the repository at this point in the history
ghstack-source-id: 103fd4f3ba8eb8d6e916b6921ab14f95c920f3b5
Pull Request resolved: pytorch#2381
  • Loading branch information
vmoens committed Aug 12, 2024
1 parent 430f1bd commit bb0ddb5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 19 deletions.
17 changes: 11 additions & 6 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,8 +1077,9 @@ def _step(
if partial_steps is not None and partial_steps.all():
partial_steps = None
if partial_steps is not None:
partial_steps = partial_steps.view(tensordict.shape)
tensordict = tensordict[partial_steps]
workers_range = partial_steps.nonzero().squeeze().tolist()
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
tensordict_in = tensordict
else:
workers_range = range(self.num_workers)
Expand Down Expand Up @@ -1468,8 +1469,9 @@ def _step_and_maybe_reset_no_buffers(
if partial_steps is not None and partial_steps.all():
partial_steps = None
if partial_steps is not None:
partial_steps = partial_steps.view(tensordict.shape)
tensordict = tensordict[partial_steps]
workers_range = partial_steps.nonzero().squeeze().tolist()
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
else:
workers_range = range(self.num_workers)

Expand Down Expand Up @@ -1518,7 +1520,8 @@ def step_and_maybe_reset(
if partial_steps is not None and partial_steps.all():
partial_steps = None
if partial_steps is not None:
workers_range = partial_steps.nonzero().squeeze().tolist()
partial_steps = partial_steps.view(tensordict.shape)
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
shared_tensordict_parent = TensorDict.lazy_stack(
[self.shared_tensordict_parent[i] for i in workers_range]
)
Expand Down Expand Up @@ -1648,8 +1651,9 @@ def _step_no_buffers(
if partial_steps is not None and partial_steps.all():
partial_steps = None
if partial_steps is not None:
partial_steps = partial_steps.view(tensordict.shape)
tensordict = tensordict[partial_steps]
workers_range = partial_steps.nonzero().squeeze().tolist()
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
else:
workers_range = range(self.num_workers)

Expand Down Expand Up @@ -1691,7 +1695,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
if partial_steps is not None and partial_steps.all():
partial_steps = None
if partial_steps is not None:
workers_range = partial_steps.nonzero().squeeze().tolist()
partial_steps = partial_steps.view(tensordict.shape)
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
shared_tensordict_parent = TensorDict.lazy_stack(
[self.shared_tensordicts[i] for i in workers_range]
)
Expand Down Expand Up @@ -1723,7 +1728,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# if we have input "next" data (eg, RNNs which pass the next state)
# the sub-envs will need to process them through step_and_maybe_reset.
# We keep track of which keys are present to let the worker know what
# should be passd to the env (we don't want to pass done states for instance)
# should be passed to the env (we don't want to pass done states for instance)
next_td_keys = list(next_td_passthrough.keys(True, True))
data = [
{"next_td_passthrough_keys": next_td_keys}
Expand Down
74 changes: 61 additions & 13 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2317,9 +2317,11 @@ def rollout(
max_steps: int,
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
*,
auto_reset: bool = True,
auto_cast_to_device: bool = False,
break_when_any_done: bool = True,
break_when_any_done: bool | None = None,
break_when_all_done: bool | None = None,
return_contiguous: bool = True,
tensordict: Optional[TensorDictBase] = None,
set_truncated: bool = False,
Expand All @@ -2342,13 +2344,16 @@ def rollout(
TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user
responsibility to save any result within the callback call if data needs to be carried over beyond
the call to ``rollout``.
Keyword Args:
auto_reset (bool, optional): if ``True``, resets automatically the environment
if it is in a done state when the rollout is initiated.
Default is ``True``.
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
policy device before the policy is used. Default is ``False``.
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is
called on the sub-envs that are done. Default is True.
break_when_all_done (bool): TODO
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
Expand Down Expand Up @@ -2545,6 +2550,19 @@ def rollout(
... )
"""
if break_when_any_done is None: # True by default
if break_when_all_done: # all overrides
break_when_any_done = False
else:
break_when_any_done = True
if break_when_all_done is None:
# There is no case where break_when_all_done is True by default
break_when_all_done = False
if break_when_all_done and break_when_any_done:
raise TypeError(
"Cannot have both break_when_all_done and break_when_any_done True at the same time."
)

if policy is not None:
policy = _make_compatible_policy(
policy, self.observation_spec, env=self, fast_wrap=True
Expand Down Expand Up @@ -2578,8 +2596,12 @@ def rollout(
"env_device": env_device,
"callback": callback,
}
if break_when_any_done:
tensordicts = self._rollout_stop_early(**kwargs)
if break_when_any_done or break_when_all_done:
tensordicts = self._rollout_stop_early(
break_when_all_done=break_when_all_done,
break_when_any_done=break_when_any_done,
**kwargs,
)
else:
tensordicts = self._rollout_nonstop(**kwargs)
batch_size = self.batch_size if tensordict is None else tensordict.batch_size
Expand Down Expand Up @@ -2639,6 +2661,8 @@ def _step_mdp(self):
def _rollout_stop_early(
self,
*,
break_when_any_done,
break_when_all_done,
tensordict,
auto_cast_to_device,
max_steps,
Expand All @@ -2651,6 +2675,7 @@ def _rollout_stop_early(
if auto_cast_to_device:
sync_func = _get_sync_func(policy_device, env_device)
tensordicts = []
partial_steps = True
for i in range(max_steps):
if auto_cast_to_device:
if policy_device is not None:
Expand All @@ -2668,23 +2693,46 @@ def _rollout_stop_early(
tensordict.clear_device_()
tensordict = self.step(tensordict)
td_append = tensordict.copy()
if break_when_all_done:
if partial_steps is not True:
# At least one partial step has been done
del td_append["_partial_steps"]
td_append = torch.where(
partial_steps.view(td_append.shape), td_append, tensordicts[-1]
)

tensordicts.append(td_append)

if i == max_steps - 1:
# we don't truncate as one could potentially continue the run
break
tensordict = self._step_mdp(tensordict)

# done and truncated are in done_keys
# We read if any key is done.
any_done = _terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key=None,
)

if any_done:
break
if break_when_any_done:
# done and truncated are in done_keys
# We read if any key is done.
any_done = _terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key=None,
)
if any_done:
break
else:
_terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key="_partial_steps",
write_full_false=False,
)
partial_step_curr = tensordict.get("_partial_steps", None)
if partial_step_curr is not None:
partial_step_curr = ~partial_step_curr
partial_steps = partial_steps & partial_step_curr
if partial_steps is not True:
if not partial_steps.any():
break
tensordict.set("_partial_steps", partial_steps)

if callback is not None:
callback(self, tensordict)
Expand Down

0 comments on commit bb0ddb5

Please sign in to comment.