Skip to content

Commit

Permalink
[BugFix] Fix sliced PRB when only traj is provided (pytorch#2228)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 14, 2024
1 parent ce92e35 commit 35df59e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 13 deletions.
17 changes: 9 additions & 8 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2230,12 +2230,12 @@ def test_slice_sampler(
def test_slice_sampler_at_capacity(self, sampler):
torch.manual_seed(0)

trajectory0 = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
trajectory1 = torch.arange(2).repeat_interleave(5)
trajectory0 = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])
trajectory1 = torch.arange(2).repeat_interleave(6)
trajectory = torch.stack([trajectory0, trajectory1], 0)

td = TensorDict(
{"trajectory": trajectory, "steps": torch.arange(10).expand(2, 10)}, [2, 10]
{"trajectory": trajectory, "steps": torch.arange(12).expand(2, 12)}, [2, 12]
)

rb = ReplayBuffer(
Expand Down Expand Up @@ -2469,7 +2469,8 @@ def test_slice_sampler_strictlength(self):
@pytest.mark.parametrize("ndim", [1, 2])
@pytest.mark.parametrize("strict_length", [True, False])
@pytest.mark.parametrize("circ", [False, True])
def test_slice_sampler_prioritized(self, ndim, strict_length, circ):
@pytest.mark.parametrize("at_capacity", [False, True])
def test_slice_sampler_prioritized(self, ndim, strict_length, circ, at_capacity):
torch.manual_seed(0)
out = []
for t in range(5):
Expand All @@ -2491,9 +2492,9 @@ def test_slice_sampler_prioritized(self, ndim, strict_length, circ):
if ndim == 2:
data = torch.stack([data, data])
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(data.numel(), ndim=ndim),
storage=LazyTensorStorage(data.numel() - at_capacity, ndim=ndim),
sampler=PrioritizedSliceSampler(
max_capacity=data.numel(),
max_capacity=data.numel() - at_capacity,
alpha=1.0,
beta=1.0,
end_key="done",
Expand Down Expand Up @@ -2530,8 +2531,8 @@ def test_slice_sampler_prioritized(self, ndim, strict_length, circ):
assert (samples["traj"] == 0).any()
# Check that all samples of the first traj contain all elements (since it's too short to fullfill 10 elts)
sc = samples[samples["traj"] == 0]["step_count"]
assert (sc == 0).sum() == (sc == 1).sum()
assert (sc == 0).sum() == (sc == 4).sum()
assert (sc == 1).sum() == (sc == 2).sum()
assert (sc == 1).sum() == (sc == 4).sum()
assert rb._sampler._cache
rb.extend(data)
assert not rb._sampler._cache
Expand Down
29 changes: 24 additions & 5 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,9 @@ def _get_stop_and_length(self, storage, fallback=True):
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(
trajectory=trajectory, at_capacity=storage._is_full
trajectory=trajectory,
at_capacity=storage._is_full,
cursor=getattr(storage, "_last_cursor", None),
)
if self.cache_values:
self._cache["stop-and-length"] = vals
Expand Down Expand Up @@ -1803,7 +1805,7 @@ def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811
.flip(0)
)

def _preceding_stop_idx(self, storage, lengths, seq_length):
def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx):
preceding_stop_idx = self._cache.get("preceding_stop_idx")
if preceding_stop_idx is not None:
return preceding_stop_idx
Expand All @@ -1828,6 +1830,13 @@ def _preceding_stop_idx(self, storage, lengths, seq_length):
# Mask the rightmost values of that padded tensor
preceding_stop_idx = pad[:, -seq_length + 1 + span_right :]
preceding_stop_idx = preceding_stop_idx[preceding_stop_idx >= 0]
if storage._is_full:
preceding_stop_idx = (
preceding_stop_idx
+ np.ravel_multi_index(
tuple(start_idx[0].tolist()), storage._total_shape
)
) % storage._total_shape.numel()
if self.cache_values:
self._cache["preceding_stop_idx"] = preceding_stop_idx
return preceding_stop_idx
Expand All @@ -1838,7 +1847,9 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
seq_length, num_slices = self._adjusted_batch_size(batch_size)

preceding_stop_idx = self._preceding_stop_idx(storage, lengths, seq_length)
preceding_stop_idx = self._preceding_stop_idx(
storage, lengths, seq_length, start_idx
)
if storage.ndim > 1:
# we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted)
# This is because the lengths come as they would for a permuted storage
Expand All @@ -1851,12 +1862,14 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
)

# force to not sample index at the end of a trajectory
vals = torch.tensor(self._sum_tree[preceding_stop_idx.cpu().numpy()])
self._sum_tree[preceding_stop_idx.cpu().numpy()] = 0.0
# and no need to update self._min_tree

starts, info = PrioritizedSampler.sample(
self, storage=storage, batch_size=batch_size // seq_length
)
self._sum_tree[preceding_stop_idx.cpu().numpy()] = vals
# We must truncate the seq_length if (1) not strict length or (2) span[1]
if self.span[1] or not self.strict_length:
if not isinstance(starts, torch.Tensor):
Expand All @@ -1866,7 +1879,13 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
# Find the stop that comes after the start index
# say start_tensor has shape [N, X] and stop_idx has shape [M, X]
# diff will have shape [M, N, X]
diff = stop_idx.unsqueeze(1) - starts_tensor.unsqueeze(0)
stop_idx_corr = stop_idx.clone()
stop_idx_corr[:, 0] = torch.where(
stop_idx[:, 0] < start_idx[:, 0],
stop_idx[:, 0] + storage._len_along_dim0,
stop_idx[:, 0],
)
diff = stop_idx_corr.unsqueeze(1) - starts_tensor.unsqueeze(0)
# filter out all items that don't belong to the same dim in the storage
mask = (diff[:, :, 1:] != 0).any(-1)
diff = diff[:, :, 0]
Expand All @@ -1876,7 +1895,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
diff[diff < 0] = diff.max() + 1
# Take the arg min along dim 0 (thereby reducing dim M)
idx = diff.argmin(dim=0)
stops = stop_idx[idx, 0]
stops = stop_idx_corr[idx, 0]
# TODO: here things may not work bc we could have spanning trajs,
# though I cannot show that it breaks in the tests
if starts_tensor.ndim > 1:
Expand Down

0 comments on commit 35df59e

Please sign in to comment.