Skip to content

Commit

Permalink
[Feature] Replay buffer checkpointers (pytorch#2137)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 16, 2024
1 parent 259f20d commit 73d09c3
Show file tree
Hide file tree
Showing 10 changed files with 1,523 additions and 300 deletions.
97 changes: 87 additions & 10 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,31 @@ using the following components:
:template: rl_template.rst


Sampler
FlatStorageCheckpointer
H5StorageCheckpointer
ImmutableDatasetWriter
LazyMemmapStorage
LazyTensorStorage
ListStorage
ListStorageCheckpointer
NestedStorageCheckpointer
PrioritizedSampler
PrioritizedSliceSampler
RandomSampler
RoundRobinWriter
Sampler
SamplerWithoutReplacement
SliceSampler
SliceSamplerWithoutReplacement
Storage
ListStorage
LazyTensorStorage
LazyMemmapStorage
StorageCheckpointerBase
StorageEnsembleCheckpointer
TensorDictMaxValueWriter
TensorDictRoundRobinWriter
TensorStorage
TensorStorageCheckpointer
Writer
ImmutableDatasetWriter
RoundRobinWriter
TensorDictRoundRobinWriter
TensorDictMaxValueWriter


Storage choice is very influential on replay buffer sampling latency, especially
in distributed reinforcement learning settings with larger data volumes.
Expand Down Expand Up @@ -384,22 +392,72 @@ TorchRL offers two distinctive ways of accomplishing this:
Checkpointing Replay Buffers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. _checkpoint-rb:

Each component of the replay buffer can potentially be stateful and, as such,
require a dedicated way of being serialized.
Our replay buffer enjoys two separate APIs for saving their state on disk:
:meth:`~torchrl.data.ReplayBuffer.dumps` and :meth:`~torchrl.data.ReplayBuffer.loads` will save the
data of each component except transforms (storage, writer, sampler) using memory-mapped
tensors and json files for the metadata. This will work across all classes except
tensors and json files for the metadata.

This will work across all classes except
:class:`~torchrl.data.replay_buffers.storages.ListStorage`, which content
cannot be anticipated (and as such does not comply with memory-mapped data
structures such as those that can be found in the tensordict library).

This API guarantees that a buffer that is saved and then loaded back will be in
the exact same state, whether we look at the status of its sampler (eg, priority trees)
its writer (eg, max writer heaps) or its storage.
Under the hood, :meth:`~torchrl.data.ReplayBuffer.dumps` will just call the public

Under the hood, a naive call to :meth:`~torchrl.data.ReplayBuffer.dumps` will just call the public
`dumps` method in a specific folder for each of its components (except transforms
which we don't assume to be serializable using memory-mapped tensors in general).

Saving data in :ref:`TED-format <TED-format>` may however consume much more memory than required. If continuous
trajectories are stored in a buffer, we can avoid saving duplicated observations by saving all the
observations at the root plus only the last element of the `"next"` sub-tensordict's observations, which
can reduce the storage consumption up to two times. To enable this, three checkpointer classes are available:
:class:`~torchrl.data.FlatStorageCheckpointer` will discard duplicated observations to compress the TED format. At
load time, this class will re-write the observations in the correct format. If the buffer is saved on disk,
the operations executed by this checkpointer will not require any additional RAM.
The :class:`~torchrl.data.NestedStorageCheckpointer` will save the trajectories using nested tensors to make the data
representation more apparent (each item along the first dimension representing a distinct trajectory).
Finally, the :class:`~torchrl.data.H5StorageCheckpointer` will save the buffer in an H5DB format, enabling users to
compress the data and save some more space.

.. warning:: The checkpointers make some restrictive assumption about the replay buffers. First, it is assumed that
the ``done`` state accurately represents the end of a trajectory (except for the last trajectory which was written
for which the writer cursor indicates where to place the truncated signal). For MARL usage, one should note that
only done states that have as many elements as the root tensordict are allowed:
if the done state has extra elements that are not represented in
the batch-size of the storage, these checkpointers will fail. For example, a done state with shape ``torch.Size([3, 4, 5])``
within a storage of shape ``torch.Size([3, 4])`` is not allowed.

Here is a concrete example of how an H5DB checkpointer could be used in practice:

>>> from torchrl.data import ReplayBuffer, H5StorageCheckpointer, LazyMemmapStorage
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import GymEnv, SerialEnv
>>> import torch
>>> env = SerialEnv(3, lambda: GymEnv("CartPole-v1", device=None))
>>> env.set_seed(0)
>>> torch.manual_seed(0)
>>> collector = SyncDataCollector(
>>> env, policy=env.rand_step, total_frames=200, frames_per_batch=22
>>> )
>>> rb = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2))
>>> rb_test = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2))
>>> rb.storage.checkpointer = H5StorageCheckpointer()
>>> rb_test.storage.checkpointer = H5StorageCheckpointer()
>>> for i, data in enumerate(collector):
... rb.extend(data)
... assert rb._storage.max_size == 102
... rb.dumps(path_to_save_dir)
... rb_test.loads(path_to_save_dir)
... assert_allclose_td(rb_test[:], rb[:])


