Skip to content

Commit

Permalink
[Feature] Composite replay buffers (#1768)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 9, 2024
1 parent fd27cb7 commit 11a82c3
Show file tree
Hide file tree
Showing 10 changed files with 1,229 additions and 51 deletions.
79 changes: 79 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,85 @@ Here's an example:
RobosetExperienceReplay
VD4RLExperienceReplay

Composing datasets
~~~~~~~~~~~~~~~~~~

In offline RL, it is customary to work with more than one dataset at the same time.
Moreover, TorchRL usually has a fine-grained dataset nomenclature, where
each task is represented separately when other libraries will represent these
datasets in a more compact way. To allow users to compose multiple datasets
together, we propose a :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble`
primitive that allows users to sample from multiple datasets at once.

If the individual dataset formats differ, :class:`~torchrl.envs.Transform` instances
can be used. In the following example, we create two dummy datasets with semantically
identical entries that differ in names (``("some", "key")`` and ``"another_key"``)
and show how they can be renamed to have a matching name. We also resize images
such that they can be stacked together during sampling.

>>> from torchrl.envs import Comopse, ToTensorImage, Resize, RenameTransform
>>> from torchrl.data import TensorDictReplayBuffer, ReplayBufferEnsemble, LazyMemmapStorage
>>> from tensordict import TensorDict
>>> import torch
>>> rb0 = TensorDictReplayBuffer(
... storage=LazyMemmapStorage(10),
... transform=Compose(
... ToTensorImage(in_keys=["pixels", ("next", "pixels")]),
... Resize(32, in_keys=["pixels", ("next", "pixels")]),
... RenameTransform([("some", "key")], ["renamed"]),
... ),
... )
>>> rb1 = TensorDictReplayBuffer(
... storage=LazyMemmapStorage(10),
... transform=Compose(
... ToTensorImage(in_keys=["pixels", ("next", "pixels")]),
... Resize(32, in_keys=["pixels", ("next", "pixels")]),
... RenameTransform(["another_key"], ["renamed"]),
... ),
... )
>>> rb = ReplayBufferEnsemble(
... rb0,
... rb1,
... p=[0.5, 0.5],
... transform=Resize(33, in_keys=["pixels"], out_keys=["pixels33"]),
... )
>>> data0 = TensorDict(
... {
... "pixels": torch.randint(255, (10, 244, 244, 3)),
... ("next", "pixels"): torch.randint(255, (10, 244, 244, 3)),
... ("some", "key"): torch.randn(10),
... },
... batch_size=[10],
... )
>>> data1 = TensorDict(
... {
... "pixels": torch.randint(255, (10, 64, 64, 3)),
... ("next", "pixels"): torch.randint(255, (10, 64, 64, 3)),
... "another_key": torch.randn(10),
... },
... batch_size=[10],
... )
>>> rb[0].extend(data0)
>>> rb[1].extend(data1)
>>> for _ in range(2):
... sample = rb.sample(10)
... assert sample["next", "pixels"].shape == torch.Size([2, 5, 3, 32, 32])
... assert sample["pixels"].shape == torch.Size([2, 5, 3, 32, 32])
... assert sample["pixels33"].shape == torch.Size([2, 5, 3, 33, 33])
... assert sample["renamed"].shape == torch.Size([2, 5])

.. currentmodule:: torchrl.data.replay_buffers


.. autosummary::
:toctree: generated/
:template: rl_template.rst

ReplayBufferEnsemble
SamplerEnsemble
StorageEnsemble
WriterEnsemble

TensorSpec
----------

Expand Down
Loading

0 comments on commit 11a82c3

Please sign in to comment.