Skip to content

Commit

Permalink
[Doc] Improve PrioritizedSampler doc and get rid of np dependency as …
Browse files Browse the repository at this point in the history
…much as possible (pytorch#1881)
  • Loading branch information
vmoens authored Feb 7, 2024
1 parent 1fe745a commit 144f547
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 35 deletions.
2 changes: 1 addition & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def test_set_tensorclass(self, max_size, shape, storage):
@pytest.mark.parametrize("priority_key", ["pk", "td_error"])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
def test_prototype_prb(priority_key, contiguous, device):
def test_ptdrb(priority_key, contiguous, device):
torch.manual_seed(0)
np.random.seed(0)
rb = TensorDictReplayBuffer(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
if data.ndim:
priority = self._get_priority_vector(data)
else:
priority = self._get_priority_item(data)
priority = torch.as_tensor(self._get_priority_item(data))
index = data.get("index")
while index.shape != priority.shape:
# reduce index
Expand Down
116 changes: 83 additions & 33 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from torchrl._utils import _replace_last
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES

try:
from torchrl._torchrl import (
Expand Down Expand Up @@ -250,11 +249,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
class PrioritizedSampler(Sampler):
"""Prioritized sampler for replay buffer.
Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
Prioritized experience replay."
(https://arxiv.org/abs/1511.05952)
Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952)
Args:
max_capacity (int): maximum capacity of the buffer.
alpha (float): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
Expand All @@ -264,6 +262,51 @@ class PrioritizedSampler(Sampler):
tensordicts (ie stored trajectory). Can be one of "max", "min",
"median" or "mean".
Examples:
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> rb.add(data_0)
>>> rb.add(data_1)
>>> rb.update_priority(torch.tensor([0, 1]), priority=priority)
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample)
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> print(info)
{'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
.. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the
process of updating the priorities:
>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = TDRB(
... storage=LazyTensorStorage(10),
... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
... priority_key="priority", # This kwarg isn't present in regular RBs
... )
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> data = torch.stack([data_0, data_1])
>>> rb.extend(data)
>>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample['index']) # The index is packed with the tensordict
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
"""

def __init__(
Expand Down Expand Up @@ -327,15 +370,17 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
raise RuntimeError("negative p_sum")
if p_min <= 0:
raise RuntimeError("negative p_min")
# For some undefined reason, only np.random works here.
# All PT attempts fail, even when subsequently transformed into numpy
mass = np.random.uniform(0.0, p_sum, size=batch_size)
# mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum)
# mass = torch.rand(batch_size).mul_(p_sum)
index = self._sum_tree.scan_lower_bound(mass)
if not isinstance(index, np.ndarray):
index = np.array([index])
if isinstance(index, torch.Tensor):
index.clamp_max_(len(storage) - 1)
else:
index = np.clip(index, None, len(storage) - 1)
weight = self._sum_tree[index]
index = torch.as_tensor(index)
if not index.ndim:
index = index.unsqueeze(0)
index.clamp_max_(len(storage) - 1)
weight = torch.as_tensor(self._sum_tree[index])

# Importance sampling weight formula:
# w_i = (p_i / sum(p) * N) ^ (-beta)
Expand All @@ -345,9 +390,10 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
# weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta)
# weight_i = (p_i / min(p)) ^ (-beta)
# weight = np.power(weight / (p_min + self._eps), -self._beta)
weight = np.power(weight / p_min, -self._beta)
weight = torch.pow(weight / p_min, -self._beta)
return index, {"_weight": weight}

@torch.no_grad()
def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None:
priority = self.default_priority

Expand All @@ -360,6 +406,11 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None:
"priority should be a scalar or an iterable of the same "
"length as index"
)
# make sure everything is cast to cpu
if isinstance(index, torch.Tensor) and not index.is_cpu:
index = index.cpu()
if isinstance(priority, torch.Tensor) and not priority.is_cpu:
priority = priority.cpu()

self._sum_tree[index] = priority
self._min_tree[index] = priority
Expand All @@ -377,6 +428,7 @@ def extend(self, index: torch.Tensor) -> None:
index = index.cpu()
self._add_or_extend(index)

@torch.no_grad()
def update_priority(
self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor]
) -> None:
Expand All @@ -389,28 +441,26 @@ def update_priority(
indexed elements.
"""
if isinstance(index, INT_CLASSES):
if not isinstance(priority, float):
if len(priority) != 1:
raise RuntimeError(
f"priority length should be 1, got {len(priority)}"
)
priority = priority.item()
else:
if not (
isinstance(priority, float)
or len(priority) == 1
or len(index) == len(priority)
):
priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
index = torch.as_tensor(
index, dtype=torch.long, device=torch.device("cpu")
).detach()
# we need to reshape priority if it has more than one elements or if it has
# a different shape than index
if priority.numel() > 1 and priority.shape != index.shape:
try:
priority = priority.reshape(index.shape[:1])
except Exception as err:
raise RuntimeError(
"priority should be a number or an iterable of the same "
"length as index"
)
index = _to_numpy(index)
priority = _to_numpy(priority)

self._max_priority = max(self._max_priority, np.max(priority))
priority = np.power(priority + self._eps, self._alpha)
f"length as index. Got priority of shape {priority.shape} and index "
f"{index.shape}."
) from err
elif priority.numel() <= 1:
priority = priority.squeeze()

self._max_priority = priority.max().clamp_min(self._max_priority).item()
priority = torch.pow(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
self._min_tree[index] = priority

Expand Down Expand Up @@ -1233,7 +1283,7 @@ def __getitem__(self, index):
if isinstance(index, slice) and index == slice(None):
return self
if isinstance(index, (list, range, np.ndarray)):
index = torch.tensor(index)
index = torch.as_tensor(index)
if isinstance(index, torch.Tensor):
if index.ndim > 1:
raise RuntimeError(
Expand Down

0 comments on commit 144f547

Please sign in to comment.