Skip to content

Commit

Permalink
[Feature] pickle-free RB checkpointing (pytorch#1733)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 6, 2023
1 parent 841f8d9 commit 25bd8a5
Show file tree
Hide file tree
Showing 6 changed files with 749 additions and 134 deletions.
69 changes: 69 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,49 @@ write onto the storage. The following code snippet examplifies this feature:
... assert (rb["_data", "a"][:10] == 0).all() # data from main process
... assert (rb["_data", "a"][10:20] == 1).all() # data from remote process

Sharing replay buffers across processes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Replay buffers can be shared between processes as long as their components are
sharable. This feature allows for multiple processes to collect data and populate a shared
replay buffer collaboratively, rather than centralizing the data on the main process
which can incur some data transmission overhead.

Sharable storages include :class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage`
or any subclass of :class:`~torchrl.data.replay_buffers.storages.TensorStorage`
as long as they are instantiated and their content is stored as memory-mapped
tensors. Stateful writers such as :class:`~torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter`
are currently not sharable, and the same goes for stateful samplers such as
:class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`.

A shared replay buffer can be read and extended on any process that has access
to it, as the following example shows:

>>> import pickle
>>>
>>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
>>> import torch
>>> from torch import multiprocessing as mp
>>> from tensordict import TensorDict
>>>
>>> def worker(rb):
... td = TensorDict({"a": torch.ones(10)}, [10])
... # Extends the shared replay buffer on a subprocess
... rb.extend(td)
>>>
>>> if __name__ == "__main__":
... rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21))
... td = TensorDict({"a": torch.zeros(10)}, [10])
.. # extends the replay buffer on the main process
... rb.extend(td)
...
... proc = mp.Process(target=worker, args=(rb,))
... proc.start()
... proc.join()
... # Checks that the length of the buffer equates the length of both
... # extensions (local and remote)
... assert len(rb) == 20


Storing trajectories
~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -131,6 +174,32 @@ can be used:
device=None,
is_shared=False)
Checkpointing Replay Buffers
----------------------------

Each component of the replay buffer can potentially be stateful and, as such,
require a dedicated way of being serialized.
Our replay buffer enjoys two separate APIs for saving their state on disk:
:meth:`~torchrl.data.ReplayBuffer.dumps` and :meth:`~torchrl.data.ReplayBuffer.loads` will save the
data of each component except transforms (storage, writer, sampler) using memory-mapped
tensors and json files for the metadata. This will work across all classes except
:class:`~torchrl.data.replay_buffers.storages.ListStorage`, which content
cannot be anticipated (and as such does not comply with memory-mapped data
structures such as those that can be found in the tensordict library).
This API guarantees that a buffer that is saved and then loaded back will be in
the exact same state, whether we look at the status of its sampler (eg, priority trees)
its writer (eg, max writer heaps) or its storage.
Under the hood, :meth:`~torchrl.data.ReplayBuffer.dumps` will just call the public
`dumps` method in a specific folder for each of its components (except transforms
which we don't assume to be serializable using memory-mapped tensors in general).

Whenever saving data using :meth:`~torchrl.data.ReplayBuffer.dumps` is not possible, an
alternative way is to use :meth:`~torchrl.data.ReplayBuffer.state_dict`, which returns a data
structure that can be saved using :func:`torch.save` and loaded using :func:`torch.load`
before calling :meth:`~torchrl.data.ReplayBuffer.load_state_dict`. The drawback
of this method is that it will struggle to save big data structures, which is a
common setting when using replay buffers.

Datasets
--------

Expand Down
Loading

0 comments on commit 25bd8a5

Please sign in to comment.