Skip to content

Commit

Permalink
[Refactor,Feature] Refactor collector shapes and stack_result in sync…
Browse files Browse the repository at this point in the history
… collector (pytorch#1994)
  • Loading branch information
vmoens authored Mar 19, 2024
1 parent 4bce371 commit f6fbc44
Show file tree
Hide file tree
Showing 3 changed files with 452 additions and 121 deletions.
108 changes: 107 additions & 1 deletion docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ avoid OOM errors. Finally, the choice of the batch size and passing device (ie t
device where the data will be stored while waiting to be passed to the collection
worker) may also impact the memory management. The key parameters to control are
:obj:`devices` which controls the execution devices (ie the device of the policy)
and :obj:`storing_devices` which will control the device where the environment and
and :obj:`storing_device` which will control the device where the environment and
data are stored during a rollout. A good heuristic is usually to use the same device
for storage and compute, which is the default behaviour when only the `devices` argument
is being passed.
Expand All @@ -59,6 +59,112 @@ Besides those compute parameters, users may choose to configure the following pa
- exploration_type: the exploration strategy to be used with the policy.
- reset_when_done: whether environments should be reset when reaching a done state.

Collectors and batch size
-------------------------

Because each collector has its own way of organizing the environments that are
run within, the data will come with different batch-size depending on how
the specificities of the collector. The following table summarizes what is to
be expected when collecting data:


+--------------------+---------------------+--------------------------------------------+------------------------------+
| | SyncDataCollector | MultiSyncDataCollector (n=B) |MultiaSyncDataCollector (n=B) |
+====================+=====================+=============+==============+===============+==============================+
| `cat_results` | NA | `"stack"` | `0` | `-1` | NA |
+--------------------+---------------------+-------------+--------------+---------------+------------------------------+
| Single env | [T] | `[B, T]` | `[B*(T//B)` | `[B*(T//B)]` | [T] |
+--------------------+---------------------+-------------+--------------+---------------+------------------------------+
| Batched env (n=P) | [P, T] | `[B, P, T]` | `[B * P, T]`| `[P, T * B]` | [P, T] |
+--------------------+---------------------+-------------+--------------+---------------+------------------------------+

In each of these cases, the last dimension (``T`` for ``time``) is adapted such
that the batch size equals the ``frames_per_batch`` argument passed to the
collector.

.. warning:: :class:`~torchrl.collectors.collectors.MultiSyncDataCollector` should not be
used with ``cat_results=0``, as the data will be stacked along the batch
dimension with batched environment, or the time dimension for single environments,
which can introduce some confusion when swapping one with the other.
``cat_results="stack"`` is a better and more consistent way of interacting
with the environments as it will keep each dimension separate, and provide
better interchangeability between configurations, collector classes and other
components.

Whereas :class:`~torchrl.collectors.collectors.MultiSyncDataCollector`
has a dimension corresponding to the number of sub-collectors being run (``B``),
:class:`~torchrl.collectors.collectors.MultiaSyncDataCollector` doesn't. This
is easily understood when considering that :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`
delivers batches of data on a first-come, first-serve basis, whereas
:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` gathers data from
each sub-collector before delivering it.

Collectors and replay buffers interoperability
----------------------------------------------

In the simplest scenario where single transitions have to be sampled
from the replay buffer, little attention has to be given to the way
the collector is built. Flattening the data after collection will
be a sufficient preprocessing step before populating the storage:

>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N),
... transform=lambda data: data.reshape(-1))
>>> for data in collector:
... memory.extend(data)

If trajectory slices have to be collected, the recommended way to achieve this is to create
a multidimensional buffer and sample using the :class:`~torchrl.data.replay_buffers.SliceSampler`
sampler class. One must ensure that the data passed to the buffer is properly shaped, with the
``time`` and ``batch`` dimensions clearly separated. In practice, the following configurations
will work:

>>> # Single environment: no need for a multi-dimensional buffer
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
>>> for data in collector:
... memory.extend(data)
>>> # Batched environments: a multi-dim buffer is required
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N, ndim=2),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> env = ParallelEnv(4, make_env)
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
>>> for data in collector:
... memory.extend(data)
>>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv iif cat_results="stack"
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N, ndim=2),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = MultiSyncDataCollector([make_env] * 4,
... policy,
... frames_per_batch=N,
... total_frames=-1,
... cat_results="stack")
>>> for data in collector:
... memory.extend(data)
>>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N, ndim=3),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4,
... policy,
... frames_per_batch=N,
... total_frames=-1,
... cat_results="stack")
>>> for data in collector:
... memory.extend(data)

Using replay buffers that sample trajectories with :class:`~torchrl.collectors.MultiSyncDataCollector`
isn't currently fully supported as the data batches can come from any worker and in most cases consecutive
batches written in the buffer won't come from the same source (thereby interrupting the trajectories).


Single node data collectors
---------------------------
Expand Down
Loading

0 comments on commit f6fbc44

Please sign in to comment.