Skip to content

Commit

Permalink
[BugFix] Fix and test PRB priority update across dims and rb types (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 24, 2024
1 parent 3c6c801 commit 1062e3e
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 5 deletions.
147 changes: 147 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,6 +2653,153 @@ def test_prb_update_max_priority(self, max_priority_within_buffer):
assert rb._sampler._max_priority[0] == 21
assert rb._sampler._max_priority[1] == 0

def test_prb_ndim(self):
"""This test lists all the possible ways of updating the priority of a PRB with RB, TRB and TPRB.
All tests are done for 1d and 2d TDs.
"""
torch.manual_seed(0)
np.random.seed(0)

# first case: 1d, RB
rb = ReplayBuffer(
sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0),
storage=LazyTensorStorage(100),
batch_size=4,
)
data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10])
idx = rb.extend(data)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
rb.update_priority(idx, 2)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()
s, info = rb.sample(return_info=True)
rb.update_priority(info["index"], 3)
assert (
torch.tensor([rb._sampler._sum_tree[i] for i in range(10)])[info["index"]]
== 3
).all()

# second case: 1d, TRB
rb = TensorDictReplayBuffer(
sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0),
storage=LazyTensorStorage(100),
batch_size=4,
)
data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10])
idx = rb.extend(data)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
rb.update_priority(idx, 2)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()
s = rb.sample()
rb.update_priority(s["index"], 3)
assert (
torch.tensor([rb._sampler._sum_tree[i] for i in range(10)])[s["index"]] == 3
).all()

# third case: 1d TPRB
rb = TensorDictPrioritizedReplayBuffer(
alpha=1.0,
beta=1.0,
storage=LazyTensorStorage(100),
batch_size=4,
priority_key="p",
)
data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10])
idx = rb.extend(data)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
rb.update_priority(idx, 2)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()
s = rb.sample()

s["p"] = torch.ones(4) * 10_000
rb.update_tensordict_priority(s)
assert (
torch.tensor([rb._sampler._sum_tree[i] for i in range(10)])[s["index"]]
== 10_000
).all()

s2 = rb.sample()
# All indices in s2 must be from s since we set a very high priority to these items
assert (s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).any(0).all()

# fourth case: 2d RB
rb = ReplayBuffer(
sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0),
storage=LazyTensorStorage(100, ndim=2),
batch_size=4,
)
data = TensorDict(
{"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5]
)
idx = rb.extend(data)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
rb.update_priority(idx, 2)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()

s, info = rb.sample(return_info=True)
rb.update_priority(info["index"], 3)
priorities = torch.tensor(
[rb._sampler._sum_tree[i] for i in range(10)]
).reshape((5, 2))
assert (priorities[info["index"]] == 3).all()

# fifth case: 2d TRB
# 2d
rb = TensorDictReplayBuffer(
sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0),
storage=LazyTensorStorage(100, ndim=2),
batch_size=4,
)
data = TensorDict(
{"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5]
)
idx = rb.extend(data)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
rb.update_priority(idx, 2)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()

s = rb.sample()
rb.update_priority(s["index"], 10_000)
priorities = torch.tensor(
[rb._sampler._sum_tree[i] for i in range(10)]
).reshape((5, 2))
assert (priorities[s["index"].unbind(-1)] == 10_000).all()

s2 = rb.sample()
assert (
(s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).all(-1).any(0).all()
)

# Sixth case: 2d TDPRB
rb = TensorDictPrioritizedReplayBuffer(
alpha=1.0,
beta=1.0,
storage=LazyTensorStorage(100, ndim=2),
batch_size=4,
priority_key="p",
)
data = TensorDict(
{"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5]
)
idx = rb.extend(data)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
rb.update_priority(idx, torch.ones(()) * 2)
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()
s = rb.sample()
# setting the priorities to a value that is so big that the buffer will resample them
s["p"] = torch.ones(4) * 10_000
rb.update_tensordict_priority(s)
priorities = torch.tensor(
[rb._sampler._sum_tree[i] for i in range(10)]
).reshape((5, 2))
assert (priorities[s["index"].unbind(-1)] == 10_000).all()

s2 = rb.sample()
assert (
(s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).all(-1).any(0).all()
)


def test_prioritized_slice_sampler_doc_example():
sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
Expand Down
16 changes: 11 additions & 5 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,12 @@ def extend(self, data: Sequence) -> torch.Tensor:

def update_priority(
self,
index: Union[int, torch.Tensor],
index: Union[int, torch.Tensor, Tuple[torch.Tensor]],
priority: Union[int, torch.Tensor],
) -> None:
if isinstance(index, tuple):
index = torch.stack(index, -1)
priority = torch.as_tensor(priority)
if self.dim_extend > 0 and priority.ndim > 1:
priority = self._transpose(priority).flatten()
# priority = priority.flatten()
Expand Down Expand Up @@ -1095,7 +1098,7 @@ def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor:
dtype=torch.float,
device=tensordict.device,
).expand(tensordict.shape[0])
if self._storage.ndim > 1:
if self._storage.ndim > 1 and priority.ndim >= self._storage.ndim:
# We have to flatten the priority otherwise we'll be aggregating
# the priority across batches
priority = priority.flatten(0, self._storage.ndim - 1)
Expand Down Expand Up @@ -1172,9 +1175,12 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
else:
priority = torch.as_tensor(self._get_priority_item(data))
index = data.get("index")
while index.shape != priority.shape:
# reduce index
index = index[..., 0]
if self._storage.ndim > 1 and index.ndim == 2:
index = index.unbind(-1)
else:
while index.shape != priority.shape:
# reduce index
index = index[..., 0]
return self.update_priority(index, priority)

def sample(
Expand Down

0 comments on commit 1062e3e

Please sign in to comment.