From 73d09c376d9bd5a18e726757e5624343c05537fb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 16 May 2024 16:12:57 +0100 Subject: [PATCH] [Feature] Replay buffer checkpointers (#2137) --- docs/source/reference/data.rst | 97 +- examples/replay-buffers/checkpoint.py | 39 + test/test_rb.py | 53 +- torchrl/data/__init__.py | 13 + torchrl/data/replay_buffers/__init__.py | 10 + torchrl/data/replay_buffers/checkpointers.py | 375 ++++++++ torchrl/data/replay_buffers/replay_buffers.py | 41 +- torchrl/data/replay_buffers/storages.py | 339 ++----- torchrl/data/replay_buffers/utils.py | 855 +++++++++++++++++- tutorials/sphinx-tutorials/rb_tutorial.py | 1 + 10 files changed, 1523 insertions(+), 300 deletions(-) create mode 100644 examples/replay-buffers/checkpoint.py create mode 100644 torchrl/data/replay_buffers/checkpointers.py diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 3c249bcfbaa..cf3df44487c 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -136,23 +136,31 @@ using the following components: :template: rl_template.rst - Sampler + FlatStorageCheckpointer + H5StorageCheckpointer + ImmutableDatasetWriter + LazyMemmapStorage + LazyTensorStorage + ListStorage + ListStorageCheckpointer + NestedStorageCheckpointer PrioritizedSampler PrioritizedSliceSampler RandomSampler + RoundRobinWriter + Sampler SamplerWithoutReplacement SliceSampler SliceSamplerWithoutReplacement Storage - ListStorage - LazyTensorStorage - LazyMemmapStorage + StorageCheckpointerBase + StorageEnsembleCheckpointer + TensorDictMaxValueWriter + TensorDictRoundRobinWriter TensorStorage + TensorStorageCheckpointer Writer - ImmutableDatasetWriter - RoundRobinWriter - TensorDictRoundRobinWriter - TensorDictMaxValueWriter + Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes. @@ -384,22 +392,72 @@ TorchRL offers two distinctive ways of accomplishing this: Checkpointing Replay Buffers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. _checkpoint-rb: + Each component of the replay buffer can potentially be stateful and, as such, require a dedicated way of being serialized. Our replay buffer enjoys two separate APIs for saving their state on disk: :meth:`~torchrl.data.ReplayBuffer.dumps` and :meth:`~torchrl.data.ReplayBuffer.loads` will save the data of each component except transforms (storage, writer, sampler) using memory-mapped -tensors and json files for the metadata. This will work across all classes except +tensors and json files for the metadata. + +This will work across all classes except :class:`~torchrl.data.replay_buffers.storages.ListStorage`, which content cannot be anticipated (and as such does not comply with memory-mapped data structures such as those that can be found in the tensordict library). + This API guarantees that a buffer that is saved and then loaded back will be in the exact same state, whether we look at the status of its sampler (eg, priority trees) its writer (eg, max writer heaps) or its storage. -Under the hood, :meth:`~torchrl.data.ReplayBuffer.dumps` will just call the public + +Under the hood, a naive call to :meth:`~torchrl.data.ReplayBuffer.dumps` will just call the public `dumps` method in a specific folder for each of its components (except transforms which we don't assume to be serializable using memory-mapped tensors in general). +Saving data in :ref:`TED-format ` may however consume much more memory than required. If continuous +trajectories are stored in a buffer, we can avoid saving duplicated observations by saving all the +observations at the root plus only the last element of the `"next"` sub-tensordict's observations, which +can reduce the storage consumption up to two times. To enable this, three checkpointer classes are available: +:class:`~torchrl.data.FlatStorageCheckpointer` will discard duplicated observations to compress the TED format. At +load time, this class will re-write the observations in the correct format. If the buffer is saved on disk, +the operations executed by this checkpointer will not require any additional RAM. +The :class:`~torchrl.data.NestedStorageCheckpointer` will save the trajectories using nested tensors to make the data +representation more apparent (each item along the first dimension representing a distinct trajectory). +Finally, the :class:`~torchrl.data.H5StorageCheckpointer` will save the buffer in an H5DB format, enabling users to +compress the data and save some more space. + +.. warning:: The checkpointers make some restrictive assumption about the replay buffers. First, it is assumed that + the ``done`` state accurately represents the end of a trajectory (except for the last trajectory which was written + for which the writer cursor indicates where to place the truncated signal). For MARL usage, one should note that + only done states that have as many elements as the root tensordict are allowed: + if the done state has extra elements that are not represented in + the batch-size of the storage, these checkpointers will fail. For example, a done state with shape ``torch.Size([3, 4, 5])`` + within a storage of shape ``torch.Size([3, 4])`` is not allowed. + +Here is a concrete example of how an H5DB checkpointer could be used in practice: + + >>> from torchrl.data import ReplayBuffer, H5StorageCheckpointer, LazyMemmapStorage + >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.envs import GymEnv, SerialEnv + >>> import torch + >>> env = SerialEnv(3, lambda: GymEnv("CartPole-v1", device=None)) + >>> env.set_seed(0) + >>> torch.manual_seed(0) + >>> collector = SyncDataCollector( + >>> env, policy=env.rand_step, total_frames=200, frames_per_batch=22 + >>> ) + >>> rb = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2)) + >>> rb_test = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2)) + >>> rb.storage.checkpointer = H5StorageCheckpointer() + >>> rb_test.storage.checkpointer = H5StorageCheckpointer() + >>> for i, data in enumerate(collector): + ... rb.extend(data) + ... assert rb._storage.max_size == 102 + ... rb.dumps(path_to_save_dir) + ... rb_test.loads(path_to_save_dir) + ... assert_allclose_td(rb_test[:], rb[:]) + + Whenever saving data using :meth:`~torchrl.data.ReplayBuffer.dumps` is not possible, an alternative way is to use :meth:`~torchrl.data.ReplayBuffer.state_dict`, which returns a data structure that can be saved using :func:`torch.save` and loaded using :func:`torch.load` @@ -520,6 +578,19 @@ should have a considerably lower memory footprint than observations, for instanc This format eliminates any ambiguity regarding the matching of an observation with its action, info, or done state. +Flattening TED to reduce memory consumption +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +TED copies the observations twice in the memory, which can impact the feasibility of using this format +in practice. Since it is being used mostly for ease of representation, one can store the data +in a flat manner but represent it as TED during training. + +This is particularly useful when serializing replay buffers: +For instance, the :class:`~torchrl.data.TED2Flat` class ensures that a TED-formatted data +structure is flattened before being written to disk, whereas the :class:`~torchrl.data.Flat2TED` +load hook will unflatten this structure during deserialization. + + Dimensionality of the Tensordict ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -869,6 +940,12 @@ Utils consolidate_spec check_no_exclusive_keys contains_lazy_spec + Nested2TED + Flat2TED + H5Combine + H5Split + TED2Flat + TED2Nested .. currentmodule:: torchrl.envs.transforms.rb_transforms diff --git a/examples/replay-buffers/checkpoint.py b/examples/replay-buffers/checkpoint.py new file mode 100644 index 00000000000..bd634acf105 --- /dev/null +++ b/examples/replay-buffers/checkpoint.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""An example of a replay buffer being checkpointed at each iteration. + +To explore this feature, try replacing the H5StorageCheckpointer with a NestedStorageCheckpointer or a +FlatStorageCheckpointer instance! + +""" +import tempfile + +import tensordict.utils +import torch + +from torchrl.collectors import SyncDataCollector +from torchrl.data import H5StorageCheckpointer, LazyMemmapStorage, ReplayBuffer +from torchrl.envs import GymEnv, SerialEnv + +with tempfile.TemporaryDirectory() as path_to_save_dir: + env = SerialEnv(3, lambda: GymEnv("CartPole-v1", device=None)) + env.set_seed(0) + torch.manual_seed(0) + collector = SyncDataCollector( + env, policy=env.rand_step, total_frames=200, frames_per_batch=22 + ) + rb = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2)) + rb_test = ReplayBuffer(storage=LazyMemmapStorage(100, ndim=2)) + rb.storage.checkpointer = H5StorageCheckpointer() + rb_test.storage.checkpointer = H5StorageCheckpointer() + for data in collector: + rb.extend(data) + assert rb._storage.max_size == 102 + rb.dumps(path_to_save_dir) + rb_test.loads(path_to_save_dir) + tensordict.assert_allclose_td(rb_test[:], rb[:]) + # Print the directory structure: + tensordict.utils.print_directory_tree(path_to_save_dir) diff --git a/test/test_rb.py b/test/test_rb.py index 13b1a50aecc..7f5fdff0bc4 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -17,7 +17,7 @@ import pytest import torch -from _utils_internal import get_default_devices, make_tc +from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc from mocking_classes import CountingEnv from packaging import version @@ -35,7 +35,9 @@ from torchrl.collectors import RandomPolicy, SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( + FlatStorageCheckpointer, MultiStep, + NestedStorageCheckpointer, PrioritizedReplayBuffer, RemoteTensorDictReplayBuffer, ReplayBuffer, @@ -44,6 +46,7 @@ TensorDictReplayBuffer, ) from torchrl.data.replay_buffers import samplers, writers +from torchrl.data.replay_buffers.checkpointers import H5StorageCheckpointer from torchrl.data.replay_buffers.samplers import ( PrioritizedSampler, PrioritizedSliceSampler, @@ -2901,6 +2904,54 @@ def test_done_slicesampler(self, strict_length): assert (split_trajectories(sample)["next", "done"].sum(-2) == 1).all() +@pytest.mark.skipif(not _has_gym, reason="gym required") +class TestCheckpointers: + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + @pytest.mark.parametrize( + "checkpointer", + [FlatStorageCheckpointer, H5StorageCheckpointer, NestedStorageCheckpointer], + ) + def test_simple_env(self, storage_type, checkpointer, tmpdir): + env = GymEnv(CARTPOLE_VERSIONED(), device=None) + env.set_seed(0) + torch.manual_seed(0) + collector = SyncDataCollector( + env, policy=env.rand_step, total_frames=200, frames_per_batch=22 + ) + rb = ReplayBuffer(storage=storage_type(100)) + rb_test = ReplayBuffer(storage=storage_type(100)) + rb.storage.checkpointer = checkpointer() + rb_test.storage.checkpointer = checkpointer() + for data in collector: + rb.extend(data) + rb.dumps(tmpdir) + rb_test.loads(tmpdir) + assert_allclose_td(rb_test[:], rb[:]) + + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + @pytest.mark.parametrize( + "checkpointer", + [FlatStorageCheckpointer, NestedStorageCheckpointer, H5StorageCheckpointer], + ) + def test_multi_env(self, storage_type, checkpointer, tmpdir): + env = SerialEnv(3, lambda: GymEnv(CARTPOLE_VERSIONED(), device=None)) + env.set_seed(0) + torch.manual_seed(0) + collector = SyncDataCollector( + env, policy=env.rand_step, total_frames=200, frames_per_batch=22 + ) + rb = ReplayBuffer(storage=storage_type(100, ndim=2)) + rb_test = ReplayBuffer(storage=storage_type(100, ndim=2)) + rb.storage.checkpointer = checkpointer() + rb_test.storage.checkpointer = checkpointer() + for data in collector: + rb.extend(data) + assert rb._storage.max_size == 102 + rb.dumps(tmpdir) + rb_test.loads(tmpdir) + assert_allclose_td(rb_test[:], rb[:]) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index e2c0b97c2fc..14c1bcbb6c6 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -5,10 +5,18 @@ from .postprocs import MultiStep from .replay_buffers import ( + Flat2TED, + FlatStorageCheckpointer, + H5Combine, + H5Split, + H5StorageCheckpointer, ImmutableDatasetWriter, LazyMemmapStorage, LazyTensorStorage, ListStorage, + ListStorageCheckpointer, + Nested2TED, + NestedStorageCheckpointer, PrioritizedReplayBuffer, PrioritizedSampler, RandomSampler, @@ -21,12 +29,17 @@ SliceSampler, SliceSamplerWithoutReplacement, Storage, + StorageCheckpointerBase, StorageEnsemble, + StorageEnsembleCheckpointer, + TED2Flat, + TED2Nested, TensorDictMaxValueWriter, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, TensorDictRoundRobinWriter, TensorStorage, + TensorStorageCheckpointer, Writer, WriterEnsemble, ) diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index c9aadcb992b..25822dcfe4c 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -3,6 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .checkpointers import ( + FlatStorageCheckpointer, + H5StorageCheckpointer, + ListStorageCheckpointer, + NestedStorageCheckpointer, + StorageCheckpointerBase, + StorageEnsembleCheckpointer, + TensorStorageCheckpointer, +) from .replay_buffers import ( PrioritizedReplayBuffer, RemoteTensorDictReplayBuffer, @@ -29,6 +38,7 @@ StorageEnsemble, TensorStorage, ) +from .utils import Flat2TED, H5Combine, H5Split, Nested2TED, TED2Flat, TED2Nested from .writers import ( ImmutableDatasetWriter, RoundRobinWriter, diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py new file mode 100644 index 00000000000..27e91b84e29 --- /dev/null +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -0,0 +1,375 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import abc +import json +import warnings +from pathlib import Path + +import numpy as np +import torch +from tensordict import ( + is_tensor_collection, + NonTensorData, + PersistentTensorDict, + TensorDict, +) +from tensordict.memmap import MemoryMappedTensor +from tensordict.utils import _STRDTYPE2DTYPE +from torchrl.data.replay_buffers.utils import ( + _save_pytree, + Flat2TED, + H5Combine, + H5Split, + Nested2TED, + TED2Flat, + TED2Nested, +) + + +class StorageCheckpointerBase: + """Public base class for storage checkpointers. + + Each storage checkpointer must implement a `save` and `load` method that take as input a storage and a + path. + + """ + + @abc.abstractmethod + def dumps(self, storage, path): + ... + + @abc.abstractmethod + def loads(self, storage, path): + ... + + +class ListStorageCheckpointer(StorageCheckpointerBase): + """A storage checkpointer for ListStoage. + + Currently not implemented. + + """ + + @staticmethod + def dumps(storage, path): + raise NotImplementedError( + "ListStorage doesn't support serialization via `dumps` - `loads` API." + ) + + @staticmethod + def loads(storage, path): + raise NotImplementedError( + "ListStorage doesn't support serialization via `dumps` - `loads` API." + ) + + +class TensorStorageCheckpointer(StorageCheckpointerBase): + """A storage checkpointer for TensorStorages. + + This class supports TensorDict-based storages as well as pytrees. + + This class will call save and load hooks if provided. These hooks should take as input the + data being transformed as well as the path where the data should be saved. + + """ + + _save_hooks = [] + _load_hooks = [] + + def dumps(self, storage, path): + path = Path(path) + path.mkdir(exist_ok=True) + + if not storage.initialized: + raise RuntimeError("Cannot save a non-initialized storage.") + metadata = {} + _storage = storage._storage + for hook in self._save_hooks: + _storage = hook(_storage, path=path) + if is_tensor_collection(_storage): + if ( + _storage.is_memmap() + and Path(_storage.saved_path).absolute() == Path(path).absolute() + ): + _storage.memmap_refresh_() + else: + # try to load the path and overwrite. + _storage.memmap( + path, + copy_existing=True, # num_threads=torch.get_num_threads() + ) + is_pytree = False + else: + _save_pytree(_storage, metadata, path) + is_pytree = True + + with open(path / "storage_metadata.json", "w") as file: + json.dump( + { + "metadata": metadata, + "is_pytree": is_pytree, + "len": storage._len, + }, + file, + ) + + def loads(self, storage, path): + with open(path / "storage_metadata.json", "r") as file: + metadata = json.load(file) + is_pytree = metadata["is_pytree"] + _len = metadata["len"] + if is_pytree: + if self._load_hooks: + raise RuntimeError( + "Loading hooks are not compatible with PyTree storages." + ) + path = Path(path) + for local_path, md in metadata["metadata"].items(): + # load tensor + local_path_dot = local_path.replace(".", "/") + total_tensor_path = path / (local_path_dot + ".memmap") + shape = torch.Size(md["shape"]) + dtype = _STRDTYPE2DTYPE[md["dtype"]] + tensor = MemoryMappedTensor.from_filename( + filename=total_tensor_path, shape=shape, dtype=dtype + ) + # split path + local_path = local_path.split(".") + # replace potential dots + local_path = [_path.replace("__", ".") for _path in local_path] + if storage.initialized: + # copy in-place + _storage_tensor = storage._storage + # in this case there is a single tensor, so we skip + if local_path != ["_-single-tensor-_"]: + for _path in local_path: + if _path.isdigit(): + _path_attempt = int(_path) + try: + _storage_tensor = _storage_tensor[_path_attempt] + continue + except IndexError: + pass + _storage_tensor = _storage_tensor[_path] + _storage_tensor.copy_(tensor) + else: + raise RuntimeError( + "Cannot fill a non-initialized pytree-based TensorStorage." + ) + else: + _storage = TensorDict.load_memmap(path) + if storage.initialized: + dest = storage._storage + else: + # TODO: This could load the RAM a lot, maybe try to catch this within the hook and use memmap instead + dest = None + for hook in self._load_hooks: + _storage = hook(_storage, out=dest) + if not storage.initialized: + # this should not be reached if is_pytree=True + storage._init(_storage[0]) + storage._storage.update_(_storage) + else: + storage._storage.copy_(_storage) + storage._len = _len + + +class FlatStorageCheckpointer(TensorStorageCheckpointer): + """Saves the storage in a compact form, saving space on the TED format. + + This class explicitly assumes and does NOT check that: + + - done states (including terminated and truncated) at the root are always False; + - observations in the "next" tensordict are shifted by one step in the future (this + is not the case when a multi-step transform is used for instance) unless `done` is `True` + in which case the observation in `("next", key)` at time `t` and the one in `key` at time + `t+1` should not match. + + .. seealso: The full list of arguments can be found in :class:`~torchrl.data.TED2Flat`. + + """ + + def __init__(self, done_keys=None, reward_keys=None): + kwargs = {} + if done_keys is not None: + kwargs["done_keys"] = done_keys + if reward_keys is not None: + kwargs["reward_keys"] = reward_keys + self._save_hooks = [TED2Flat(**kwargs)] + self._load_hooks = [Flat2TED(**kwargs)] + + def _save_shift_is_full(self, storage): + is_full = storage._is_full + last_cursor = storage._last_cursor + for hook in self._save_hooks: + if hasattr(hook, "is_full"): + hook.is_full = is_full + if last_cursor is None: + warnings.warn( + "las_cursor is None. The replay buffer " + "may not be saved properly in this setting. To solve this issue, make " + "sure the storage updates the _las_cursor value during calls to `set`." + ) + shift = self._get_shift_from_last_cursor(last_cursor) + for hook in self._save_hooks: + if hasattr(hook, "shift"): + hook.shift = shift + + def dumps(self, storage, path): + self._save_shift_is_full(storage) + return super().dumps(storage, path) + + def _get_shift_from_last_cursor(self, last_cursor): + if isinstance(last_cursor, slice): + return last_cursor.stop + 1 + if isinstance(last_cursor, int): + return last_cursor + 1 + if isinstance(last_cursor, torch.Tensor): + return last_cursor.reshape(-1)[-1].item() + 1 + if isinstance(last_cursor, np.ndarray): + return last_cursor.reshape(-1)[-1].item() + 1 + raise ValueError(f"Unrecognised last_cursor type {type(last_cursor)}.") + + +class NestedStorageCheckpointer(FlatStorageCheckpointer): + """Saves the storage in a compact form, saving space on the TED format and using memory-mapped nested tensors. + + This class explicitly assumes and does NOT check that: + + - done states (including terminated and truncated) at the root are always False; + - observations in the "next" tensordict are shifted by one step in the future (this + is not the case when a multi-step transform is used for instance). + + .. seealso: The full list of arguments can be found in :class:`~torchrl.data.TED2Flat`. + + """ + + def __init__(self, done_keys=None, reward_keys=None, **kwargs): + kwargs = {} + if done_keys is not None: + kwargs["done_keys"] = done_keys + if reward_keys is not None: + kwargs["reward_keys"] = reward_keys + self._save_hooks = [TED2Nested(**kwargs)] + self._load_hooks = [Nested2TED(**kwargs)] + + +class H5StorageCheckpointer(NestedStorageCheckpointer): + """Saves the storage in a compact form, saving space on the TED format and using H5 format to save the data. + + This class explicitly assumes and does NOT check that: + + - done states (including terminated and truncated) at the root are always False; + - observations in the "next" tensordict are shifted by one step in the future (this + is not the case when a multi-step transform is used for instance). + + Keyword Args: + checkpoint_file: the filename where to save the checkpointed data. + This will be ignored iff the path passed to dumps / loads ends with the ``.h5`` + suffix. Defaults to ``"checkpoint.h5"``. + h5_kwargs (Dict[str, Any] or Tuple[Tuple[str, Any], ...]): kwargs to be + passed to :meth:`h5py.File.create_dataset`. + + .. note:: To prevent out-of-memory issues, the data of the H5 file will be temporarily written + on memory-mapped tensors stored in shared file system. The physical memory usage may increase + during loading as a consequence. + + .. seealso: The full list of arguments can be found in :class:`~torchrl.data.TED2Flat`. Note that this class only + supports keyword arguments. + + """ + + def __init__( + self, + *, + checkpoint_file: str = "checkpoint.h5", + done_keys=None, + reward_keys=None, + h5_kwargs=None, + **kwargs, + ): + ted2_kwargs = kwargs + if done_keys is not None: + ted2_kwargs["done_keys"] = done_keys + if reward_keys is not None: + ted2_kwargs["reward_keys"] = reward_keys + self._save_hooks = [TED2Nested(**ted2_kwargs), H5Split()] + self._load_hooks = [H5Combine(), Nested2TED(**ted2_kwargs)] + self.kwargs = {} if h5_kwargs is None else dict(h5_kwargs) + self.checkpoint_file = checkpoint_file + + def dumps(self, storage, path): + path = self._get_path(path) + + self._save_shift_is_full(storage) + + if not storage.initialized: + raise RuntimeError("Cannot save a non-initialized storage.") + _storage = storage._storage + length = storage._len + for hook in self._save_hooks: + # we don't pass a path here since we're not reusing the tensordict + _storage = hook(_storage) + if is_tensor_collection(_storage): + # try to load the path and overwrite. + data = PersistentTensorDict.from_dict(_storage, path, **self.kwargs) + data["_len"] = NonTensorData(data=length) + else: + raise ValueError("Only tensor collections are supported.") + + def loads(self, storage, path): + path = self._get_path(path) + data = PersistentTensorDict.from_h5(path) + if storage.initialized: + dest = storage._storage + else: + # TODO: This could load the RAM a lot, maybe try to catch this within the hook and use memmap instead + dest = None + _len = data["_len"] + for hook in self._load_hooks: + data = hook(data, out=dest) + if not storage.initialized: + # this should not be reached if is_pytree=True + storage._init(data[0]) + storage._storage.update_(data) + else: + storage._storage.copy_(data) + storage._len = _len + + def _get_path(self, path): + path = Path(path) + if path.suffix == ".h5": + return str(path.absolute()) + try: + path.mkdir(exist_ok=True) + except Exception: + raise RuntimeError(f"Failed to create the checkpoint directory {path}.") + path = path / self.checkpoint_file + return str(path.absolute()) + + +class StorageEnsembleCheckpointer(StorageCheckpointerBase): + """Checkpointer for ensemble storages.""" + + @staticmethod + def dumps(storage, path: Path): + path = Path(path).absolute() + storages = storage._storages + for i, storage in enumerate(storages): + storage.dumps(path / str(i)) + if storage._transforms is not None: + for i, transform in enumerate(storage._transforms): + torch.save(transform.state_dict(), path / f"{i}_transform.pt") + + @staticmethod + def loads(storage, path: Path): + path = Path(path).absolute() + for i, _storage in enumerate(storage._storages): + _storage.loads(path / str(i)) + if storage._transforms is not None: + for i, transform in enumerate(storage._transforms): + transform.load_state_dict(torch.load(path / f"{i}_transform.pt")) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 77d206b9a11..389c16fa785 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -203,6 +203,7 @@ def __init__( transform: "Transform" | None = None, # noqa-F821 batch_size: int | None = None, dim_extend: int | None = None, + checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821 ) -> None: self._storage = storage if storage is not None else ListStorage(max_size=1_000) self._storage.attach(self) @@ -260,6 +261,7 @@ def __init__( if dim_extend is not None and dim_extend < 0: raise ValueError("dim_extend must be a positive value.") self.dim_extend = dim_extend + self._storage.checkpointer = checkpointer @property def dim_extend(self): @@ -355,7 +357,7 @@ def __getitem__(self, index: int | torch.Tensor | NestedKey) -> Any: if isinstance(index, str) or (isinstance(index, tuple) and unravel_key(index)): return self[:][index] if isinstance(index, tuple): - if len(index) > 1: + if len(index) == 1: return self[index[0]] else: return self[:][index] @@ -468,6 +470,35 @@ def loads(self, path): metadata = json.load(file) self._batch_size = metadata["batch_size"] + def save(self, *args, **kwargs): + """Alias for :meth:`~.dumps`.""" + return self.dumps(*args, **kwargs) + + def dump(self, *args, **kwargs): + """Alias for :meth:`~.dumps`.""" + return self.dumps(*args, **kwargs) + + def load(self, *args, **kwargs): + """Alias for :meth:`~.loads`.""" + return self.loads(*args, **kwargs) + + def register_save_hook(self, hook: Callable[[Any], Any]): + """Registers a save hook for the storage. + + .. note:: Hooks are currently not serialized when saving a replay buffer: they must + be manually re-initialized every time the buffer is created. + """ + self._storage.register_save_hook(hook) + + def register_load_hook(self, hook: Callable[[Any], Any]): + """Registers a load hook for the storage. + + .. note:: Hooks are currently not serialized when saving a replay buffer: they must + be manually re-initialized every time the buffer is created. + + """ + self._storage.register_load_hook(hook) + def add(self, data: Any) -> int: """Add a single element to the replay buffer. @@ -992,12 +1023,12 @@ class TensorDictReplayBuffer(ReplayBuffer): """ - def __init__(self, *, priority_key: str = "td_error", **kw) -> None: - writer = kw.get("writer", None) + def __init__(self, *, priority_key: str = "td_error", **kwargs) -> None: + writer = kwargs.get("writer", None) if writer is None: - kw["writer"] = TensorDictRoundRobinWriter() + kwargs["writer"] = TensorDictRoundRobinWriter() - super().__init__(**kw) + super().__init__(**kwargs) self.priority_key = priority_key def _get_priority_item(self, tensordict: TensorDictBase) -> float: diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 6058de290e3..373f05b9852 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -5,14 +5,12 @@ from __future__ import annotations import abc -import json import os import textwrap import warnings from collections import OrderedDict from copy import copy from multiprocessing.context import get_spawning_popen -from pathlib import Path from typing import Any, Dict, List, Sequence, Union import numpy as np @@ -25,16 +23,17 @@ TensorDictBase, ) from tensordict.memmap import MemoryMappedTensor -from tensordict.utils import _STRDTYPE2DTYPE from torch import multiprocessing as mp -from torch.utils._pytree import LeafSpec, tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torchrl._utils import implement_for, logger as torchrl_logger -from torchrl.data.replay_buffers.utils import _is_int, INT_CLASSES - -SINGLE_TENSOR_BUFFER_NAME = os.environ.get( - "SINGLE_TENSOR_BUFFER_NAME", "_-single-tensor-_" +from torchrl.data.replay_buffers.checkpointers import ( + ListStorageCheckpointer, + StorageCheckpointerBase, + StorageEnsembleCheckpointer, + TensorStorageCheckpointer, ) +from torchrl.data.replay_buffers.utils import _init_pytree, _is_int, INT_CLASSES class Storage: @@ -50,9 +49,23 @@ class Storage: ndim = 1 max_size: int + _default_checkpointer: StorageCheckpointerBase - def __init__(self, max_size: int) -> None: + def __init__( + self, max_size: int, checkpointer: StorageCheckpointerBase | None = None + ) -> None: self.max_size = int(max_size) + self.checkpointer = checkpointer + + @property + def checkpointer(self): + return self._checkpointer + + @checkpointer.setter + def checkpointer(self, value: StorageCheckpointerBase | None) -> None: + if value is None: + value = self._default_checkpointer() + self._checkpointer = value @property def _is_full(self): @@ -76,13 +89,11 @@ def set(self, cursor: int, data: Any): def get(self, index: int) -> Any: ... - @abc.abstractmethod def dumps(self, path): - ... + self.checkpointer.dumps(self, path) - @abc.abstractmethod def loads(self, path): - ... + self.checkpointer.loads(self, path) def attach(self, buffer: Any) -> None: """This function attaches a sampler to this storage. @@ -160,6 +171,18 @@ def flatten(self): f"Please report this exception as well as the use case (incl. buffer construction) on github." ) + def save(self, *args, **kwargs): + """Alias for :meth:`~.dumps`.""" + return self.dumps(*args, **kwargs) + + def dump(self, *args, **kwargs): + """Alias for :meth:`~.dumps`.""" + return self.dumps(*args, **kwargs) + + def load(self, *args, **kwargs): + """Alias for :meth:`~.loads`.""" + return self.loads(*args, **kwargs) + class ListStorage(Storage): """A storage stored in a list. @@ -173,20 +196,12 @@ class ListStorage(Storage): """ + _default_checkpointer = ListStorageCheckpointer + def __init__(self, max_size: int): super().__init__(max_size) self._storage = [] - def dumps(self, path): - raise NotImplementedError( - "ListStorage doesn't support serialization via `dumps` - `loads` API." - ) - - def loads(self, path): - raise NotImplementedError( - "ListStorage doesn't support serialization via `dumps` - `loads` API." - ) - def set(self, cursor: Union[int, Sequence[int], slice], data: Any): if not isinstance(cursor, INT_CLASSES): if (isinstance(cursor, torch.Tensor) and cursor.numel() <= 1) or ( @@ -351,6 +366,7 @@ class TensorStorage(Storage): """ _storage = None + _default_checkpointer = TensorStorageCheckpointer def __init__( self, @@ -388,82 +404,7 @@ def __init__( else "auto" ) self._storage = storage - - def dumps(self, path): - path = Path(path) - path.mkdir(exist_ok=True) - - if not self.initialized: - raise RuntimeError("Cannot save a non-initialized storage.") - metadata = {} - if is_tensor_collection(self._storage): - # try to load the path and overwrite. - self._storage.memmap( - path, copy_existing=True, num_threads=torch.get_num_threads() - ) - is_pytree = False - else: - _save_pytree(self._storage, metadata, path) - is_pytree = True - - with open(path / "storage_metadata.json", "w") as file: - json.dump( - { - "metadata": metadata, - "is_pytree": is_pytree, - "len": self._len, - }, - file, - ) - - def loads(self, path): - with open(path / "storage_metadata.json", "r") as file: - metadata = json.load(file) - is_pytree = metadata["is_pytree"] - _len = metadata["len"] - if is_pytree: - path = Path(path) - for local_path, md in metadata["metadata"].items(): - # load tensor - local_path_dot = local_path.replace(".", "/") - total_tensor_path = path / (local_path_dot + ".memmap") - shape = torch.Size(md["shape"]) - dtype = _STRDTYPE2DTYPE[md["dtype"]] - tensor = MemoryMappedTensor.from_filename( - filename=total_tensor_path, shape=shape, dtype=dtype - ) - # split path - local_path = local_path.split(".") - # replace potential dots - local_path = [_path.replace("__", ".") for _path in local_path] - if self.initialized: - # copy in-place - _storage_tensor = self._storage - # in this case there is a single tensor, so we skip - if local_path != ["_-single-tensor-_"]: - for _path in local_path: - if _path.isdigit(): - _path_attempt = int(_path) - try: - _storage_tensor = _storage_tensor[_path_attempt] - continue - except IndexError: - pass - _storage_tensor = _storage_tensor[_path] - _storage_tensor.copy_(tensor) - else: - raise RuntimeError( - "Cannot fill a non-initialized pytree-based TensorStorage." - ) - else: - _storage = TensorDict.load_memmap(path) - if not self.initialized: - # this should not be reached if is_pytree=True - self._init(_storage[0]) - self._storage.update_(_storage) - else: - self._storage.copy_(_storage) - self._len = _len + self._last_cursor = None @property def _len(self): @@ -689,6 +630,8 @@ def set( data: Union[TensorDictBase, torch.Tensor], ): + self._last_cursor = cursor + if isinstance(data, list): # flip list try: @@ -726,6 +669,8 @@ def set( # noqa: F811 data: Union[TensorDictBase, torch.Tensor], ): + self._last_cursor = cursor + if isinstance(data, list): # flip list try: @@ -892,6 +837,8 @@ class LazyTensorStorage(TensorStorage): """ + _default_checkpointer = TensorStorageCheckpointer + def __init__( self, max_size: int, @@ -911,10 +858,12 @@ def _init( def max_size_along_dim0(data_shape): if self.ndim > 1: - return ( + result = ( -(self.max_size // -data_shape[: self.ndim - 1].numel()), *data_shape, ) + self.max_size = torch.Size(result).numel() + return result return (self.max_size, *data_shape) if is_tensor_collection(data): @@ -1003,6 +952,8 @@ class LazyMemmapStorage(LazyTensorStorage): """ + _default_checkpointer = TensorStorageCheckpointer + def __init__( self, max_size: int, @@ -1094,10 +1045,12 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: def max_size_along_dim0(data_shape): if self.ndim > 1: - return ( + result = ( -(self.max_size // -data_shape[: self.ndim - 1].numel()), *data_shape, ) + self.max_size = torch.Size(result).numel() + return result return (self.max_size, *data_shape) if is_tensor_collection(data): @@ -1147,6 +1100,8 @@ class StorageEnsemble(Storage): """ + _default_checkpointer = StorageEnsembleCheckpointer + def __init__( self, *storages: Storage, @@ -1194,22 +1149,6 @@ def _convert_id(self, sub): def _get_storage(self, sub): return self._storages[sub] - def dumps(self, path: Path): - path = Path(path).absolute() - for i, storage in enumerate(self._storages): - storage.dumps(path / str(i)) - if self._transforms is not None: - for i, transform in enumerate(self._transforms): - torch.save(transform.state_dict(), path / f"{i}_transform.pt") - - def loads(self, path: Path): - path = Path(path).absolute() - for i, storage in enumerate(self._storages): - storage.loads(path / str(i)) - if self._transforms is not None: - for i, transform in enumerate(self._transforms): - transform.load_state_dict(torch.load(path / f"{i}_transform.pt")) - def state_dict(self) -> Dict[str, Any]: raise NotImplementedError @@ -1327,172 +1266,6 @@ def _make_empty_memmap(shape, dtype, path): return MemoryMappedTensor.empty(shape=shape, dtype=dtype, filename=path) -@implement_for("torch", "2.3", None) -def _path2str(path, default_name=None): - # Uses the Keys defined in pytree to build a path - from torch.utils._pytree import MappingKey, SequenceKey - - if default_name is None: - default_name = SINGLE_TENSOR_BUFFER_NAME - if not path: - return default_name - if isinstance(path, tuple): - return "/".join([_path2str(_sub, default_name=default_name) for _sub in path]) - if isinstance(path, MappingKey): - if not isinstance(path.key, (int, str, bytes)): - raise ValueError("Values must be of type int, str or bytes in PyTree maps.") - result = str(path.key) - if result == default_name: - raise RuntimeError( - "A tensor had the same identifier as the default name used when the buffer contains " - f"a single tensor (name={default_name}). This behaviour is not allowed. Please rename your " - f"tensor in the map/dict or set a new default name with the environment variable SINGLE_TENSOR_BUFFER_NAME." - ) - return result - if isinstance(path, SequenceKey): - return str(path.idx) - - -@implement_for("torch", None, "2.3") -def _path2str(path, default_name=None): # noqa: F811 - raise RuntimeError - - -def _get_paths(spec, cumulpath=""): - # alternative way to build a path without the keys - if isinstance(spec, LeafSpec): - yield cumulpath if cumulpath else SINGLE_TENSOR_BUFFER_NAME - - contexts = spec.context - children_specs = spec.children_specs - if contexts is None: - contexts = range(len(children_specs)) - - for context, spec in zip(contexts, children_specs): - cpath = "/".join((cumulpath, str(context))) if cumulpath else str(context) - yield from _get_paths(spec, cpath) - - -def _save_pytree_common(tensor_path, path, tensor, metadata): - if "." in tensor_path: - tensor_path.replace(".", "__") - total_tensor_path = path / (tensor_path + ".memmap") - if os.path.exists(total_tensor_path): - MemoryMappedTensor.from_filename( - shape=tensor.shape, - filename=total_tensor_path, - dtype=tensor.dtype, - ).copy_(tensor) - else: - os.makedirs(total_tensor_path.parent, exist_ok=True) - MemoryMappedTensor.from_tensor( - tensor, - filename=total_tensor_path, - copy_existing=True, - copy_data=True, - ) - key = tensor_path.replace("/", ".") - if key in metadata: - raise KeyError( - "At least two values have conflicting representations in " - f"the data structure to be serialized: {key}." - ) - metadata[key] = { - "dtype": str(tensor.dtype), - "shape": list(tensor.shape), - } - - -@implement_for("torch", "2.3", None) -def _save_pytree(_storage, metadata, path): - from torch.utils._pytree import tree_map_with_path - - def save_tensor( - tensor_path: tuple, tensor: torch.Tensor, metadata=metadata, path=path - ): - tensor_path = _path2str(tensor_path) - _save_pytree_common(tensor_path, path, tensor, metadata) - - tree_map_with_path(save_tensor, _storage) - - -@implement_for("torch", None, "2.3") -def _save_pytree(_storage, metadata, path): # noqa: F811 - - flat_storage, storage_specs = tree_flatten(_storage) - storage_paths = _get_paths(storage_specs) - - def save_tensor( - tensor_path: str, tensor: torch.Tensor, metadata=metadata, path=path - ): - _save_pytree_common(tensor_path, path, tensor, metadata) - - for tensor, tensor_path in zip(flat_storage, storage_paths): - save_tensor(tensor_path, tensor) - - -def _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor): - if "." in tensor_path: - tensor_path.replace(".", "__") - if scratch_dir is not None: - total_tensor_path = Path(scratch_dir) / (tensor_path + ".memmap") - if os.path.exists(total_tensor_path): - raise RuntimeError( - f"The storage of tensor {total_tensor_path} already exists. " - f"To load an existing replay buffer, use storage.loads. " - f"Choose a different path to store your buffer or delete the existing files." - ) - os.makedirs(total_tensor_path.parent, exist_ok=True) - else: - total_tensor_path = None - out = MemoryMappedTensor.empty( - shape=max_size_fn(tensor.shape), - filename=total_tensor_path, - dtype=tensor.dtype, - ) - try: - filesize = os.path.getsize(tensor.filename) / 1024 / 1024 - torchrl_logger.debug( - f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." - ) - except (RuntimeError, AttributeError): - pass - return out - - -@implement_for("torch", "2.3", None) -def _init_pytree(scratch_dir, max_size_fn, data): - from torch.utils._pytree import tree_map_with_path - - # If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree - # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype - def save_tensor(tensor_path: tuple, tensor: torch.Tensor): - tensor_path = _path2str(tensor_path) - return _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor) - - out = tree_map_with_path(save_tensor, data) - return out - - -@implement_for("torch", None, "2.3") -def _init_pytree(scratch_dir, max_size, data): # noqa: F811 - - flat_data, data_specs = tree_flatten(data) - data_paths = _get_paths(data_specs) - data_paths = list(data_paths) - - # If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree - # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype - def save_tensor(tensor_path: str, tensor: torch.Tensor): - return _init_pytree_common(tensor_path, scratch_dir, max_size, tensor) - - out = [] - for tensor, tensor_path in zip(flat_data, data_paths): - out.append(save_tensor(tensor_path, tensor)) - - return tree_unflatten(out, data_specs) - - def _flip_list(data): if all(is_tensor_collection(_data) for _data in data): return torch.stack(data) diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index 86337a56742..1fe0eb077c5 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -5,13 +5,33 @@ # import tree from __future__ import annotations +import contextlib + +import math +import os import typing +from pathlib import Path from typing import Any, Callable, Union import numpy as np import torch - +from tensordict import ( + LazyStackedTensorDict, + MemoryMappedTensor, + NonTensorData, + TensorDict, + TensorDictBase, + unravel_key, +) from torch import Tensor +from torch.nn import functional as F +from torch.utils._pytree import LeafSpec, tree_flatten, tree_unflatten +from torchrl._utils import implement_for, logger as torchrl_logger + +SINGLE_TENSOR_BUFFER_NAME = os.environ.get( + "SINGLE_TENSOR_BUFFER_NAME", "_-single-tensor-_" +) + INT_CLASSES_TYPING = Union[int, np.integer] if hasattr(typing, "get_args"): @@ -96,3 +116,836 @@ def _is_int(index): if isinstance(index, (np.ndarray, torch.Tensor)): return index.ndim == 0 return False + + +class TED2Flat: + """A storage saving hook to serialize TED data in a compact format. + + Args: + done_key (NestedKey, optional): the key where the done states should be read. + Defaults to ``("next", "done")``. + shift_key (NestedKey, optional): the key where the shift will be written. + Defaults to "shift". + is_full_key (NestedKey, optional): the key where the is_full attribute will be written. + Defaults to "is_full". + done_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the done entries. + Defaults to ("done", "truncated", "terminated") + reward_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the reward entries. + Defaults to ("reward",) + + + Examples: + >>> import tempfile + >>> + >>> from tensordict import TensorDict + >>> + >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage + >>> from torchrl.envs import GymEnv + >>> import torch + >>> + >>> env = GymEnv("CartPole-v1") + >>> env.set_seed(0) + >>> torch.manual_seed(0) + >>> collector = SyncDataCollector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200) + >>> rb = ReplayBuffer(storage=LazyMemmapStorage(200)) + >>> rb.register_save_hook(TED2Flat()) + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... for i, data in enumerate(collector): + ... rb.extend(data) + ... rb.dumps(tmpdir) + ... # load the data to represent it + ... td = TensorDict.load(tmpdir + "/storage/") + ... print(td) + TensorDict( + fields={ + action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=True), + collector: TensorDict( + fields={ + traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=True)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True), + observation: MemoryMappedTensor(shape=torch.Size([220, 4]), device=cpu, dtype=torch.float32, is_shared=True), + reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=True), + terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True), + truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + + """ + + _shift: int = None + _is_full: bool = None + + def __init__( + self, + done_key=("next", "done"), + shift_key="shift", + is_full_key="is_full", + done_keys=("done", "truncated", "terminated"), + reward_keys=("reward",), + ): + self.done_key = done_key + self.shift_key = shift_key + self.is_full_key = is_full_key + self.done_keys = {unravel_key(key) for key in done_keys} + self.reward_keys = {unravel_key(key) for key in reward_keys} + + @property + def shift(self): + return self._shift + + @shift.setter + def shift(self, value: int): + self._shift = value + + @property + def is_full(self): + return self._is_full + + @is_full.setter + def is_full(self, value: int): + self._is_full = value + + def __call__(self, data: TensorDictBase, path: Path = None): + # Get the done state + shift = self.shift + is_full = self.is_full + + # Create an output storage + output = TensorDict() + output.set_non_tensor(self.is_full_key, is_full) + output.set_non_tensor(self.shift_key, shift) + output.set_non_tensor("_storage_shape", tuple(data.shape)) + output.memmap_(path) + + # Preallocate the output + done = data.get(self.done_key).squeeze(-1).clone() + if not is_full: + # shift is the cursor place + done[shift - 1] = True + else: + done = done.roll(-shift, dims=0) + done[-1] = True + ntraj = done.sum() + + # Get the keys that require extra storage + keys_to_expand = set(data.get("next").keys(True, True)) - ( + self.done_keys.union(self.reward_keys) + ) + + total_keys = data.exclude("next").keys(True, True) + total_keys = set(total_keys).union(set(data.get("next").keys(True, True))) + + len_with_offset = data.numel() + ntraj # + done[0].numel() + for key in total_keys: + if key in (self.done_keys.union(self.reward_keys)): + entry = data.get(("next", key)) + else: + entry = data.get(key) + + if key in keys_to_expand: + shape = torch.Size([len_with_offset, *entry.shape[data.ndim :]]) + dtype = entry.dtype + output.make_memmap(key, shape=shape, dtype=dtype) + else: + shape = torch.Size([data.numel(), *entry.shape[data.ndim :]]) + output.make_memmap(key, shape=shape, dtype=entry.dtype) + + if data.ndim == 1: + return self._call( + data=data, + output=output, + is_full=is_full, + shift=shift, + done=done, + total_keys=total_keys, + keys_to_expand=keys_to_expand, + ) + + with data.flatten(1, -1) if data.ndim > 2 else contextlib.nullcontext( + data + ) as data_flat: + if data.ndim > 2: + done = done.flatten(1, -1) + traj_per_dim = done.sum(0) + nsteps = data_flat.shape[0] + + start = 0 + start_with_offset = start + stop_with_offset = 0 + stop = 0 + for data_slice, done_slice, traj_for_dim in zip( + data_flat.unbind(1), done.unbind(1), traj_per_dim + ): + stop_with_offset = stop_with_offset + nsteps + traj_for_dim + cur_slice_offset = slice(start_with_offset, stop_with_offset) + start_with_offset = stop_with_offset + + stop = stop + data.shape[0] + cur_slice = slice(start, stop) + start = stop + + def _index( + key, + val, + keys_to_expand=keys_to_expand, + cur_slice=cur_slice, + cur_slice_offset=cur_slice_offset, + ): + if key in keys_to_expand: + return val[cur_slice_offset] + return val[cur_slice] + + out_slice = output.named_apply(_index, nested_keys=True) + self._call( + data=data_slice, + output=out_slice, + is_full=is_full, + shift=shift, + done=done_slice, + total_keys=total_keys, + keys_to_expand=keys_to_expand, + ) + return output + + def _call(self, *, data, output, is_full, shift, done, total_keys, keys_to_expand): + # capture for each item in data where the observation should be written + idx = torch.arange(data.shape[0]) + idx_done = (idx + done.cumsum(0))[done] + idx += torch.nn.functional.pad(done, [1, 0])[:-1].cumsum(0) + + for key in total_keys: + if key in (self.done_keys.union(self.reward_keys)): + entry = data.get(("next", key)) + else: + entry = data.get(key) + + if key in keys_to_expand: + mmap = output.get(key) + shifted_next = data.get(("next", key)) + if is_full: + _roll_inplace(entry, shift=-shift, out=mmap, index_dest=idx) + _roll_inplace( + shifted_next, + shift=-shift, + out=mmap, + index_dest=idx_done, + index_source=done, + ) + else: + mmap[idx] = entry + mmap[idx_done] = shifted_next[done] + elif is_full: + mmap = output.get(key) + _roll_inplace(entry, shift=-shift, out=mmap) + else: + mmap = output.get(key) + mmap.copy_(entry) + return output + + +class Flat2TED: + """A storage loading hook to deserialize flattened TED data to TED format. + + Args: + done_key (NestedKey, optional): the key where the done states should be read. + Defaults to ``("next", "done")``. + shift_key (NestedKey, optional): the key where the shift will be written. + Defaults to "shift". + is_full_key (NestedKey, optional): the key where the is_full attribute will be written. + Defaults to "is_full". + done_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the done entries. + Defaults to ("done", "truncated", "terminated") + reward_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the reward entries. + Defaults to ("reward",) + + Examples: + >>> import tempfile + >>> + >>> from tensordict import TensorDict + >>> + >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage, Flat2TED + >>> from torchrl.envs import GymEnv + >>> import torch + >>> + >>> env = GymEnv("CartPole-v1") + >>> env.set_seed(0) + >>> torch.manual_seed(0) + >>> collector = SyncDataCollector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200) + >>> rb = ReplayBuffer(storage=LazyMemmapStorage(200)) + >>> rb.register_save_hook(TED2Flat()) + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... for i, data in enumerate(collector): + ... rb.extend(data) + ... rb.dumps(tmpdir) + ... # load the data to represent it + ... td = TensorDict.load(tmpdir + "/storage/") + ... + ... rb_load = ReplayBuffer(storage=LazyMemmapStorage(200)) + ... rb_load.register_load_hook(Flat2TED()) + ... rb_load.load(tmpdir) + ... print("storage after loading", rb_load[:]) + ... assert (rb[:] == rb_load[:]).all() + storage after loading TensorDict( + fields={ + action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False), + reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + + + """ + + def __init__( + self, + done_key="done", + shift_key="shift", + is_full_key="is_full", + done_keys=("done", "truncated", "terminated"), + reward_keys=("reward",), + ): + self.done_key = done_key + self.shift_key = shift_key + self.is_full_key = is_full_key + self.done_keys = {unravel_key(key) for key in done_keys} + self.reward_keys = {unravel_key(key) for key in reward_keys} + + def __call__(self, data: TensorDictBase, out: TensorDictBase = None): + _storage_shape = data.get_non_tensor("_storage_shape", default=None) + if isinstance(_storage_shape, int): + _storage_shape = torch.Size([_storage_shape]) + shift = data.get_non_tensor(self.shift_key, default=None) + is_full = data.get_non_tensor(self.is_full_key, default=None) + done = ( + data.get("done") + .reshape((*_storage_shape[1:], -1)) + .contiguous() + .permute(-1, *range(0, len(_storage_shape) - 1)) + .clone() + ) + if not is_full: + # shift is the cursor place + done[shift - 1] = True + else: + # done = done.roll(-shift, dims=0) + done[-1] = True + + if _storage_shape is not None and len(_storage_shape) > 1: + # iterate over data and allocate + if out is None: + out = TensorDict(batch_size=_storage_shape) + for i in range(out.ndim): + if i >= 2: + # FLattening the lazy stack will make the data unavailable - we need to find a way to make this + # possible. + raise RuntimeError( + "Checkpointing an uninitialized buffer with more than 2 dimensions is currently not supported. " + "Please file an issue on GitHub to ask for this feature!" + ) + out = LazyStackedTensorDict(*out.unbind(i), stack_dim=i) + + # Create a function that reads slices of the input data + with out.flatten(1, -1) if out.ndim > 2 else contextlib.nullcontext( + out + ) as out_flat: + nsteps = done.shape[0] + n_elt_batch = done.shape[1:].numel() + traj_per_dim = done.sum(0) + + start = 0 + start_with_offset = start + stop_with_offset = 0 + stop = 0 + + for out_unbound, traj_for_dim in zip(out_flat.unbind(-1), traj_per_dim): + stop_with_offset = stop_with_offset + nsteps + traj_for_dim + cur_slice_offset = slice(start_with_offset, stop_with_offset) + start_with_offset = stop_with_offset + + stop = stop + nsteps + cur_slice = slice(start, stop) + start = stop + + def _index( + key, + val, + cur_slice=cur_slice, + nsteps=nsteps, + n_elt_batch=n_elt_batch, + cur_slice_offset=cur_slice_offset, + ): + if val.shape[0] != (nsteps * n_elt_batch): + return val[cur_slice_offset] + return val[cur_slice] + + data_slice = data.named_apply( + _index, nested_keys=True, batch_size=[] + ) + self._call( + data=data_slice, + out=out_unbound, + is_full=is_full, + shift=shift, + _storage_shape=_storage_shape, + ) + return out + return self._call( + data=data, + out=out, + is_full=is_full, + shift=shift, + _storage_shape=_storage_shape, + ) + + def _call(self, *, data, out, _storage_shape, shift, is_full): + done = data.get(self.done_key) + done = done.clone() + + nsteps = done.shape[0] + + # capture for each item in data where the observation should be written + idx = torch.arange(done.shape[0]) + padded_done = F.pad(done.squeeze(-1), [1, 0]) + root_idx = idx + padded_done[:-1].cumsum(0) + next_idx = root_idx + 1 + + if out is None: + out = TensorDict({}, [nsteps]) + + def maybe_roll(entry, out=None): + if is_full and shift is not None: + if out is not None: + _roll_inplace(entry, shift=shift, out=out) + return + else: + return entry.roll(shift, dims=0) + if out is not None: + out.copy_(entry) + return + return entry + + root_idx = maybe_roll(root_idx) + next_idx = maybe_roll(next_idx) + if not is_full: + next_idx = next_idx[:-1] + + for key, entry in data.items(True, True): + if entry.shape[0] == nsteps: + if key in (self.done_keys.union(self.reward_keys)): + if key != "reward" and key not in out.keys(True, True): + # Create a done state at the root full of 0s + out.set(key, torch.zeros_like(entry), inplace=True) + entry = maybe_roll(entry, out=out.get(("next", key), None)) + if entry is not None: + out.set(("next", key), entry, inplace=True) + else: + # action and similar + entry = maybe_roll(entry, out=out.get(key, default=None)) + if entry is not None: + # then out is not locked + out.set(key, entry, inplace=True) + else: + dest_next = out.get(("next", key), None) + if dest_next is not None: + if not is_full: + dest_next = dest_next[:-1] + dest_next.copy_(entry[next_idx]) + else: + if not is_full: + val = entry[next_idx] + val = torch.cat([val, torch.zeros_like(val[:1])]) + out.set(("next", key), val, inplace=True) + else: + out.set(("next", key), entry[next_idx], inplace=True) + + dest = out.get(key, None) + if dest is not None: + dest.copy_(entry[root_idx]) + else: + out.set(key, entry[root_idx], inplace=True) + return out + + +class TED2Nested(TED2Flat): + """Converts a TED-formatted dataset into a tensordict populated with nested tensors where each row is a trajectory.""" + + _shift: int = None + _is_full: bool = None + + def __call__(self, data: TensorDictBase, path: Path = None): + data = super().__call__(data, path=path) + + shift = self.shift + is_full = self.is_full + storage_shape = data.get_non_tensor("_storage_shape", (-1,)) + # place time at the end + storage_shape = (*storage_shape[1:], storage_shape[0]) + + done = data.get("done") + done = done.squeeze(-1).clone() + if not is_full: + done.view(storage_shape)[..., shift - 1] = True + # else: + done.view(storage_shape)[..., -1] = True + + ntraj = done.sum() + + nz = done.nonzero(as_tuple=True)[0] + traj_lengths = torch.cat([nz[:1] + 1, nz.diff()]) + # if not is_full: + # traj_lengths = torch.cat( + # [traj_lengths, (done.shape[0] - traj_lengths.sum()).unsqueeze(0)] + # ) + + keys_to_expand, keys_to_keep = zip( + *[ + (key, None) if val.shape[0] != done.shape[0] else (None, key) + for key, val in data.items(True, True) + ] + ) + keys_to_expand = [key for key in keys_to_expand if key is not None] + keys_to_keep = [key for key in keys_to_keep if key is not None] + + out = TensorDict({}, batch_size=[ntraj]) + out.update(dict(data.non_tensor_items())) + + out.memmap_(path) + + traj_lengths = traj_lengths.unsqueeze(-1) + if not is_full: + # Increment by one only the trajectories that are not terminal + traj_lengths_expand = traj_lengths + ( + traj_lengths.cumsum(0) % storage_shape[-1] != 0 + ) + else: + traj_lengths_expand = traj_lengths + 1 + for key in keys_to_expand: + val = data.get(key) + shape = torch.cat( + [ + traj_lengths_expand, + torch.tensor(val.shape[1:], dtype=torch.long).repeat( + traj_lengths.numel(), 1 + ), + ], + -1, + ) + # This works because the storage location is the same as the previous one - no copy is done + # but a new shape is written + out.make_memmap_from_storage( + key, val.untyped_storage(), dtype=val.dtype, shape=shape + ) + for key in keys_to_keep: + val = data.get(key) + shape = torch.cat( + [ + traj_lengths, + torch.tensor(val.shape[1:], dtype=torch.long).repeat( + traj_lengths.numel(), 1 + ), + ], + -1, + ) + out.make_memmap_from_storage( + key, val.untyped_storage(), dtype=val.dtype, shape=shape + ) + return out + + +class Nested2TED(Flat2TED): + """Converts a nested tensordict where each row is a trajectory into the TED format.""" + + def __call__(self, data, out: TensorDictBase = None): + # Get a flat representation of data + def flatten_het_dim(tensor): + shape = [tensor.size(i) for i in range(2, tensor.ndim)] + tensor = torch.tensor(tensor.untyped_storage(), dtype=tensor.dtype).view( + -1, *shape + ) + return tensor + + data = data.apply(flatten_het_dim, batch_size=[]) + data.auto_batch_size_() + return super().__call__(data, out=out) + + +class H5Split(TED2Flat): + """Splits a dataset prepared with TED2Nested into a TensorDict where each trajectory is stored as views on their parent nested tensors.""" + + _shift: int = None + _is_full: bool = None + + def __call__(self, data): + nzeros = int(math.ceil(math.log10(data.shape[0]))) + + result = TensorDict( + { + f"traj_{str(i).zfill(nzeros)}": _data + for i, _data in enumerate(data.filter_non_tensor_data().unbind(0)) + } + ).update(dict(data.non_tensor_items())) + + return result + + +class H5Combine: + """Combines trajectories in a persistent tensordict into a single standing tensordict stored in filesystem.""" + + def __call__(self, data, out=None): + # TODO: this load the entire H5 in memory, which can be problematic + # Ideally we would want to load it on a memmap tensordict + # We currently ignore out in this call but we should leverage that + values = [val for key, val in data.items() if key.startswith("traj")] + metadata_keys = [key for key in data.keys() if not key.startswith("traj")] + result = TensorDict({key: NonTensorData(data[key]) for key in metadata_keys}) + + # Create a memmap in file system (no files associated) + result.memmap_() + + # Create each entry + def initialize(key, *x): + result.make_memmap( + key, + shape=torch.stack([torch.tensor(_x.shape) for _x in x]), + dtype=x[0].dtype, + ) + return + + values[0].named_apply( + initialize, + *values[1:], + nested_keys=True, + batch_size=[], + filter_empty=True, + ) + + # Populate the entries + def populate(key, *x): + dest = result.get(key) + for i, _x in enumerate(x): + dest[i].copy_(_x) + + values[0].named_apply( + populate, + *values[1:], + nested_keys=True, + batch_size=[], + filter_empty=True, + ) + return result + + +@implement_for("torch", "2.3", None) +def _path2str(path, default_name=None): + # Uses the Keys defined in pytree to build a path + from torch.utils._pytree import MappingKey, SequenceKey + + if default_name is None: + default_name = SINGLE_TENSOR_BUFFER_NAME + if not path: + return default_name + if isinstance(path, tuple): + return "/".join([_path2str(_sub, default_name=default_name) for _sub in path]) + if isinstance(path, MappingKey): + if not isinstance(path.key, (int, str, bytes)): + raise ValueError("Values must be of type int, str or bytes in PyTree maps.") + result = str(path.key) + if result == default_name: + raise RuntimeError( + "A tensor had the same identifier as the default name used when the buffer contains " + f"a single tensor (name={default_name}). This behaviour is not allowed. Please rename your " + f"tensor in the map/dict or set a new default name with the environment variable SINGLE_TENSOR_BUFFER_NAME." + ) + return result + if isinstance(path, SequenceKey): + return str(path.idx) + + +@implement_for("torch", None, "2.3") +def _path2str(path, default_name=None): # noqa: F811 + raise RuntimeError + + +def _save_pytree_common(tensor_path, path, tensor, metadata): + if "." in tensor_path: + tensor_path.replace(".", "__") + total_tensor_path = path / (tensor_path + ".memmap") + if os.path.exists(total_tensor_path): + MemoryMappedTensor.from_filename( + shape=tensor.shape, + filename=total_tensor_path, + dtype=tensor.dtype, + ).copy_(tensor) + else: + os.makedirs(total_tensor_path.parent, exist_ok=True) + MemoryMappedTensor.from_tensor( + tensor, + filename=total_tensor_path, + copy_existing=True, + copy_data=True, + ) + key = tensor_path.replace("/", ".") + if key in metadata: + raise KeyError( + "At least two values have conflicting representations in " + f"the data structure to be serialized: {key}." + ) + metadata[key] = { + "dtype": str(tensor.dtype), + "shape": list(tensor.shape), + } + + +@implement_for("torch", "2.3", None) +def _save_pytree(_storage, metadata, path): + from torch.utils._pytree import tree_map_with_path + + def save_tensor( + tensor_path: tuple, tensor: torch.Tensor, metadata=metadata, path=path + ): + tensor_path = _path2str(tensor_path) + _save_pytree_common(tensor_path, path, tensor, metadata) + + tree_map_with_path(save_tensor, _storage) + + +@implement_for("torch", None, "2.3") +def _save_pytree(_storage, metadata, path): # noqa: F811 + + flat_storage, storage_specs = tree_flatten(_storage) + storage_paths = _get_paths(storage_specs) + + def save_tensor( + tensor_path: str, tensor: torch.Tensor, metadata=metadata, path=path + ): + _save_pytree_common(tensor_path, path, tensor, metadata) + + for tensor, tensor_path in zip(flat_storage, storage_paths): + save_tensor(tensor_path, tensor) + + +def _get_paths(spec, cumulpath=""): + # alternative way to build a path without the keys + if isinstance(spec, LeafSpec): + yield cumulpath if cumulpath else SINGLE_TENSOR_BUFFER_NAME + + contexts = spec.context + children_specs = spec.children_specs + if contexts is None: + contexts = range(len(children_specs)) + + for context, spec in zip(contexts, children_specs): + cpath = "/".join((cumulpath, str(context))) if cumulpath else str(context) + yield from _get_paths(spec, cpath) + + +def _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor): + if "." in tensor_path: + tensor_path.replace(".", "__") + if scratch_dir is not None: + total_tensor_path = Path(scratch_dir) / (tensor_path + ".memmap") + if os.path.exists(total_tensor_path): + raise RuntimeError( + f"The storage of tensor {total_tensor_path} already exists. " + f"To load an existing replay buffer, use storage.loads. " + f"Choose a different path to store your buffer or delete the existing files." + ) + os.makedirs(total_tensor_path.parent, exist_ok=True) + else: + total_tensor_path = None + out = MemoryMappedTensor.empty( + shape=max_size_fn(tensor.shape), + filename=total_tensor_path, + dtype=tensor.dtype, + ) + try: + filesize = os.path.getsize(tensor.filename) / 1024 / 1024 + torchrl_logger.debug( + f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." + ) + except (RuntimeError, AttributeError): + pass + return out + + +@implement_for("torch", "2.3", None) +def _init_pytree(scratch_dir, max_size_fn, data): + from torch.utils._pytree import tree_map_with_path + + # If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree + # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype + def save_tensor(tensor_path: tuple, tensor: torch.Tensor): + tensor_path = _path2str(tensor_path) + return _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor) + + out = tree_map_with_path(save_tensor, data) + return out + + +@implement_for("torch", None, "2.3") +def _init_pytree(scratch_dir, max_size, data): # noqa: F811 + + flat_data, data_specs = tree_flatten(data) + data_paths = _get_paths(data_specs) + data_paths = list(data_paths) + + # If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree + # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype + def save_tensor(tensor_path: str, tensor: torch.Tensor): + return _init_pytree_common(tensor_path, scratch_dir, max_size, tensor) + + out = [] + for tensor, tensor_path in zip(flat_data, data_paths): + out.append(save_tensor(tensor_path, tensor)) + + return tree_unflatten(out, data_specs) + + +def _roll_inplace(tensor, shift, out, index_dest=None, index_source=None): + # slice 0 + source0 = tensor[:-shift] + if index_source is not None: + source0 = source0[index_source[shift:]] + + slice0_shift = source0.shape[0] + if index_dest is not None: + out[index_dest[-slice0_shift:]] = source0 + else: + slice0 = out[-slice0_shift:] + slice0.copy_(source0) + + # slice 1 + source1 = tensor[-shift:] + if index_source is not None: + source1 = source1[index_source[:shift]] + if index_dest is not None: + out[index_dest[:-slice0_shift]] = source1 + else: + slice1 = out[:-slice0_shift] + slice1.copy_(source1) + return out diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 33cbd44f951..fc3a3ae954c 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -874,3 +874,4 @@ def assert0(x): # :class:`~torchrl.data.PrioritizedSliceSampler` and # :class:`~torchrl.data.SliceSamplerWithoutReplacement`, or other writers # such as :class:`~torchrl.data.TensorDictMaxValueWriter`. +# - Check how to checkpoint ReplayBuffers in :ref:`the doc `.