Skip to content

Commit

Permalink
[BugFix] Fix strict-length for spanning trajectories (pytorch#1982)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 4, 2024
1 parent 9337550 commit 6fb16a2
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 48 deletions.
64 changes: 50 additions & 14 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,7 +1904,7 @@ def test_slice_sampler(
)

done = torch.zeros(100, 1, dtype=torch.bool)
done[torch.tensor([29, 54, 69])] = 1
done[torch.tensor([29, 54, 69, 99])] = 1

data = TensorDict(
{
Expand Down Expand Up @@ -1995,6 +1995,35 @@ def test_slice_sampler(
truncated = info[("next", "truncated")]
assert truncated.view(num_slices, -1)[:, -1].all()

@pytest.mark.parametrize("sampler", [SliceSampler, SliceSamplerWithoutReplacement])
def test_slice_sampler_at_capacity(self, sampler):
torch.manual_seed(0)

trajectory0 = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
trajectory1 = torch.arange(2).repeat_interleave(5)
trajectory = torch.stack([trajectory0, trajectory1], 0)

td = TensorDict(
{"trajectory": trajectory, "steps": torch.arange(10).expand(2, 10)}, [2, 10]
)

rb = ReplayBuffer(
sampler=sampler(traj_key="trajectory", num_slices=2),
storage=LazyTensorStorage(20, ndim=2),
batch_size=6,
)

rb.extend(td)

for s in rb:
if (s["steps"] == 9).any():
n = (s["steps"] == 9).nonzero()
assert ((s["steps"] == 0).nonzero() == n + 1).all()
assert ((s["steps"] == 1).nonzero() == n + 2).all()
break
else:
raise AssertionError

def test_slice_sampler_errors(self):
device = "cpu"
batch_size, num_slices = 100, 20
Expand Down Expand Up @@ -2651,19 +2680,26 @@ def test_rb_multidim_collector(
if transform:
for t in transform:
rb.append_transform(t())
for data in collector:
rb.extend(data)
if isinstance(rb, TensorDictReplayBuffer) and transform is not None:
# this should fail bc we can't set the indices after executing the transform.
with pytest.raises(RuntimeError, match="Failed to set the metadata"):
rb.sample()
return
s = rb.sample()
rbtot = rb[:]
assert rbtot.shape[0] == 2
assert len(rb) == rbtot.numel()
if transform is not None:
assert s.ndim == 2
try:
for i, data in enumerate(collector): # noqa: B007
rb.extend(data)
if isinstance(rb, TensorDictReplayBuffer) and transform is not None:
# this should fail bc we can't set the indices after executing the transform.
with pytest.raises(
RuntimeError, match="Failed to set the metadata"
):
rb.sample()
return
s = rb.sample()
rbtot = rb[:]
assert rbtot.shape[0] == 2
assert len(rb) == rbtot.numel()
if transform is not None:
assert s.ndim == 2
except Exception:
print(f"Failing at iter {i}") # noqa: T201
print(f"rb {rb}") # noqa: T201
raise


if __name__ == "__main__":
Expand Down
176 changes: 142 additions & 34 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,16 @@ class SliceSampler(Sampler):
To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
or when this signal is readily available. Must be used with ``cache_values=True``
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
If provided, it is assumed that the storage is at capacity and that
if the last element of the ``ends`` tensor is ``False``,
the same trajectory spans across end and beginning.
trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids.
To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
or when this signal is readily available. Must be used with ``cache_values=True``
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
If provided, it is assumed that the storage is at capacity and that
if the last element of the trajectory tensor is identical to the first,
the same trajectory spans across end and beginning.
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
truncated_key (NestedKey, optional): If not ``None``, this argument
Expand All @@ -625,7 +631,8 @@ class SliceSampler(Sampler):
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
allowed to appear in the batch. If ``True``, trajectories shorted
than required will be filtered out.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
Expand Down Expand Up @@ -728,7 +735,10 @@ def __init__(
raise RuntimeError(
"To be used, trajectories requires `cache_values` to be set to `True`."
)
vals = self._find_start_stop_traj(trajectory=trajectories)
vals = self._find_start_stop_traj(
trajectory=trajectories,
at_capacity=True,
)
self._cache["stop-and-length"] = vals

elif ends is not None:
Expand All @@ -742,7 +752,7 @@ def __init__(
raise RuntimeError(
"To be used, ends requires `cache_values` to be set to `True`."
)
vals = self._find_start_stop_traj(end=ends)
vals = self._find_start_stop_traj(end=ends, at_capacity=True)
self._cache["stop-and-length"] = vals

else:
Expand All @@ -760,7 +770,7 @@ def __init__(
)

@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None):
def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
if trajectory is not None:
# slower
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
Expand All @@ -773,30 +783,60 @@ def _find_start_stop_traj(*, trajectory=None, end=None):

# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, torch.ones_like(end[:1])], 0)
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
length = trajectory.shape[0]
else:
# TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True

# We presume that not done at the end means that the traj spans across end and beginning of storage
length = end.shape[0]

if not at_capacity:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
elif not end.any(0).all():
# we must have at least one end by traj to delimitate trajectories
# so if no end can be found we set it manually
mask = ~end.any(0, True)
mask = torch.cat([torch.zeros_like(end[:-1]), mask])
end = torch.masked_fill(mask, end, 1)
ndim = end.ndim
if ndim == 0:
raise RuntimeError(
"Expected the end-of-trajectory signal to be at least 1-dimensional."
)
# Using transpose ensures the start and stop are sorted the same way
stop_idx = end.transpose(0, -1).nonzero()
beginnings = torch.cat([torch.ones_like(end[:1]), end[:-1]], 0)
start_idx = beginnings.transpose(0, -1).nonzero()
start_idx = torch.cat([start_idx[:, -1:], start_idx[:, :-1]], -1)
# beginnings = torch.cat([end[-1:], end[:-1]], 0)
# start_idx = beginnings.transpose(0, -1).nonzero()
# start_idx = torch.cat([start_idx[:, -1:], start_idx[:, :-1]], -1)
stop_idx = torch.cat([stop_idx[:, -1:], stop_idx[:, :-1]], -1)

# First build the start indices as the stop + 1, we'll shift it later
start_idx = stop_idx.clone()
start_idx[:, 0] += 1
start_idx[:, 0] %= end.shape[0]
# shift start: to do this, we check when the non-first dim indices are identical
# and get a mask like [False, True, True, False, True, ...] where False means
# that there's a switch from one dim to another (ie, a switch from one element of the batch
# to another). We roll this one step along the time dimension and these two
# masks provide us with the indices of the permutation matrix we need
# to apply to start_idx.
start_idx_mask = (start_idx[1:, 1:] == start_idx[:-1, 1:]).all(-1)
m1 = torch.cat([torch.zeros_like(start_idx_mask[:1]), start_idx_mask])
m2 = torch.cat([start_idx_mask, torch.zeros_like(start_idx_mask[:1])])
start_idx_replace = torch.empty_like(start_idx)
start_idx_replace[m1] = start_idx[m2]
start_idx_replace[~m1] = start_idx[~m2]
start_idx = start_idx_replace
lengths = stop_idx[:, 0] - start_idx[:, 0] + 1
lengths[lengths < 0] = lengths[lengths < 0] + length
return start_idx, stop_idx, lengths

def _tensor_slices_from_startend(self, seq_length, start):
def _tensor_slices_from_startend(self, seq_length, start, storage_length):
# start is a 2d tensor resulting from nonzero()
# seq_length is a 1d tensor indicating the desired length of each sequence

Expand All @@ -823,6 +863,7 @@ def _start_to_end(st: torch.Tensor, length: int):
for _start, _seq_len in zip(start, seq_length)
]
)
result[:, 0] = result[:, 0] % storage_length
return result

def _get_stop_and_length(self, storage, fallback=True):
Expand All @@ -843,7 +884,9 @@ def _get_stop_and_length(self, storage, fallback=True):
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(trajectory=trajectory)
vals = self._find_start_stop_traj(
trajectory=trajectory, at_capacity=storage._is_full
)
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
Expand All @@ -863,7 +906,9 @@ def _get_stop_and_length(self, storage, fallback=True):
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(end=done.squeeze()[: len(storage)])
vals = self._find_start_stop_traj(
end=done.squeeze()[: len(storage)], at_capacity=storage._is_full
)
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
Expand Down Expand Up @@ -903,28 +948,80 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
"instead."
)
seq_length, num_slices = self._adjusted_batch_size(batch_size)
return self._sample_slices(lengths, start_idx, stop_idx, seq_length, num_slices)
storage_length = storage.shape[0]
return self._sample_slices(
lengths,
start_idx,
stop_idx,
seq_length,
num_slices,
storage_length=storage_length,
)

def _sample_slices(
self, lengths, start_idx, stop_idx, seq_length, num_slices, traj_idx=None
self,
lengths: torch.Tensor,
start_idx: torch.Tensor,
stop_idx: torch.Tensor,
seq_length: int,
num_slices: int,
storage_length: int,
traj_idx: torch.Tensor | None = None,
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
if traj_idx is None:
traj_idx = torch.randint(
lengths.shape[0], (num_slices,), device=lengths.device
)
else:
num_slices = traj_idx.shape[0]
def get_traj_idx(lengths=lengths):
return torch.randint(lengths.shape[0], (num_slices,), device=lengths.device)

if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
"Some stored trajectories have a length shorter than the slice that was asked for ("
f"min length={lengths.min()}). "
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
"in you batch."
)
# make seq_length a tensor with values clamped by lengths
seq_length = lengths[traj_idx].clamp_max(seq_length)
idx = lengths == seq_length
if not idx.any():
raise RuntimeError(
"Did not find a single trajectory with sufficient length."
)
if (
isinstance(seq_length, torch.Tensor)
and seq_length.shape == lengths.shape
):
seq_length = seq_length[idx]
lengths_idx = lengths[idx]
start_idx = start_idx[idx]
stop_idx = stop_idx[idx]

if traj_idx is None:
traj_idx = get_traj_idx(lengths=lengths_idx)
else:
# Here we must filter out the indices that correspond to trajectories
# we don't want to keep. That could potentially lead to an empty sample.
# The difficulty with this adjustment is that traj_idx points to a full
# sequences of lengths, but we filter out part of it so we must
# convert traj_idx to a boolean mask, index this mask with the
# valid indices and then recover the nonzero.
idx_mask = torch.zeros_like(idx)
idx_mask[traj_idx] = True
traj_idx = idx_mask[idx].nonzero().squeeze(-1)
if not traj_idx.numel():
raise RuntimeError(
"None of the provided indices pointed to a trajectory of "
"sufficient length. Consider using strict_length=False for the "
"sampler instead."
)
num_slices = traj_idx.shape[0]

del idx
lengths = lengths_idx
else:
if traj_idx is None:
traj_idx = get_traj_idx()
else:
num_slices = traj_idx.shape[0]

# make seq_length a tensor with values clamped by lengths
seq_length = lengths[traj_idx].clamp_max(seq_length)
else:
if traj_idx is None:
traj_idx = get_traj_idx()
else:
num_slices = traj_idx.shape[0]

relative_starts = (
(
Expand All @@ -936,12 +1033,12 @@ def _sample_slices(
)
starts = torch.cat(
[
start_idx[traj_idx, :1] + relative_starts.unsqueeze(-1),
(start_idx[traj_idx, 0] + relative_starts).unsqueeze(1),
start_idx[traj_idx, 1:],
],
1,
)
index = self._tensor_slices_from_startend(seq_length, starts)
index = self._tensor_slices_from_startend(seq_length, starts, storage_length)
if self.truncated_key is not None:
truncated_key = self.truncated_key
done_key = _replace_last(truncated_key, "done")
Expand Down Expand Up @@ -1053,7 +1150,8 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
allowed to appear in the batch. If ``True``, trajectories shorted
than required will be filtered out.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
Expand Down Expand Up @@ -1172,8 +1270,15 @@ def sample(
# first get indices of the trajectories we want to retrieve
seq_length, num_slices = self._adjusted_batch_size(batch_size)
indices, _ = SamplerWithoutReplacement.sample(self, storage, num_slices)
storage_length = storage.shape[0]
idx, info = self._sample_slices(
lengths, start_idx, stop_idx, seq_length, num_slices, traj_idx=indices
lengths,
start_idx,
stop_idx,
seq_length,
num_slices,
storage_length,
traj_idx=indices,
)
return idx, info

Expand Down Expand Up @@ -1233,7 +1338,8 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
allowed to appear in the batch. If ``True``, trajectories shorted
than required will be filtered out.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
Expand Down Expand Up @@ -1384,7 +1490,9 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
info["_weight"] = torch.as_tensor(info["_weight"], device=lengths.device)

# extends starting indices of each slice with sequence_length to get indices of all steps
index = self._tensor_slices_from_startend(seq_length, starts)
index = self._tensor_slices_from_startend(
seq_length, starts, storage_length=storage.shape[0]
)

# repeat the weight of each slice to match the number of steps
info["_weight"] = torch.repeat_interleave(info["_weight"], seq_length)
Expand Down

0 comments on commit 6fb16a2

Please sign in to comment.