Skip to content

Commit

Permalink
[BugFix] Allow zero alpha value for PrioritizedSampler (pytorch#2164)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 authored May 30, 2024
1 parent e284b5f commit 0f02c4a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
10 changes: 8 additions & 2 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,11 +895,12 @@ def test_extend_list_pytree(self, max_size, shape, storage):
@pytest.mark.parametrize("priority_key", ["pk", "td_error"])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
def test_ptdrb(priority_key, contiguous, device):
@pytest.mark.parametrize("alpha", [0.0, 0.7])
def test_ptdrb(priority_key, contiguous, alpha, device):
torch.manual_seed(0)
np.random.seed(0)
rb = TensorDictReplayBuffer(
sampler=samplers.PrioritizedSampler(5, alpha=0.7, beta=0.9),
sampler=samplers.PrioritizedSampler(5, alpha=alpha, beta=0.9),
priority_key=priority_key,
batch_size=5,
)
Expand Down Expand Up @@ -934,6 +935,11 @@ def test_ptdrb(priority_key, contiguous, device):
assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all()
assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a"))

if (
alpha == 0.0
): # when alpha is 0.0, sampling is uniform, so no need to check priority sampling
return

# test strong update
# get all indices that match first item
idx = s.get("_idx")
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,9 @@ def __init__(
dtype: torch.dtype = torch.float,
reduction: str = "max",
) -> None:
if alpha <= 0:
if alpha < 0:
raise ValueError(
f"alpha must be strictly greater than 0, got alpha={alpha}"
f"alpha must be greater or equal than 0, got alpha={alpha}"
)
if beta < 0:
raise ValueError(f"beta must be greater or equal to 0, got beta={beta}")
Expand Down

0 comments on commit 0f02c4a

Please sign in to comment.