Skip to content

Commit

Permalink
[Feature] Add PrioritizedSliceSampler (pytorch#1875)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene authored Feb 7, 2024
1 parent b34e2d2 commit 4d52d5f
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ using the following components:

Sampler
PrioritizedSampler
PrioritizedSliceSampler
RandomSampler
SamplerWithoutReplacement
SliceSampler
Expand Down
149 changes: 135 additions & 14 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from torchrl.data.replay_buffers import samplers, writers
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
PrioritizedSliceSampler,
RandomSampler,
SamplerEnsemble,
SamplerWithoutReplacement,
Expand Down Expand Up @@ -1834,13 +1835,14 @@ def test_sampler_without_rep_state_dict(self, backend):
assert (s.exclude("index") == 0).all()

@pytest.mark.parametrize(
"batch_size,num_slices,slice_len",
"batch_size,num_slices,slice_len,prioritized",
[
[100, 20, None],
[120, 30, None],
[100, None, 5],
[120, None, 4],
[101, None, 101],
[100, 20, None, True],
[100, 20, None, False],
[120, 30, None, False],
[100, None, 5, False],
[120, None, 4, False],
[101, None, 101, False],
],
)
@pytest.mark.parametrize("episode_key", ["episode", ("some", "episode")])
Expand All @@ -1853,6 +1855,7 @@ def test_slice_sampler(
batch_size,
num_slices,
slice_len,
prioritized,
episode_key,
done_key,
match_episode,
Expand Down Expand Up @@ -1897,19 +1900,34 @@ def test_slice_sampler(
else:
strict_length = True

sampler = SliceSampler(
num_slices=num_slices,
traj_key=episode_key,
end_key=done_key,
slice_len=slice_len,
strict_length=strict_length,
)
if prioritized:
num_steps = data.shape[0]
sampler = PrioritizedSliceSampler(
max_capacity=num_steps,
alpha=0.7,
beta=0.9,
num_slices=num_slices,
traj_key=episode_key,
end_key=done_key,
slice_len=slice_len,
strict_length=strict_length,
)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
else:
sampler = SliceSampler(
num_slices=num_slices,
traj_key=episode_key,
end_key=done_key,
slice_len=slice_len,
strict_length=strict_length,
)
if slice_len is not None:
num_slices = batch_size // slice_len
trajs_unique_id = set()
too_short = False
count_unique = set()
for _ in range(10):
for _ in range(30):
index, info = sampler.sample(storage, batch_size=batch_size)
if _data_prefix:
samples = storage._storage["_data"][index]
Expand All @@ -1918,6 +1936,7 @@ def test_slice_sampler(
if strict_length:
# check that trajs are ok
samples = samples.view(num_slices, -1)

assert samples["another_episode"].unique(
dim=1
).squeeze().shape == torch.Size([num_slices])
Expand All @@ -1936,6 +1955,7 @@ def test_slice_sampler(
raise AssertionError(
f"Not all items can be sampled: {set(range(100))-count_unique} are missing"
)

if strict_length:
assert not too_short
else:
Expand Down Expand Up @@ -2071,6 +2091,107 @@ def test_slice_sampler_without_replacement(
assert truncated.view(num_slices, -1)[:, -1].all()


def test_prioritized_slice_sampler_doc_example():
sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
rb = TensorDictReplayBuffer(
storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6
)
data = TensorDict(
{
"observation": torch.randn(9, 16),
"action": torch.randn(9, 1),
"episode": torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=torch.long),
"steps": torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], dtype=torch.long),
("next", "observation"): torch.randn(9, 16),
("next", "reward"): torch.randn(9, 1),
("next", "done"): torch.tensor(
[0, 0, 1, 0, 0, 1, 0, 0, 1], dtype=torch.bool
).unsqueeze(1),
},
batch_size=[9],
)
rb.extend(data)
sample, info = rb.sample(return_info=True)
# print("episode", sample["episode"].tolist())
# print("steps", sample["steps"].tolist())
# print("weight", info["_weight"].tolist())

priority = torch.tensor([0, 3, 3, 0, 0, 0, 1, 1, 1])
rb.update_priority(torch.arange(0, 9, 1), priority=priority)
sample, info = rb.sample(return_info=True)
# print("episode", sample["episode"].tolist())
# print("steps", sample["steps"].tolist())
# print("weight", info["_weight"].tolist())


@pytest.mark.parametrize("device", get_default_devices())
def test_prioritized_slice_sampler_episodes(device):
num_slices = 10
batch_size = 20

episode = torch.zeros(100, dtype=torch.int, device=device)
episode[:30] = 1
episode[30:55] = 2
episode[55:70] = 3
episode[70:] = 4
steps = torch.cat(
[torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0
)
done = torch.zeros(100, 1, dtype=torch.bool)
done[torch.tensor([29, 54, 69])] = 1

data = TensorDict(
{
"observation": torch.randn(100, 16),
"action": torch.randn(100, 4),
"episode": episode,
"steps": steps,
("next", "observation"): torch.randn(100, 16),
("next", "reward"): torch.randn(100, 1),
("next", "done"): done,
},
batch_size=[100],
device=device,
)

num_steps = data.shape[0]
sampler = PrioritizedSliceSampler(
max_capacity=num_steps,
alpha=0.7,
beta=0.9,
num_slices=num_slices,
)

rb = TensorDictReplayBuffer(
storage=LazyMemmapStorage(100),
sampler=sampler,
batch_size=batch_size,
)
rb.extend(data)

episodes = []
for _ in range(10):
sample = rb.sample()
episodes.append(sample["episode"])
assert {1, 2, 3, 4} == set(
torch.cat(episodes).cpu().tolist()
), "all episodes are expected to be sampled at least once"

index = torch.arange(0, num_steps, 1)
new_priorities = torch.cat(
[torch.ones(30), torch.zeros(25), torch.ones(15), torch.zeros(30)], 0
)
sampler.update_priority(index, new_priorities)

episodes = []
for _ in range(10):
sample = rb.sample()
episodes.append(sample["episode"])
assert {1, 3} == set(
torch.cat(episodes).cpu().tolist()
), "after priority update, only episode 1 and 3 are expected to be sampled"


class TestEnsemble:
def _make_data(self, data_type):
if data_type is torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from .samplers import (
PrioritizedSampler,
PrioritizedSliceSampler,
RandomSampler,
Sampler,
SamplerEnsemble,
Expand Down
Loading

0 comments on commit 4d52d5f

Please sign in to comment.