Skip to content

Commit

Permalink
[Feature] SliceSampler (#1748)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 19, 2023
1 parent 2e1d60c commit 4d3a0c6
Show file tree
Hide file tree
Showing 7 changed files with 772 additions and 45 deletions.
84 changes: 52 additions & 32 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ We also give users the ability to compose a replay buffer using the following co
PrioritizedSampler
RandomSampler
SamplerWithoutReplacement
SliceSampler
SliceSamplerWithoutReplacement
Storage
ListStorage
LazyTensorStorage
Expand Down Expand Up @@ -109,41 +111,59 @@ It is not too difficult to store trajectories in the replay buffer.
One element to pay attention to is that the size of the replay buffer is always
the size of the leading dimension of the storage: in other words, creating a
replay buffer with a storage of size 1M when storing multidimensional data
does not mean storing 1M frames but 1M trajectories.
does not mean storing 1M frames but 1M trajectories. However, if trajectories
(or episodes/rollouts) are flattened before being stored, the capacity will still
be 1M steps.

When sampling trajectories, it may be desirable to sample sub-trajectories
to diversify learning or make the sampling more efficient.
To do this, we provide a custom :class:`~torchrl.envs.Transform` class named
:class:`~torchrl.envs.RandomCropTensorDict`. Here is an example of how this class
can be used:

.. code-block::Python
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.envs import RandomCropTensorDict
>>>
>>> obs = torch.randn(100, 50, 1)
>>> data = TensorDict({"obs": obs[:-1], "next": {"obs": obs[1:]}}, [99])
>>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000))
>>> rb.extend(data)
>>> # subsample trajectories of length 10
>>> rb.append_transform(RandomCropTensorDict(sub_seq_len=10))
>>> print(rb.sample(128))
TensorDict(
fields={
index: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int32, is_shared=False),
next: TensorDict(
fields={
obs: Tensor(shape=torch.Size([10, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False),
obs: Tensor(shape=torch.Size([10, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)
TorchRL offers two distinctive ways of accomplishing this:
- The :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` allows to
sample a given number of slices of trajectories stored one after another
along the leading dimension of the :class:`~torchrl.data.replay_buffers.samplers.TensorStorage`.
This is the recommended way of sampling sub-trajectories in TorchRL __especially__
when using offline datasets (which are stored using that convention).
This strategy requires to flatten the trajectories before extending the replay
buffer and reshaping them after sampling. The :class:`~torchrl.data.replay_buffers.samplers.SliceSampler`
gives extensive details about this storage and sampling strategy.

- Trajectories can also be stored independently, with the each element of the
leading dimension pointing to a different trajectory. This requires
for the trajectories to have a congruent shape (or to be padded).
We provide a custom :class:`~torchrl.envs.Transform` class named
:class:`~torchrl.envs.RandomCropTensorDict` that allows to sample
sub-trajectories in the buffer. Note that, unlike the :class:`~torchrl.data.replay_buffers.samplers.SliceSampler`-based
strategy, here having an ``"episode"`` or ``"done"`` key pointing at the
start and stop signals isn't required.
Here is an example of how this class can be used:

.. code-block::Python
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.envs import RandomCropTensorDict
>>>
>>> obs = torch.randn(100, 50, 1)
>>> data = TensorDict({"obs": obs[:-1], "next": {"obs": obs[1:]}}, [99])
>>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000))
>>> rb.extend(data)
>>> # subsample trajectories of length 10
>>> rb.append_transform(RandomCropTensorDict(sub_seq_len=10))
>>> print(rb.sample(128))
TensorDict(
fields={
index: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int32, is_shared=False),
next: TensorDict(
fields={
obs: Tensor(shape=torch.Size([10, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False),
obs: Tensor(shape=torch.Size([10, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)
Checkpointing Replay Buffers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
230 changes: 230 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
PrioritizedSampler,
RandomSampler,
SamplerWithoutReplacement,
SliceSampler,
SliceSamplerWithoutReplacement,
)

from torchrl.data.replay_buffers.storages import (
Expand Down Expand Up @@ -1544,6 +1546,234 @@ def test_error_noninit(self):
self.exec_multiproc_rb(init=False)


class TestSamplers:
@pytest.mark.parametrize(
"batch_size,num_slices,slice_len",
[
[100, 20, None],
[120, 30, None],
[100, None, 5],
[120, None, 4],
[101, None, 101],
],
)
@pytest.mark.parametrize("episode_key", ["episode", ("some", "episode")])
@pytest.mark.parametrize("done_key", ["done", ("some", "done")])
@pytest.mark.parametrize("match_episode", [True, False])
@pytest.mark.parametrize("_data_prefix", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
def test_slice_sampler(
self,
batch_size,
num_slices,
slice_len,
episode_key,
done_key,
match_episode,
_data_prefix,
device,
):
torch.manual_seed(0)
storage = LazyMemmapStorage(100)
episode = torch.zeros(100, dtype=torch.int, device=device)
episode[:30] = 1
episode[30:55] = 2
episode[55:70] = 3
episode[70:] = 4
steps = torch.cat(
[torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0
)

done = torch.zeros(100, 1, dtype=torch.bool)
done[torch.tensor([29, 54, 69])] = 1

data = TensorDict(
{
# we only use episode_key if we want the sampler to access it
episode_key if match_episode else "whatever_episode": episode,
"another_episode": episode,
"obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5),
"act": torch.randn((20,)).expand(100, 20),
"steps": steps,
"other": torch.randn((20, 50)).expand(100, 20, 50),
done_key: done,
},
[100],
device=device,
)
if _data_prefix:
data = TensorDict({"_data": data}, [100])
storage.set(range(100), data)
if slice_len is not None and slice_len > 15:
# we may have to sample trajs shorter than slice_len
strict_length = False
else:
strict_length = True

sampler = SliceSampler(
num_slices=num_slices,
traj_key=episode_key,
end_key=done_key,
slice_len=slice_len,
strict_length=strict_length,
)
if slice_len is not None:
num_slices = batch_size // slice_len
trajs_unique_id = set()
too_short = False
for _ in range(5):
index, info = sampler.sample(storage, batch_size=batch_size)
if _data_prefix:
samples = storage._storage["_data"][index]
else:
samples = storage._storage[index]
if strict_length:
# check that trajs are ok
samples = samples.view(num_slices, -1)
assert samples["another_episode"].unique(
dim=1
).squeeze().shape == torch.Size([num_slices])
assert (
samples["steps"][..., 1:] - 1 == samples["steps"][..., :-1]
).all()
too_short = too_short or index.numel() < batch_size
trajs_unique_id = trajs_unique_id.union(
samples["another_episode"].view(-1).tolist()
)
if strict_length:
assert not too_short
else:
assert too_short

assert len(trajs_unique_id) == 4
truncated = info[("next", "truncated")]
assert truncated.view(num_slices, -1)[:, -1].all()

def test_slice_sampler_errors(self):
device = "cpu"
_data_prefix = False
batch_size, num_slices = 100, 20

episode = torch.zeros(100, dtype=torch.int, device=device)
episode[:30] = 1
episode[30:55] = 2
episode[55:70] = 3
episode[70:] = 4
steps = torch.cat(
[torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0
)

done = torch.zeros(100, 1, dtype=torch.bool)
done[torch.tensor([29, 54, 69])] = 1

data = TensorDict(
{
# we only use episode_key if we want the sampler to access it
"episode": episode,
"another_episode": episode,
"obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5),
"act": torch.randn((20,)).expand(100, 20),
"steps": steps,
"other": torch.randn((20, 50)).expand(100, 20, 50),
("next", "done"): done,
},
[100],
device=device,
)
if _data_prefix:
data = TensorDict({"_data": data}, [100])

data_wrong_done = data.clone(False)
data_wrong_done.rename_key_("episode", "_")
data_wrong_done["next", "done"] = done.unsqueeze(1).expand(100, 5, 1)
storage = LazyMemmapStorage(100)
storage.set(range(100), data_wrong_done)
sampler = SliceSampler(num_slices=num_slices)
with pytest.raises(
RuntimeError,
match="Expected the end-of-trajectory signal to be 1-dimensional",
):
index, _ = sampler.sample(storage, batch_size=batch_size)

storage = ListStorage(100)
storage.set(range(100), data)
sampler = SliceSampler(num_slices=num_slices)
with pytest.raises(
RuntimeError, match="can only sample from TensorStorage subclasses"
):
index, _ = sampler.sample(storage, batch_size=batch_size)

@pytest.mark.parametrize("batch_size,num_slices", [[20, 4], [4, 2]])
@pytest.mark.parametrize("episode_key", ["episode", ("some", "episode")])
@pytest.mark.parametrize("done_key", ["done", ("some", "done")])
@pytest.mark.parametrize("match_episode", [True, False])
@pytest.mark.parametrize("_data_prefix", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
def test_slice_sampler_without_replacement(
self,
batch_size,
num_slices,
episode_key,
done_key,
match_episode,
_data_prefix,
device,
):
torch.manual_seed(0)
storage = LazyMemmapStorage(100)
episode = torch.zeros(100, dtype=torch.int, device=device)
steps = []
done = torch.zeros(100, 1, dtype=torch.bool)
for i in range(0, 100, 5):
episode[i : i + 5] = i // 5
steps.append(torch.arange(5))
done[i + 4] = 1
steps = torch.cat(steps)

data = TensorDict(
{
# we only use episode_key if we want the sampler to access it
episode_key if match_episode else "whatever_episode": episode,
"another_episode": episode,
"obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5),
"act": torch.randn((20,)).expand(100, 20),
"steps": steps,
"other": torch.randn((20, 50)).expand(100, 20, 50),
done_key: done,
},
[100],
device=device,
)
if _data_prefix:
data = TensorDict({"_data": data}, [100])
storage.set(range(100), data)
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices, traj_key=episode_key, end_key=done_key
)
trajs_unique_id = set()
for i in range(5):
index, info = sampler.sample(storage, batch_size=batch_size)
if _data_prefix:
samples = storage._storage["_data"][index]
else:
samples = storage._storage[index]

# check that trajs are ok
samples = samples.view(num_slices, -1)
assert samples["another_episode"].unique(
dim=1
).squeeze().shape == torch.Size([num_slices])
assert (samples["steps"][..., 1:] - 1 == samples["steps"][..., :-1]).all()
cur_episodes = samples["another_episode"].view(-1).tolist()
for ep in cur_episodes:
assert ep not in trajs_unique_id, i
trajs_unique_id = trajs_unique_id.union(
cur_episodes,
)
truncated = info[("next", "truncated")]
assert truncated.view(num_slices, -1)[:, -1].all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 2 additions & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
RemoteTensorDictReplayBuffer,
ReplayBuffer,
RoundRobinWriter,
SliceSampler,
SliceSamplerWithoutReplacement,
Storage,
TensorDictMaxValueWriter,
TensorDictPrioritizedReplayBuffer,
Expand Down
2 changes: 2 additions & 0 deletions torchrl/data/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
RandomSampler,
Sampler,
SamplerWithoutReplacement,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from .storages import (
LazyMemmapStorage,
Expand Down
7 changes: 1 addition & 6 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,16 +439,11 @@ def sample(
if not self._prefetch:
ret = self._sample(batch_size)
else:
if len(self._prefetch_queue) == 0:
ret = self._sample(batch_size)
else:
with self._futures_lock:
ret = self._prefetch_queue.popleft().result()

with self._futures_lock:
while len(self._prefetch_queue) < self._prefetch_cap:
fut = self._prefetch_executor.submit(self._sample, batch_size)
self._prefetch_queue.append(fut)
ret = self._prefetch_queue.popleft().result()

if return_info:
return ret
Expand Down
Loading

0 comments on commit 4d3a0c6

Please sign in to comment.