Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed May 16, 2024
1 parent 94e7934 commit 2a184b3
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 37 deletions.
16 changes: 9 additions & 7 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ TorchRL offers two distinctive ways of accomplishing this:
Checkpointing Replay Buffers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. _checkpoint-rb:

Each component of the replay buffer can potentially be stateful and, as such,
require a dedicated way of being serialized.
Our replay buffer enjoys two separate APIs for saving their state on disk:
Expand All @@ -412,7 +414,7 @@ Under the hood, a naive call to :meth:`~torchrl.data.ReplayBuffer.dumps` will ju
`dumps` method in a specific folder for each of its components (except transforms
which we don't assume to be serializable using memory-mapped tensors in general).

Saving data in TED-format may however consume much more memory than required. If continuous
Saving data in :ref:`TED-format <TED-format>` may however consume much more memory than required. If continuous
trajectories are stored in a buffer, we can avoid saving duplicated observations by saving all the
observations at the root plus only the last element of the `"next"` sub-tensordict's observations, which
can reduce the storage consumption up to two times. To enable this, three checkpointer classes are available:
Expand All @@ -426,9 +428,11 @@ compress the data and save some more space.

.. warning:: The checkpointers make some restrictive assumption about the replay buffers. First, it is assumed that
the ``done`` state accurately represents the end of a trajectory (except for the last trajectory which was written
for which the writer cursor indicates where to place the truncated signal). Furthermore, only done states that have
as many elements as the root tensordict are allowed: if the done state has extra elements that are not represented in
the batch-size of the storage, these checkpointers will fail.
for which the writer cursor indicates where to place the truncated signal). For MARL usage, one should note that
only done states that have as many elements as the root tensordict are allowed:
if the done state has extra elements that are not represented in
the batch-size of the storage, these checkpointers will fail. For example, a done state with shape ``torch.Size([3, 4, 5])``
within a storage of shape ``torch.Size([3, 4])`` is not allowed.

Here is a concrete example of how an H5DB checkpointer could be used in practice:

Expand All @@ -449,8 +453,6 @@ Here is a concrete example of how an H5DB checkpointer could be used in practice
>>> for i, data in enumerate(collector):
... rb.extend(data)
... assert rb._storage.max_size == 102
... if i == 0:
... rb_test.extend(data)
... rb.dumps(path_to_save_dir)
... rb_test.loads(path_to_save_dir)
... assert_allclose_td(rb_test[:], rb[:])
Expand Down Expand Up @@ -579,7 +581,7 @@ its action, info, or done state.
Flattening TED to reduce memory consumption
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TED copies the observations twice in memory, which can impact the feasibility of using this format
TED copies the observations twice in the memory, which can impact the feasibility of using this format
in practice. Since it is being used mostly for ease of representation, one can store the data
in a flat manner but represent it as TED during training.

Expand Down
39 changes: 39 additions & 0 deletions examples/replay-buffers/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""An example of a replay buffer being checkpointed at each iteration.
To explore this feature, try replacing the H5StorageCheckpointer with a NestedStorageCheckpointer or a
FlatStorageCheckpointer instance!
"""
import tempfile

import tensordict.utils
import torch

from torchrl.collectors import SyncDataCollector
from torchrl.data import H5StorageCheckpointer, LazyMemmapStorage, ReplayBuffer
from torchrl.envs import GymEnv, SerialEnv

with tempfile.TemporaryDirectory() as path_to_save_dir:
env = SerialEnv(3, lambda: GymEnv("CartPole-v1", device=None))
env.set_seed(0)
torch.manual_seed(0)
collector = SyncDataCollector(
env, policy=env.rand_step, total_frames=200, frames_per_batch=22
)
rb = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2))
rb_test = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2))
rb.storage.checkpointer = H5StorageCheckpointer()
rb_test.storage.checkpointer = H5StorageCheckpointer()
for data in collector:
rb.extend(data)
assert rb._storage.max_size == 102
rb.dumps(path_to_save_dir)
rb_test.loads(path_to_save_dir)
tensordict.assert_allclose_td(rb_test[:], rb[:])
# Print the directory structure:
tensordict.utils.print_directory_tree(path_to_save_dir)
8 changes: 2 additions & 6 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2922,10 +2922,8 @@ def test_simple_env(self, storage_type, checkpointer, tmpdir):
rb_test = ReplayBuffer(storage=storage_type(100))
rb.storage.checkpointer = checkpointer()
rb_test.storage.checkpointer = checkpointer()
for i, data in enumerate(collector):
for data in collector:
rb.extend(data)
if i == 0:
rb_test.extend(data)
rb.dumps(tmpdir)
rb_test.loads(tmpdir)
assert_allclose_td(rb_test[:], rb[:])
Expand All @@ -2946,11 +2944,9 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir):
rb_test = ReplayBuffer(storage=storage_type(100, ndim=2))
rb.storage.checkpointer = checkpointer()
rb_test.storage.checkpointer = checkpointer()
for i, data in enumerate(collector):
for data in collector:
rb.extend(data)
assert rb._storage.max_size == 102
if i == 0:
rb_test.extend(data)
rb.dumps(tmpdir)
rb_test.loads(tmpdir)
assert_allclose_td(rb_test[:], rb[:])
Expand Down
55 changes: 42 additions & 13 deletions torchrl/data/replay_buffers/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,18 @@ class FlatStorageCheckpointer(TensorStorageCheckpointer):
in which case the observation in `("next", key)` at time `t` and the one in `key` at time
`t+1` should not match.
.. warning:: Given the above limitations, one should make sure that
.. seealso: The full list of arguments can be found in :class:`~torchrl.data.TED2Flat`.
"""