Whenever saving data using :meth:`~torchrl.data.ReplayBuffer.dumps` is not possible, an
alternative way is to use :meth:`~torchrl.data.ReplayBuffer.state_dict`, which returns a data
structure that can be saved using :func:`torch.save` and loaded using :func:`torch.load`
Expand Down Expand Up @@ -520,6 +578,19 @@ should have a considerably lower memory footprint than observations, for instanc
This format eliminates any ambiguity regarding the matching of an observation with
its action, info, or done state.

Flattening TED to reduce memory consumption
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TED copies the observations twice in the memory, which can impact the feasibility of using this format
in practice. Since it is being used mostly for ease of representation, one can store the data
in a flat manner but represent it as TED during training.

This is particularly useful when serializing replay buffers:
For instance, the :class:`~torchrl.data.TED2Flat` class ensures that a TED-formatted data
structure is flattened before being written to disk, whereas the :class:`~torchrl.data.Flat2TED`
load hook will unflatten this structure during deserialization.


Dimensionality of the Tensordict
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -869,6 +940,12 @@ Utils
consolidate_spec
check_no_exclusive_keys
contains_lazy_spec
Nested2TED
Flat2TED
H5Combine
H5Split
TED2Flat
TED2Nested

.. currentmodule:: torchrl.envs.transforms.rb_transforms

