Skip to content

Commit

Permalink
[BugFix] buffer __iter__ for samplers without replacement + prefetch (#…
Browse files Browse the repository at this point in the history
…2185)

Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
JulianKu and vmoens authored Jun 11, 2024
1 parent 672b50e commit 1029f10
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 5 deletions.
19 changes: 19 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,6 +1927,25 @@ def test_sampler_without_rep_state_dict(self, backend):
s = new_replay_buffer.sample(batch_size=1)
assert (s.exclude("index") == 0).all()

def test_sampler_without_replacement_cap_prefetch(self):
torch.manual_seed(0)
data = TensorDict({"a": torch.arange(10)}, batch_size=[10])
rb = ReplayBuffer(
storage=LazyTensorStorage(10),
sampler=SamplerWithoutReplacement(),
batch_size=2,
prefetch=3,
)
rb.extend(data)

for _ in range(100):
s = set()
for i, d in enumerate(rb):
assert i <= 4
s = s.union(set(d["a"].tolist()))
assert i == 4
assert s == set(range(10))

@pytest.mark.parametrize(
"batch_size,num_slices,slice_len,prioritized",
[
Expand Down
9 changes: 7 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,10 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
ret = self._sample(batch_size)
else:
with self._futures_lock:
while len(self._prefetch_queue) < self._prefetch_cap:
while (
len(self._prefetch_queue)
< min(self._sampler._remaining_batches, self._prefetch_cap)
) and not self._sampler.ran_out:
fut = self._prefetch_executor.submit(self._sample, batch_size)
self._prefetch_queue.append(fut)
ret = self._prefetch_queue.popleft().result()
Expand Down Expand Up @@ -715,7 +718,9 @@ def __iter__(self):
"Cannot iterate over the replay buffer. "
"Batch_size was not specified during construction of the replay buffer."
)
while not self._sampler.ran_out:
while not self._sampler.ran_out or (
self._prefetch and len(self._prefetch_queue)
):
yield self.sample()

def __getstate__(self) -> Dict[str, Any]:
Expand Down
20 changes: 17 additions & 3 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
class Sampler(ABC):
"""A generic sampler base class for composable Replay Buffers."""

# Some samplers - mainly those without replacement -
# need to keep track of the number of remaining batches
_remaining_batches = int(torch.iinfo(torch.int64).max)

@abstractmethod
def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
...
Expand Down Expand Up @@ -174,7 +178,7 @@ def loads(self, path):
metadata = json.load(file)
self.load_state_dict(metadata)

def _get_sample_list(self, storage: Storage, len_storage: int):
def _get_sample_list(self, storage: Storage, len_storage: int, batch_size: int):
if storage is None:
device = self._sample_list.device
else:
Expand All @@ -185,18 +189,28 @@ def _get_sample_list(self, storage: Storage, len_storage: int):
else:
_sample_list = torch.arange(len_storage, device=device)
self._sample_list = _sample_list
if self.drop_last:
self._remaining_batches = self._sample_list.numel() // batch_size
else:
self._remaining_batches = -(self._sample_list.numel() // -batch_size)

def _single_sample(self, len_storage, batch_size):
index = self._sample_list[:batch_size]
self._sample_list = self._sample_list[batch_size:]
if self.drop_last:
self._remaining_batches = self._sample_list.numel() // batch_size
else:
self._remaining_batches = -(self._sample_list.numel() // -batch_size)

# check if we have enough elements for one more batch, assuming same batch size
# will be used each time sample is called
if self._sample_list.shape[0] == 0 or (
self.drop_last and len(self._sample_list) < batch_size
):
self.ran_out = True
self._get_sample_list(storage=None, len_storage=len_storage)
self._get_sample_list(
storage=None, len_storage=len_storage, batch_size=batch_size
)
else:
self.ran_out = False
return index
Expand All @@ -213,7 +227,7 @@ def sample(
if not len_storage:
raise RuntimeError("An empty storage was passed")
if self.len_storage != len_storage or self._sample_list is None:
self._get_sample_list(storage, len_storage)
self._get_sample_list(storage, len_storage, batch_size=batch_size)
if len_storage < batch_size and self.drop_last:
raise ValueError(
f"The batch size ({batch_size}) is greater than the storage capacity ({len_storage}). "
Expand Down

0 comments on commit 1029f10

Please sign in to comment.