Skip to content

Commit

Permalink
[BugFix] Fix strict length in PRB+SliceSampler (pytorch#2202)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 7, 2024
1 parent 726e959 commit 332499a
Show file tree
Hide file tree
Showing 5 changed files with 368 additions and 151 deletions.
163 changes: 158 additions & 5 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,6 +2003,7 @@ def test_slice_sampler(
)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
sampler.update_priority(index, 1)
else:
sampler = SliceSampler(
num_slices=num_slices,
Expand All @@ -2017,16 +2018,20 @@ def test_slice_sampler(
trajs_unique_id = set()
too_short = False
count_unique = set()
for _ in range(30):
for _ in range(50):
index, info = sampler.sample(storage, batch_size=batch_size)
samples = storage._storage[index]
if strict_length:
# check that trajs are ok
samples = samples.view(num_slices, -1)

assert samples["another_episode"].unique(
dim=1
).squeeze().shape == torch.Size([num_slices])
unique_another_episode = (
samples["another_episode"].unique(dim=1).squeeze()
)
assert unique_another_episode.shape == torch.Size([num_slices]), (
num_slices,
samples,
)
assert (
samples["steps"][..., 1:] - 1 == samples["steps"][..., :-1]
).all()
Expand Down Expand Up @@ -2262,7 +2267,7 @@ def test_slice_sampler_left_right_ndim(self):
curr_eps = curr_eps[curr_eps != 0]
assert curr_eps.unique().numel() == 1

def test_slicesampler_strictlength(self):
def test_slice_sampler_strictlength(self):

torch.manual_seed(0)

Expand Down Expand Up @@ -2306,6 +2311,154 @@ def test_slicesampler_strictlength(self):
else:
assert len(sample["traj"].unique()) == 1

@pytest.mark.parametrize("ndim", [1, 2])
@pytest.mark.parametrize("strict_length", [True, False])
@pytest.mark.parametrize("circ", [False, True])
def test_slice_sampler_prioritized(self, ndim, strict_length, circ):
torch.manual_seed(0)
out = []
for t in range(5):
length = (t + 1) * 5
done = torch.zeros(length, 1, dtype=torch.bool)
done[-1] = 1
priority = 10 if t == 0 else 1
traj = TensorDict(
{
"traj": torch.full((length,), t),
"step_count": torch.arange(length),
"done": done,
"priority": torch.full((length,), priority),
},
batch_size=length,
)
out.append(traj)
data = torch.cat(out)
if ndim == 2:
data = torch.stack([data, data])
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(data.numel(), ndim=ndim),
sampler=PrioritizedSliceSampler(
max_capacity=data.numel(),
alpha=1.0,
beta=1.0,
end_key="done",
slice_len=10,
strict_length=strict_length,
cache_values=True,
),
batch_size=50,
)
if not circ:
# Simplest case: the buffer is full but no overlap
index = rb.extend(data)
else:
# The buffer is 2/3 -> 1/3 overlapping
rb.extend(data[..., : data.shape[-1] // 3])
index = rb.extend(data)
rb.update_priority(index, data["priority"])
samples = []
found_shorter_batch = False
for _ in range(100):
samples.append(rb.sample())
if samples[-1].numel() < 50:
found_shorter_batch = True
samples = torch.cat(samples)
if strict_length:
assert not found_shorter_batch
else:
assert found_shorter_batch
# the first trajectory has a very high priority, but should only appear
# if strict_length=False.
if strict_length:
assert (samples["traj"] != 0).all(), samples["traj"].unique()
else:
assert (samples["traj"] == 0).any()
# Check that all samples of the first traj contain all elements (since it's too short to fullfill 10 elts)
sc = samples[samples["traj"] == 0]["step_count"]
assert (sc == 0).sum() == (sc == 1).sum()
assert (sc == 0).sum() == (sc == 4).sum()
assert rb._sampler._cache
rb.extend(data)
assert not rb._sampler._cache

@pytest.mark.parametrize("ndim", [1, 2])
@pytest.mark.parametrize("strict_length", [True, False])
@pytest.mark.parametrize("circ", [False, True])
@pytest.mark.parametrize(
"span", [False, [False, False], [False, True], 3, [False, 3]]
)
def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span):
torch.manual_seed(0)
out = []
# 5 trajs of length 3, 6, 9, 12 and 15
for t in range(5):
length = (t + 1) * 3
done = torch.zeros(length, 1, dtype=torch.bool)
done[-1] = 1
priority = 1
traj = TensorDict(
{
"traj": torch.full((length,), t),
"step_count": torch.arange(length),
"done": done,
"priority": torch.full((length,), priority),
},
batch_size=length,
)
out.append(traj)
data = torch.cat(out)
if ndim == 2:
data = torch.stack([data, data])
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(data.numel(), ndim=ndim),
sampler=PrioritizedSliceSampler(
max_capacity=data.numel(),
alpha=1.0,
beta=1.0,
end_key="done",
slice_len=5,
strict_length=strict_length,
cache_values=True,
span=span,
),
batch_size=5,
)
if not circ:
# Simplest case: the buffer is full but no overlap
index = rb.extend(data)
else:
# The buffer is 2/3 -> 1/3 overlapping
rb.extend(data[..., : data.shape[-1] // 3])
index = rb.extend(data)
rb.update_priority(index, data["priority"])
found_traj_0 = False
found_traj_4_truncated_left = False
found_traj_4_truncated_right = False
for i, s in enumerate(rb):
t = s["traj"].unique().tolist()
assert len(t) == 1
t = t[0]
if t == 0:
found_traj_0 = True
if t == 4 and s.numel() < 5:
if s["step_count"][0] > 10:
found_traj_4_truncated_right = True
if s["step_count"][0] == 0:
found_traj_4_truncated_left = True
if i == 1000:
break
assert not rb._sampler.span[0]
# if rb._sampler.span[0]:
# assert found_traj_4_truncated_left
if rb._sampler.span[1]:
assert found_traj_4_truncated_right
else:
assert not found_traj_4_truncated_right
if strict_length and not rb._sampler.span[1]:
assert not found_traj_0
else:
assert found_traj_0


def test_prioritized_slice_sampler_doc_example():
sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
Expand Down
16 changes: 11 additions & 5 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,17 +564,18 @@ def update_priority(
index: Union[int, torch.Tensor],
priority: Union[int, torch.Tensor],
) -> None:
if self.dim_extend > 0 and priority.ndim > 1:
priority = self._transpose(priority).flatten()
# priority = priority.flatten()
with self._replay_lock:
self._sampler.update_priority(index, priority)
self._sampler.update_priority(index, priority, storage=self.storage)

@pin_memory_output
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock:
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage.get(index)
# if self.dim_extend > 0:
# data = self._transpose(data)
if not isinstance(index, INT_CLASSES):
data = self._collate_fn(data)
if self._transform is not None and len(self._transform):
Expand Down Expand Up @@ -643,7 +644,7 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
return ret[0]

def mark_update(self, index: Union[int, torch.Tensor]) -> None:
self._sampler.mark_update(index)
self._sampler.mark_update(index, storage=self._storage)

def append_transform(
self, transform: "Transform", *, invert: bool = False # noqa-F821
Expand Down Expand Up @@ -1105,8 +1106,13 @@ def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:
return torch.zeros((0, self._storage.ndim), dtype=torch.long)

index = super()._extend(tensordicts)

# TODO: to be usable directly, the indices should be flipped but the issue
# is that just doing this results in indices that are not sorted like the original data
# so the actualy indices will have to be used on the _storage directly (not on the buffer)
self._set_index_in_td(tensordicts, index)
self.update_tensordict_priority(tensordicts)
# TODO: in principle this is a good idea but currently it doesn't work + it re-writes a priority that has just been written
# self.update_tensordict_priority(tensordicts)
return index

def _set_index_in_td(self, tensordict, index):
Expand Down
Loading

0 comments on commit 332499a

Please sign in to comment.