def __init__(self):
self._save_hooks = [TED2Flat()]
self._load_hooks = [Flat2TED()]
def __init__(self, done_keys=None, reward_keys=None):
kwargs = {}
if done_keys is not None:
kwargs["done_keys"] = done_keys
if reward_keys is not None:
kwargs["reward_keys"] = reward_keys
self._save_hooks = [TED2Flat(**kwargs)]
self._load_hooks = [Flat2TED(**kwargs)]

def _save_shift_is_full(self, storage):
is_full = storage._is_full
Expand Down Expand Up @@ -239,11 +244,18 @@ class NestedStorageCheckpointer(FlatStorageCheckpointer):
- observations in the "next" tensordict are shifted by one step in the future (this
is not the case when a multi-step transform is used for instance).
.. seealso: The full list of arguments can be found in :class:`~torchrl.data.TED2Flat`.
"""

def __init__(self):
self._save_hooks = [TED2Nested()]
self._load_hooks = [Nested2TED()]
def __init__(self, done_keys=None, reward_keys=None, **kwargs):
kwargs = {}
if done_keys is not None:
kwargs["done_keys"] = done_keys
if reward_keys is not None:
kwargs["reward_keys"] = reward_keys
self._save_hooks = [TED2Nested(**kwargs)]
self._load_hooks = [Nested2TED(**kwargs)]


class H5StorageCheckpointer(NestedStorageCheckpointer):
Expand All @@ -255,22 +267,39 @@ class H5StorageCheckpointer(NestedStorageCheckpointer):
- observations in the "next" tensordict are shifted by one step in the future (this
is not the case when a multi-step transform is used for instance).
Args:
Keyword Args:
checkpoint_file: the filename where to save the checkpointed data.
This will be ignored iff the path passed to dumps / loads ends with the ``.h5``
suffix. Defaults to ``"checkpoint.h5"``.
**kwargs: kwargs to be passed to :meth:`h5py.File.create_dataset`.
h5_kwargs (Dict[str, Any] or Tuple[Tuple[str, Any], ...]): kwargs to be
passed to :meth:`h5py.File.create_dataset`.
.. note:: To prevent out-of-memory issues, the data of the H5 file will be temporarily written
on memory-mapped tensors stored in shared file system. The physical memory usage may increase
during loading as a consequence.
.. seealso: The full list of arguments can be found in :class:`~torchrl.data.TED2Flat`. Note that this class only
supports keyword arguments.
"""

def __init__(self, checkpoint_file: str = "checkpoint.h5", **kwargs):
self._save_hooks = [TED2Nested(), H5Split()]
self._load_hooks = [H5Combine(), Nested2TED()]
self.kwargs = kwargs
def __init__(
self,
*,
checkpoint_file: str = "checkpoint.h5",
done_keys=None,
reward_keys=None,
h5_kwargs=None,
**kwargs,
):
ted2_kwargs = kwargs
if done_keys is not None:
ted2_kwargs["done_keys"] = done_keys
if reward_keys is not None:
ted2_kwargs["reward_keys"] = reward_keys
self._save_hooks = [TED2Nested(**ted2_kwargs), H5Split()]
self._load_hooks = [H5Combine(), Nested2TED(**ted2_kwargs)]
self.kwargs = {} if h5_kwargs is None else dict(h5_kwargs)
self.checkpoint_file = checkpoint_file

