Skip to content

Commit

Permalink
[Feature] Span slice indices on the left and on the right (pytorch#2107)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 24, 2024
1 parent 6f1c387 commit 93e9e30
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 7 deletions.
53 changes: 53 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,6 +2175,59 @@ def test_slice_sampler_without_replacement(
done_recon = info[("next", "truncated")] | info[("next", "terminated")]
assert done_recon.view(num_slices, -1)[:, -1].all()

def test_slice_sampler_left_right(self):
torch.manual_seed(0)
data = TensorDict(
{"obs": torch.arange(1, 11).repeat(10), "eps": torch.arange(100) // 10 + 1},
[100],
)

for N in (2, 4):
rb = TensorDictReplayBuffer(
sampler=SliceSampler(num_slices=10, traj_key="eps", span=(N, N)),
batch_size=50,
storage=LazyMemmapStorage(100),
)
rb.extend(data)

for _ in range(10):
sample = rb.sample()
sample = split_trajectories(sample)
assert (sample["next", "truncated"].squeeze(-1).sum(-1) == 1).all()
assert ((sample["obs"] == 0).sum(-1) <= N).all(), sample["obs"]
assert ((sample["eps"] == 0).sum(-1) <= N).all()
for i in range(sample.shape[0]):
curr_eps = sample[i]["eps"]
curr_eps = curr_eps[curr_eps != 0]
assert curr_eps.unique().numel() == 1

def test_slice_sampler_left_right_ndim(self):
torch.manual_seed(0)
data = TensorDict(
{"obs": torch.arange(1, 11).repeat(12), "eps": torch.arange(120) // 10 + 1},
[120],
)
data = data.reshape(4, 30)

for N in (2, 4):
rb = TensorDictReplayBuffer(
sampler=SliceSampler(num_slices=10, traj_key="eps", span=(N, N)),
batch_size=50,
storage=LazyMemmapStorage(100, ndim=2),
)
rb.extend(data)

for _ in range(10):
sample = rb.sample()
sample = split_trajectories(sample)
assert (sample["next", "truncated"].squeeze(-1).sum(-1) <= 1).all()
assert ((sample["obs"] == 0).sum(-1) <= N).all(), sample["obs"]
assert ((sample["eps"] == 0).sum(-1) <= N).all()
for i in range(sample.shape[0]):
curr_eps = sample[i]["eps"]
curr_eps = curr_eps[curr_eps != 0]
assert curr_eps.unique().numel() == 1

def test_slicesampler_strictlength(self):

torch.manual_seed(0)
Expand Down
70 changes: 63 additions & 7 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,13 @@ class SliceSampler(Sampler):
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
Keyword arguments can also be passed to torch.compile with this arg.
Defaults to ``False``.
span (bool, int, Tuple[bool | int, bool | int], optional): if provided, the sampled
trajectory will span across the left and/or the right. This means that possibly
fewer elements will be provided than what was required. A boolean value means
that at least one element will be sampled per trajectory. An integer `i` means
that at least `slice_len - i` samples will be gathered for each sampled trajectory.
Using tuples allows a fine grained control over the span on the left (beginning
of the stored trajectory) and on the right (end of the stored trajectory).
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
Expand Down Expand Up @@ -753,6 +760,7 @@ def __init__(
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
compile: bool | dict = False,
span: bool | Tuple[bool | int, bool | int] = False,
):
self.num_slices = num_slices
self.slice_len = slice_len
Expand All @@ -763,6 +771,11 @@ def __init__(
self._fetch_traj = True
self.strict_length = strict_length
self._cache = {}

if isinstance(span, bool):
span = (span, span)
self.span = span

if trajectories is not None:
if traj_key is not None or end_key:
raise RuntimeError(
Expand Down Expand Up @@ -916,6 +929,7 @@ def _end_to_start_stop(end, length):
return start_idx, stop_idx, lengths

def _start_to_end(self, st: torch.Tensor, length: int):

arange = torch.arange(length, device=st.device, dtype=st.dtype)
ndims = st.shape[-1] - 1 if st.ndim else 0
if ndims:
Expand Down Expand Up @@ -1128,21 +1142,63 @@ def _get_index(
storage_length: int,
traj_idx: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, dict]:
# end_point is the last possible index for start
last_indexable_start = lengths[traj_idx] - seq_length + 1
if not self.span[1]:
end_point = last_indexable_start
elif self.span[1] is True:
end_point = lengths[traj_idx] + 1
else:
span_left = self.span[1]
if span_left >= seq_length:
raise ValueError(
"The right and left span must be strictly lower than the sequence length"
)
end_point = lengths[traj_idx] - span_left

if not self.span[0]:
start_point = 0
elif self.span[0] is True:
start_point = -seq_length + 1
else:
span_right = self.span[0]
if span_right >= seq_length:
raise ValueError(
"The right and left span must be strictly lower than the sequence length"
)
start_point = -span_right

relative_starts = (
(
torch.rand(num_slices, device=lengths.device)
* (lengths[traj_idx] - seq_length + 1)
)
.floor()
.to(start_idx.dtype)
)
torch.rand(num_slices, device=lengths.device) * (end_point - start_point)
).floor().to(start_idx.dtype) + start_point

if self.span[0]:
out_of_traj = relative_starts < 0
if out_of_traj.any():
# a negative start means sampling fewer elements
seq_length = torch.where(
~out_of_traj, seq_length, seq_length + relative_starts
)
relative_starts = torch.where(~out_of_traj, relative_starts, 0)
if self.span[1]:
out_of_traj = relative_starts + seq_length > lengths[traj_idx]
if out_of_traj.any():
# a negative start means sampling fewer elements
# print('seq_length before', seq_length)
# print('relative_starts', relative_starts)
seq_length = torch.minimum(
seq_length, lengths[traj_idx] - relative_starts
)
# print('seq_length after', seq_length)

starts = torch.cat(
[
(start_idx[traj_idx, 0] + relative_starts).unsqueeze(1),
start_idx[traj_idx, 1:],
],
1,
)

index = self._tensor_slices_from_startend(seq_length, starts, storage_length)
if self.truncated_key is not None:
truncated_key = self.truncated_key
Expand Down

0 comments on commit 93e9e30

Please sign in to comment.