Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix max-priority update #2215

Merged
merged 6 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Jun 7, 2024
commit fae8ce7f61bc49d0b807298dc2899e4d5a7afe95
19 changes: 19 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,25 @@ def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span):
else:
assert found_traj_0

def test_prb_update_max_priority(self):
rb = ReplayBuffer(
storage=LazyTensorStorage(10),
sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
)
for data in torch.arange(20):
idx = rb.add(data)
rb.update_priority(idx, 21 - data)
if data <= 9:
assert rb._sampler._max_priority[0] == 21
assert rb._sampler._max_priority[1] == 0
else:
assert rb._sampler._max_priority[0] == 11
assert rb._sampler._max_priority[1] == 0
idx = rb.extend(torch.arange(10))
rb.update_priority(idx, 12)
assert rb._sampler._max_priority[0] == 12
assert rb._sampler._max_priority[1] == 0


def test_prioritized_slice_sampler_doc_example():
sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
Expand Down
61 changes: 56 additions & 5 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,14 +376,60 @@ def _init(self):
raise NotImplementedError(
f"dtype {self.dtype} not supported by PrioritizedSampler"
)
self._max_priority = 1.0
self._max_priority = None

def _empty(self):
self._init()

@property
def _max_priority(self):
max_priority_index = self.__dict__.get("_max_priority")
if max_priority_index is None:
return (None, None)
return max_priority_index

@_max_priority.setter
def _max_priority(self, value):
self.__dict__["_max_priority"] = value

def _maybe_erase_max_priority(self, index):
max_priority_index = self._max_priority[1]
if max_priority_index is None:
return

def check_index(index=index, max_priority_index=max_priority_index):
if isinstance(index, torch.Tensor):
# index can be 1d or 2d
if index.ndim == 1:
is_overwritten = (index == max_priority_index).any()
else:
is_overwritten = (index == max_priority_index).all(-1).any()
elif isinstance(index, int):
is_overwritten = index == max_priority_index
elif isinstance(index, slice):
# This won't work if called recursively
is_overwritten = max_priority_index in range(
index.indices(self._max_capacity)
)
elif isinstance(index, tuple):
is_overwritten = isinstance(max_priority_index, tuple)
if is_overwritten:
for idx, mpi in zip(index, max_priority_index):
is_overwritten &= check_index(idx, mpi)
else:
raise TypeError(f"index of type {type(index)} is not recognized.")
return is_overwritten

is_overwritten = check_index()
if is_overwritten:
self._max_priority = None

@property
def default_priority(self) -> float:
return (self._max_priority + self._eps) ** self._alpha
mp = self._max_priority[0]
if mp is None:
mp = 1
return (mp + self._eps) ** self._alpha

def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
if len(storage) == 0:
Expand Down Expand Up @@ -422,11 +468,13 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:

return index, {"_weight": weight}

def add(self, index: int) -> None:
def add(self, index: torch.Tensor | int) -> None:
super().add(index)
self._maybe_erase_max_priority(index)

def extend(self, index: torch.Tensor) -> None:
def extend(self, index: torch.Tensor | tuple) -> None:
super().extend(index)
self._maybe_erase_max_priority(index)

@torch.no_grad()
def update_priority(
Expand Down Expand Up @@ -494,7 +542,10 @@ def update_priority(
if priority.ndim:
priority = priority[valid_index]

self._max_priority = priority.max().clamp_min(self._max_priority).item()
max_p, max_p_idx = priority.max(dim=0)
max_priority = self._max_priority[0]
if max_priority is None or max_p > max_priority:
self._max_priority = (max_p, max_p_idx)
priority = torch.pow(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
self._min_tree[index] = priority
Expand Down
Loading