Skip to content

Commit

Permalink
[BugFix] Fix max-priority update (pytorch#2215)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 8, 2024
1 parent 0813dc0 commit 4d37ee1
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 5 deletions.
29 changes: 29 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,35 @@ def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span):
else:
assert found_traj_0

@pytest.mark.parametrize("max_priority_within_buffer", [True, False])
def test_prb_update_max_priority(self, max_priority_within_buffer):
rb = ReplayBuffer(
storage=LazyTensorStorage(11),
sampler=PrioritizedSampler(
max_capacity=11,
alpha=1.0,
beta=1.0,
max_priority_within_buffer=max_priority_within_buffer,
),
)
for data in torch.arange(20):
idx = rb.add(data)
rb.update_priority(idx, 21 - data)
if data <= 10 or not max_priority_within_buffer:
assert rb._sampler._max_priority[0] == 21
assert rb._sampler._max_priority[1] == 0
else:
assert rb._sampler._max_priority[0] == 10
assert rb._sampler._max_priority[1] == 0
idx = rb.extend(torch.arange(10))
rb.update_priority(idx, 12)
if max_priority_within_buffer:
assert rb._sampler._max_priority[0] == 12
assert rb._sampler._max_priority[1] == 0
else:
assert rb._sampler._max_priority[0] == 21
assert rb._sampler._max_priority[1] == 0


def test_prioritized_slice_sampler_doc_example():
sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
Expand Down
74 changes: 69 additions & 5 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ class PrioritizedSampler(Sampler):
reduction (str, optional): the reduction method for multidimensional
tensordicts (ie stored trajectory). Can be one of "max", "min",
"median" or "mean".
max_priority_within_buffer (bool, optional): if ``True``, the max-priority
is tracked within the buffer. When ``False``, the max-priority tracks
the maximum value since the instantiation of the sampler.
Examples:
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
Expand Down Expand Up @@ -334,6 +337,7 @@ def __init__(
eps: float = 1e-8,
dtype: torch.dtype = torch.float,
reduction: str = "max",
max_priority_within_buffer: bool = False,
) -> None:
if alpha < 0:
raise ValueError(
Expand All @@ -348,6 +352,7 @@ def __init__(
self._eps = eps
self.reduction = reduction
self.dtype = dtype
self._max_priority_within_buffer = max_priority_within_buffer
self._init()

def __repr__(self):
Expand Down Expand Up @@ -376,14 +381,62 @@ def _init(self):
raise NotImplementedError(
f"dtype {self.dtype} not supported by PrioritizedSampler"
)
self._max_priority = 1.0
self._max_priority = None

def _empty(self):
self._init()

@property
def _max_priority(self):
max_priority_index = self.__dict__.get("_max_priority")
if max_priority_index is None:
return (None, None)
return max_priority_index

@_max_priority.setter
def _max_priority(self, value):
self.__dict__["_max_priority"] = value

def _maybe_erase_max_priority(self, index):
if not self._max_priority_within_buffer:
return
max_priority_index = self._max_priority[1]
if max_priority_index is None:
return

def check_index(index=index, max_priority_index=max_priority_index):
if isinstance(index, torch.Tensor):
# index can be 1d or 2d
if index.ndim == 1:
is_overwritten = (index == max_priority_index).any()
else:
is_overwritten = (index == max_priority_index).all(-1).any()
elif isinstance(index, int):
is_overwritten = index == max_priority_index
elif isinstance(index, slice):
# This won't work if called recursively
is_overwritten = max_priority_index in range(
index.indices(self._max_capacity)
)
elif isinstance(index, tuple):
is_overwritten = isinstance(max_priority_index, tuple)
if is_overwritten:
for idx, mpi in zip(index, max_priority_index):
is_overwritten &= check_index(idx, mpi)
else:
raise TypeError(f"index of type {type(index)} is not recognized.")
return is_overwritten

is_overwritten = check_index()
if is_overwritten:
self._max_priority = None

@property
def default_priority(self) -> float:
return (self._max_priority + self._eps) ** self._alpha
mp = self._max_priority[0]
if mp is None:
mp = 1
return (mp + self._eps) ** self._alpha

def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
if len(storage) == 0:
Expand Down Expand Up @@ -422,11 +475,13 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:

return index, {"_weight": weight}

def add(self, index: int) -> None:
def add(self, index: torch.Tensor | int) -> None:
super().add(index)
self._maybe_erase_max_priority(index)

def extend(self, index: torch.Tensor) -> None:
def extend(self, index: torch.Tensor | tuple) -> None:
super().extend(index)
self._maybe_erase_max_priority(index)

@torch.no_grad()
def update_priority(
Expand Down Expand Up @@ -494,7 +549,10 @@ def update_priority(
if priority.ndim:
priority = priority[valid_index]

self._max_priority = priority.max().clamp_min(self._max_priority).item()
max_p, max_p_idx = priority.max(dim=0)
max_priority = self._max_priority[0]
if max_priority is None or max_p > max_priority:
self._max_priority = (max_p, max_p_idx)
priority = torch.pow(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
self._min_tree[index] = priority
Expand Down Expand Up @@ -1563,6 +1621,10 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
that at least `slice_len - i` samples will be gathered for each sampled trajectory.
Using tuples allows a fine grained control over the span on the left (beginning
of the stored trajectory) and on the right (end of the stored trajectory).
max_priority_within_buffer (bool, optional): if ``True``, the max-priority
is tracked within the buffer. When ``False``, the max-priority tracks
the maximum value since the instantiation of the sampler.
Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -1621,6 +1683,7 @@ def __init__(
strict_length: bool = True,
compile: bool | dict = False,
span: bool | int | Tuple[bool | int, bool | int] = False,
max_priority_within_buffer: bool = False,
):
SliceSampler.__init__(
self,
Expand All @@ -1644,6 +1707,7 @@ def __init__(
eps=eps,
dtype=dtype,
reduction=reduction,
max_priority_within_buffer=max_priority_within_buffer,
)
if self.span[0]:
# Span left is hard to achieve because we need to sample 'negative' starts, but to sample
Expand Down

0 comments on commit 4d37ee1

Please sign in to comment.