Skip to content

Commit

Permalink
[Doc] Dynamic envs (pytorch#2191)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 31, 2024
1 parent 7b145b5 commit 1405600
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 15 deletions.
110 changes: 110 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,116 @@ observation of an episode should be replaced with some placeholder or not.
[-3.2109e-02, 5.8997e-01, -6.1507e-02, -8.7363e-01],
[-2.0310e-02, 3.9574e-01, -7.8980e-02, -6.0090e-01]])

Dynamic Specs
-------------

.. _dynamic_envs:

Running environments in parallel is usually done via the creation of memory buffers used to pass information from one
process to another. In some cases, it may be impossible to forecast whether and environment will or will not have
consistent inputs or outputs during a rollout, as their shape may be variable. We refer to this as dynamic specs.

TorchRL is capable of handling dynamic specs, but the batched environments and collectors will need to be made
aware of this feature. Note that, in practice, this is detected automatically.

To indicate that a tensor will have a variable size along a dimension, one can set the size value as ``-1`` for the
desired dimensions. Because the data cannot be stacked contiguously, calls to ``env.rollout`` need to be made with
the ``return_contiguous=False`` argument.
Here is a working example:

>>> from torchrl.envs import EnvBase
>>> from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, BoundedTensorSpec, BinaryDiscreteTensorSpec
>>> import torch
>>> from tensordict import TensorDict, TensorDictBase
>>>
>>> class EnvWithDynamicSpec(EnvBase):
... def __init__(self, max_count=5):
... super().__init__(batch_size=())
... self.observation_spec = CompositeSpec(
... observation=UnboundedContinuousTensorSpec(shape=(3, -1, 2)),
... )
... self.action_spec = BoundedTensorSpec(low=-1, high=1, shape=(2,))
... self.full_done_spec = CompositeSpec(
... done=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
... terminated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
... truncated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
... )
... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float)
... self.count = 0
... self.max_count = max_count
...
... def _reset(self, tensordict=None):
... self.count = 0
... data = TensorDict(
... {
... "observation": torch.full(
... (3, self.count + 1, 2),
... self.count,
... dtype=self.observation_spec["observation"].dtype,
... )
... }
... )
... data.update(self.done_spec.zero())
... return data
...
... def _step(
... self,
... tensordict: TensorDictBase,
... ) -> TensorDictBase:
... self.count += 1
... done = self.count >= self.max_count
... observation = TensorDict(
... {
... "observation": torch.full(
... (3, self.count + 1, 2),
... self.count,
... dtype=self.observation_spec["observation"].dtype,
... )
... }
... )
... done = self.full_done_spec.zero() | done
... reward = self.full_reward_spec.zero()
... return observation.update(done).update(reward)
...
... def _set_seed(self, seed: Optional[int]):
... self.manual_seed = seed
... return seed
>>> env = EnvWithDynamicSpec()
>>> print(env.rollout(5, return_contiguous=False))
LazyStackedTensorDict(
fields={
action: Tensor(shape=torch.Size([5, 2]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: LazyStackedTensorDict(
fields={
done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([5, 3, -1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([5]),
device=None,
is_shared=False,
stack_dim=0),
observation: Tensor(shape=torch.Size([5, 3, -1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([5]),
device=None,
is_shared=False,
stack_dim=0)

.. warning:: The absence of memory buffers in :class:`~torchrl.envs.ParallelEnv` and in data collectors can impact
performance of these classes dramatically. Any such usage should be carefully benchmarked against a plain execution on
a single process, as serializing and deserializing large numbers of tensors can be very expensive.

Currently, :func:`~torchrl.envs.utils.check_env_specs` will pass for dynamic specs where a shape varies along some
dimensions, but not when a key is present during a step and absent during others, or when the number of dimensions
varies.

Transforms
----------
Expand Down
17 changes: 2 additions & 15 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ class BatchedEnvBase(EnvBase):
occur via circular preallocated memory buffers. Defaults to ``True`` unless
one of the environment has dynamic specs.
.. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
.. note::
One can pass keyword arguments to each sub-environments using the following
technique: every keyword argument in :meth:`~.reset` will be passed to each
Expand Down Expand Up @@ -993,21 +995,6 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
else:
out_tds[i] = _td
if not self._use_buffers:
# first_non_none = None
# first_none = None
# for i, item in enumerate(out_tds):
# if item is not None:
# first_non_none = i
# if first_none is not None:
# break
# else:
# first_none = i
# if first_non_none is not None:
# break
# if first_none is not None:
# empty_td = out_tds[first_non_none].empty(recurse=True)
# out_tds = [item if item is not None else empty_td for item in out_tds]
#
result = LazyStackedTensorDict.maybe_dense_stack(out_tds)
return result

Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
dtype=torch.bool,
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
.. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
"""

def __init__(
Expand Down

0 comments on commit 1405600

Please sign in to comment.