def dumps(self, storage, path):
Expand Down
61 changes: 50 additions & 11 deletions torchrl/data/replay_buffers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
NonTensorData,
TensorDict,
TensorDictBase,
unravel_key,
)
from torch import Tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -120,6 +121,19 @@ def _is_int(index):
class TED2Flat:
"""A storage saving hook to serialize TED data in a compact format.
Args:
done_key (NestedKey, optional): the key where the done states should be read.
Defaults to ``("next", "done")``.
shift_key (NestedKey, optional): the key where the shift will be written.
Defaults to "shift".
is_full_key (NestedKey, optional): the key where the is_full attribute will be written.
Defaults to "is_full".
done_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the done entries.
Defaults to ("done", "truncated", "terminated")
reward_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the reward entries.
Defaults to ("reward",)
Examples:
>>> import tempfile
>>>
Expand Down Expand Up @@ -167,11 +181,18 @@ class TED2Flat:
_is_full: bool = None

def __init__(
self, done_key=("next", "done"), shift_key="shift", is_full_key="is_full"
self,
done_key=("next", "done"),
shift_key="shift",
is_full_key="is_full",
done_keys=("done", "truncated", "terminated"),
reward_keys=("reward",),
):
self.done_key = done_key
self.shift_key = shift_key
self.is_full_key = is_full_key
self.done_keys = {unravel_key(key) for key in done_keys}
self.reward_keys = {unravel_key(key) for key in reward_keys}

@property
def shift(self):
Expand Down Expand Up @@ -212,19 +233,16 @@ def __call__(self, data: TensorDictBase, path: Path = None):
ntraj = done.sum()

# Get the keys that require extra storage
keys_to_expand = set(data.get("next").keys(True, True)) - {
"terminated",
"done",
"truncated",
"reward",
}
keys_to_expand = set(data.get("next").keys(True, True)) - (
self.done_keys.union(self.reward_keys)
)

total_keys = data.exclude("next").keys(True, True)
total_keys = set(total_keys).union(set(data.get("next").keys(True, True)))

len_with_offset = data.numel() + ntraj # + done[0].numel()
for key in total_keys:
if key in ("done", "truncated", "terminated", "reward"):
if key in (self.done_keys.union(self.reward_keys)):
entry = data.get(("next", key))
else:
entry = data.get(key)
Expand Down Expand Up @@ -301,7 +319,7 @@ def _call(self, *, data, output, is_full, shift, done, total_keys, keys_to_expan
idx += torch.nn.functional.pad(done, [1, 0])[:-1].cumsum(0)

for key in total_keys:
if key in ("done", "truncated", "terminated", "reward"):
if key in (self.done_keys.union(self.reward_keys)):
entry = data.get(("next", key))
else:
entry = data.get(key)
Expand Down Expand Up @@ -333,6 +351,18 @@ def _call(self, *, data, output, is_full, shift, done, total_keys, keys_to_expan
class Flat2TED:
"""A storage loading hook to deserialize flattened TED data to TED format.
Args:
done_key (NestedKey, optional): the key where the done states should be read.
Defaults to ``("next", "done")``.
shift_key (NestedKey, optional): the key where the shift will be written.
Defaults to "shift".
is_full_key (NestedKey, optional): the key where the is_full attribute will be written.
Defaults to "is_full".
done_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the done entries.
Defaults to ("done", "truncated", "terminated")
reward_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the reward entries.
Defaults to ("reward",)
Examples:
>>> import tempfile
>>>
Expand Down Expand Up @@ -391,10 +421,19 @@ class Flat2TED:
"""

def __init__(self, done_key="done", shift_key="shift", is_full_key="is_full"):
def __init__(
self,
done_key="done",
shift_key="shift",
is_full_key="is_full",
done_keys=("done", "truncated", "terminated"),
reward_keys=("reward",),
):
self.done_key = done_key
self.shift_key = shift_key
self.is_full_key = is_full_key
self.done_keys = {unravel_key(key) for key in done_keys}
self.reward_keys = {unravel_key(key) for key in reward_keys}

def __call__(self, data: TensorDictBase, out: TensorDictBase = None):
_storage_shape = data.get_non_tensor("_storage_shape", default=None)
Expand Down Expand Up @@ -517,7 +556,7 @@ def maybe_roll(entry, out=None):

for key, entry in data.items(True, True):
if entry.shape[0] == nsteps:
if key in ("done", "terminated", "truncated", "reward"):
if key in (self.done_keys.union(self.reward_keys)):
if key != "reward" and key not in out.keys(True, True):
# Create a done state at the root full of 0s
out.set(key, torch.zeros_like(entry), inplace=True)
Expand Down
1 change: 1 addition & 0 deletions tutorials/sphinx-tutorials/rb_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,3 +874,4 @@ def assert0(x):
# :class:`~torchrl.data.PrioritizedSliceSampler` and
# :class:`~torchrl.data.SliceSamplerWithoutReplacement`, or other writers
# such as :class:`~torchrl.data.TensorDictMaxValueWriter`.
# - Check how to checkpoint ReplayBuffers in :ref:`the doc <checkpoint-rb>`.

0 comments on commit 2a184b3

Please sign in to comment.