Expand Down
39 changes: 39 additions & 0 deletions examples/replay-buffers/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""An example of a replay buffer being checkpointed at each iteration.
To explore this feature, try replacing the H5StorageCheckpointer with a NestedStorageCheckpointer or a
FlatStorageCheckpointer instance!
"""
import tempfile

import tensordict.utils
import torch

from torchrl.collectors import SyncDataCollector
from torchrl.data import H5StorageCheckpointer, LazyMemmapStorage, ReplayBuffer
from torchrl.envs import GymEnv, SerialEnv

with tempfile.TemporaryDirectory() as path_to_save_dir:
env = SerialEnv(3, lambda: GymEnv("CartPole-v1", device=None))
env.set_seed(0)
torch.manual_seed(0)
collector = SyncDataCollector(
env, policy=env.rand_step, total_frames=200, frames_per_batch=22
)
rb = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2))
rb_test = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2))
rb.storage.checkpointer = H5StorageCheckpointer()
rb_test.storage.checkpointer = H5StorageCheckpointer()
for data in collector:
rb.extend(data)
assert rb._storage.max_size == 102
rb.dumps(path_to_save_dir)
rb_test.loads(path_to_save_dir)
tensordict.assert_allclose_td(rb_test[:], rb[:])
# Print the directory structure:
tensordict.utils.print_directory_tree(path_to_save_dir)
53 changes: 52 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import torch

from _utils_internal import get_default_devices, make_tc
from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc

from mocking_classes import CountingEnv
from packaging import version
Expand All @@ -35,7 +35,9 @@
from torchrl.collectors import RandomPolicy, SyncDataCollector
from torchrl.collectors.utils import split_trajectories
from torchrl.data import (
FlatStorageCheckpointer,
MultiStep,
NestedStorageCheckpointer,
PrioritizedReplayBuffer,
RemoteTensorDictReplayBuffer,
ReplayBuffer,
Expand All @@ -44,6 +46,7 @@
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import samplers, writers
from torchrl.data.replay_buffers.checkpointers import H5StorageCheckpointer
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
PrioritizedSliceSampler,
Expand Down Expand Up @@ -2901,6 +2904,54 @@ def test_done_slicesampler(self, strict_length):
assert (split_trajectories(sample)["next", "done"].sum(-2) == 1).all()


@pytest.mark.skipif(not _has_gym, reason="gym required")
class TestCheckpointers:
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
@pytest.mark.parametrize(
"checkpointer",
[FlatStorageCheckpointer, H5StorageCheckpointer, NestedStorageCheckpointer],
)
def test_simple_env(self, storage_type, checkpointer, tmpdir):
env = GymEnv(CARTPOLE_VERSIONED(), device=None)
env.set_seed(0)
torch.manual_seed(0)
collector = SyncDataCollector(
env, policy=env.rand_step, total_frames=200, frames_per_batch=22
)
rb = ReplayBuffer(storage=storage_type(100))
rb_test = ReplayBuffer(storage=storage_type(100))
rb.storage.checkpointer = checkpointer()
rb_test.storage.checkpointer = checkpointer()
for data in collector:
rb.extend(data)
rb.dumps(tmpdir)
rb_test.loads(tmpdir)
assert_allclose_td(rb_test[:], rb[:])

@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
@pytest.mark.parametrize(
"checkpointer",
[FlatStorageCheckpointer, NestedStorageCheckpointer, H5StorageCheckpointer],
)
def test_multi_env(self, storage_type, checkpointer, tmpdir):
env = SerialEnv(3, lambda: GymEnv(CARTPOLE_VERSIONED(), device=None))
env.set_seed(0)
torch.manual_seed(0)
collector = SyncDataCollector(
env, policy=env.rand_step, total_frames=200, frames_per_batch=22
)
rb = ReplayBuffer(storage=storage_type(100, ndim=2))
rb_test = ReplayBuffer(storage=storage_type(100, ndim=2))
rb.storage.checkpointer = checkpointer()
rb_test.storage.checkpointer = checkpointer()
for data in collector:
rb.extend(data)
assert rb._storage.max_size == 102
rb.dumps(tmpdir)
rb_test.loads(tmpdir)
assert_allclose_td(rb_test[:], rb[:])


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
13 changes: 13 additions & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@

from .postprocs import MultiStep
from .replay_buffers import (
Flat2TED,
FlatStorageCheckpointer,
H5Combine,
H5Split,
H5StorageCheckpointer,
ImmutableDatasetWriter,
LazyMemmapStorage,
LazyTensorStorage,
ListStorage,
ListStorageCheckpointer,
Nested2TED,
NestedStorageCheckpointer,
PrioritizedReplayBuffer,
PrioritizedSampler,
RandomSampler,
Expand All @@ -21,12 +29,17 @@
SliceSampler,
SliceSamplerWithoutReplacement,
Storage,
StorageCheckpointerBase,
StorageEnsemble,
StorageEnsembleCheckpointer,
TED2Flat,
TED2Nested,
TensorDictMaxValueWriter,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
TensorDictRoundRobinWriter,
TensorStorage,
TensorStorageCheckpointer,
Writer,
WriterEnsemble,
)
Expand Down
10 changes: 10 additions & 0 deletions torchrl/data/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .checkpointers import (
FlatStorageCheckpointer,
H5StorageCheckpointer,
ListStorageCheckpointer,
NestedStorageCheckpointer,
StorageCheckpointerBase,
StorageEnsembleCheckpointer,
TensorStorageCheckpointer,
)
from .replay_buffers import (
PrioritizedReplayBuffer,
RemoteTensorDictReplayBuffer,
Expand All @@ -29,6 +38,7 @@
StorageEnsemble,
TensorStorage,
)
from .utils import Flat2TED, H5Combine, H5Split, Nested2TED, TED2Flat, TED2Nested
from .writers import (
ImmutableDatasetWriter,
RoundRobinWriter,
Expand Down
Loading

0 comments on commit 73d09c3

Please sign in to comment.