Skip to content

Commit

Permalink
[Feature] Passing lists of keyword arguments in reset for batched…
Browse files Browse the repository at this point in the history
… envs (#2076)
  • Loading branch information
vmoens authored Apr 11, 2024
1 parent c9296c1 commit 0d00748
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ class BatchedEnvBase(EnvBase):
TorchRL if not initiated differently before first import).
To be used only with :class:`~torchrl.envs.ParallelEnv` subclasses.
.. note::
One can pass keyword arguments to each sub-environments using the following
technique: every keyword argument in :meth:`~.reset` will be passed to each
environment except for the ``list_of_kwargs`` argument which, if present,
should contain a list of the same length as the number of workers with the
worker-specific keyword arguments stored in a dictionary.
If a partial reset is queried, the element of ``list_of_kwargs`` corresponding
to sub-environments that are not reset will be ignored.
Examples:
>>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
>>> make_env = EnvCreator(lambda: GymEnv("Pendulum-v1")) # EnvCreator ensures that the env is sharable. Optional in most cases.
Expand Down Expand Up @@ -868,6 +877,11 @@ def set_seed(

@_check_start
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
list_of_kwargs = kwargs.pop("list_of_kwargs", [kwargs] * self.num_workers)
if kwargs is not list_of_kwargs[0] and kwargs:
# this means that kwargs had more than one element and that a list was provided
for elt in list_of_kwargs:
elt.update(kwargs)
if tensordict is not None:
needs_resetting = _aggregate_end_of_traj(
tensordict, reset_keys=self.reset_keys
Expand Down Expand Up @@ -907,7 +921,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:

for i, tensordict_ in tds:
_env = self._envs[i]
_td = _env.reset(tensordict=tensordict_, **kwargs)
_td = _env.reset(tensordict=tensordict_, **list_of_kwargs[i])
try:
self.shared_tensordicts[i].update_(
_td,
Expand Down Expand Up @@ -1439,6 +1453,11 @@ def select_and_clone(name, tensor):
@torch.no_grad()
@_check_start
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
list_of_kwargs = kwargs.pop("list_of_kwargs", [kwargs] * self.num_workers)
if kwargs is not list_of_kwargs[0] and kwargs:
# this means that kwargs had more than one element and that a list was provided
for elt in list_of_kwargs:
elt.update(kwargs)
if tensordict is not None:
needs_resetting = _aggregate_end_of_traj(
tensordict, reset_keys=self.reset_keys
Expand Down Expand Up @@ -1499,9 +1518,9 @@ def tentative_update(val, other):
self.shared_tensordicts[i].apply_(
tentative_update, tensordict_, default=None
)
out = ("reset", tdkeys)
out = ("reset", (tdkeys, list_of_kwargs[i]))
else:
out = ("reset", False)
out = ("reset", (False, list_of_kwargs[i]))
outs.append((i, out))

self._sync_m2w()
Expand Down Expand Up @@ -1740,10 +1759,14 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
raise RuntimeError("call 'init' before resetting")
# we use 'data' to pass the keys that we need to pass to reset,
# because passing the entire buffer may have unwanted consequences
selected_reset_keys, reset_kwargs = data
cur_td = env.reset(
tensordict=root_shared_tensordict.select(*data, strict=False)
if data
else None
tensordict=root_shared_tensordict.select(
*selected_reset_keys, strict=False
)
if selected_reset_keys
else None,
**reset_kwargs,
)
shared_tensordict.update_(
cur_td,
Expand Down

0 comments on commit 0d00748

Please sign in to comment.