Skip to content

Commit

Permalink
[BugFix] Fix prefetch in samples without replacement - .sample() comp…
Browse files Browse the repository at this point in the history
…atibility issues (pytorch#2226)
  • Loading branch information
vmoens authored Jun 12, 2024
1 parent 0c008db commit f613eef
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
18 changes: 11 additions & 7 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,12 +1927,13 @@ def test_sampler_without_rep_state_dict(self, backend):
s = new_replay_buffer.sample(batch_size=1)
assert (s.exclude("index") == 0).all()

def test_sampler_without_replacement_cap_prefetch(self):
@pytest.mark.parametrize("drop_last", [False, True])
def test_sampler_without_replacement_cap_prefetch(self, drop_last):
torch.manual_seed(0)
data = TensorDict({"a": torch.arange(10)}, batch_size=[10])
data = TensorDict({"a": torch.arange(11)}, batch_size=[11])
rb = ReplayBuffer(
storage=LazyTensorStorage(10),
sampler=SamplerWithoutReplacement(),
storage=LazyTensorStorage(11),
sampler=SamplerWithoutReplacement(drop_last=drop_last),
batch_size=2,
prefetch=3,
)
Expand All @@ -1941,10 +1942,13 @@ def test_sampler_without_replacement_cap_prefetch(self):
for _ in range(100):
s = set()
for i, d in enumerate(rb):
assert i <= 4
assert i <= (4 + int(not drop_last)), i
s = s.union(set(d["a"].tolist()))
assert i == 4
assert s == set(range(10))
assert i == (4 + int(not drop_last)), i
if drop_last:
assert s != set(range(11))
else:
assert s == set(range(11))

@pytest.mark.parametrize(
"batch_size,num_slices,slice_len,prioritized",
Expand Down
3 changes: 2 additions & 1 deletion torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,8 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
while (
len(self._prefetch_queue)
< min(self._sampler._remaining_batches, self._prefetch_cap)
) and not self._sampler.ran_out:
and not self._sampler.ran_out
) or not len(self._prefetch_queue):
fut = self._prefetch_executor.submit(self._sample, batch_size)
self._prefetch_queue.append(fut)
ret = self._prefetch_queue.popleft().result()
Expand Down

0 comments on commit f613eef

Please sign in to comment.