Skip to content

Commit

Permalink
[BugFix] Fix max value within buffer during update priority (pytorch#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 21, 2024
1 parent 1d729e8 commit 3c6c801
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
15 changes: 12 additions & 3 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2629,12 +2629,21 @@ def test_prb_update_max_priority(self, max_priority_within_buffer):
for data in torch.arange(20):
idx = rb.add(data)
rb.update_priority(idx, 21 - data)
if data <= 10 or not max_priority_within_buffer:
if data <= 10:
# The max is always going to be the first value
assert rb._sampler._max_priority[0] == 21
assert rb._sampler._max_priority[1] == 0
else:
assert rb._sampler._max_priority[0] == 10
elif not max_priority_within_buffer:
# The max is the historical max, which was at idx 0
assert rb._sampler._max_priority[0] == 21
assert rb._sampler._max_priority[1] == 0
else:
# the max is the current max. Find it and compare
sumtree = torch.as_tensor(
[rb._sampler._sum_tree[i] for i in range(rb._sampler._max_capacity)]
)
assert rb._sampler._max_priority[0] == sumtree.max()
assert rb._sampler._max_priority[1] == sumtree.argmax()
idx = rb.extend(torch.arange(10))
rb.update_priority(idx, 12)
if max_priority_within_buffer:
Expand Down
18 changes: 15 additions & 3 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,24 @@ def update_priority(
priority = priority[valid_index]

max_p, max_p_idx = priority.max(dim=0)
max_priority = self._max_priority[0]
if max_priority is None or max_p > max_priority:
self._max_priority = (max_p, max_p_idx)
cur_max_priority, cur_max_priority_index = self._max_priority
if cur_max_priority is None or max_p > cur_max_priority:
cur_max_priority, cur_max_priority_index = self._max_priority = (
max_p,
index[max_p_idx] if index.ndim else index,
)
priority = torch.pow(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
self._min_tree[index] = priority
if (
self._max_priority_within_buffer
and cur_max_priority_index is not None
and (index == cur_max_priority_index).any()
):
maxval, maxidx = torch.tensor(
[self._sum_tree[i] for i in range(self._max_capacity)]
).max(0)
self._max_priority = (maxval, maxidx)

def mark_update(
self, index: Union[int, torch.Tensor], *, storage: Storage | None = None
Expand Down

0 comments on commit 3c6c801

Please sign in to comment.