Skip to content

Commit

Permalink
[BugFix] Fix sampling without replacement with ndim storages (pytorch…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 7, 2024
1 parent 535bd63 commit 07eb02d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 20 deletions.
27 changes: 17 additions & 10 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2555,12 +2555,19 @@ def test_rb_indexing(self, explicit):

def _rbtype(datatype):
if datatype in ("pytree", "tensorclass"):
return [ReplayBuffer, PrioritizedReplayBuffer]
return [
(ReplayBuffer, RandomSampler),
(PrioritizedReplayBuffer, RandomSampler),
(ReplayBuffer, SamplerWithoutReplacement),
(PrioritizedReplayBuffer, SamplerWithoutReplacement),
]
return [
ReplayBuffer,
PrioritizedReplayBuffer,
TensorDictReplayBuffer,
TensorDictPrioritizedReplayBuffer,
(ReplayBuffer, RandomSampler),
(ReplayBuffer, SamplerWithoutReplacement),
(PrioritizedReplayBuffer, None),
(TensorDictReplayBuffer, RandomSampler),
(TensorDictReplayBuffer, SamplerWithoutReplacement),
(TensorDictPrioritizedReplayBuffer, None),
]


Expand Down Expand Up @@ -2598,19 +2605,19 @@ def _make_data(self, datatype, datadim):
batch_size=shape,
)

datatype_rb_pairs = [
[datatype, rbtype]
datatype_rb_tuples = [
[datatype, *rbtype]
for datatype in ["pytree", "tensordict", "tensorclass"]
for rbtype in _rbtype(datatype)
]

@pytest.mark.parametrize("datatype,rbtype", datatype_rb_pairs)
@pytest.mark.parametrize("datatype,rbtype,sampler_cls", datatype_rb_tuples)
@pytest.mark.parametrize("datadim", [1, 2])
@pytest.mark.parametrize("storage_cls", [LazyMemmapStorage, LazyTensorStorage])
def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls):
def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls, sampler_cls):
data = self._make_data(datatype, datadim)
if rbtype not in (PrioritizedReplayBuffer, TensorDictPrioritizedReplayBuffer):
rbtype = functools.partial(rbtype, sampler=RandomSampler())
rbtype = functools.partial(rbtype, sampler=sampler_cls())
else:
rbtype = functools.partial(rbtype, alpha=0.9, beta=1.1)

Expand Down
32 changes: 22 additions & 10 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,20 @@ def _get_sample_list(self, storage: Storage, len_storage: int):
device = self._sample_list.device
else:
device = storage.device if hasattr(storage, "device") else None

if self.shuffle:
self._sample_list = torch.randperm(len_storage, device=device)
_sample_list = torch.randperm(len_storage, device=device)
else:
self._sample_list = torch.arange(len_storage, device=device)
_sample_list = torch.arange(len_storage, device=device)
self._sample_list = _sample_list

def _single_sample(self, len_storage, batch_size):
index = self._sample_list[:batch_size]
self._sample_list = self._sample_list[batch_size:]

# check if we have enough elements for one more batch, assuming same batch size
# will be used each time sample is called
if self._sample_list.numel() == 0 or (
if self._sample_list.shape[0] == 0 or (
self.drop_last and len(self._sample_list) < batch_size
):
self.ran_out = True
Expand All @@ -201,7 +203,6 @@ def _storage_len(self, storage):
return len(storage)

def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
storage = storage.flatten()
len_storage = self._storage_len(storage)
if len_storage == 0:
raise RuntimeError(_EMPTY_STORAGE_ERROR)
Expand All @@ -217,6 +218,8 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
)
self.len_storage = len_storage
index = self._single_sample(len_storage, batch_size)
if storage.ndim > 1:
index = torch.unravel_index(index, storage.shape)
# we 'always' return the indices. The 'drop_last' just instructs the
# sampler to turn to 'ran_out = True` whenever the next sample
# will be too short. This will be read by the replay buffer
Expand Down Expand Up @@ -834,10 +837,7 @@ def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
)
# Using transpose ensures the start and stop are sorted the same way
stop_idx = end.transpose(0, -1).nonzero()
# beginnings = torch.cat([end[-1:], end[:-1]], 0)
# start_idx = beginnings.transpose(0, -1).nonzero()
# start_idx = torch.cat([start_idx[:, -1:], start_idx[:, :-1]], -1)
stop_idx = torch.cat([stop_idx[:, -1:], stop_idx[:, :-1]], -1)
stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone()
# First build the start indices as the stop + 1, we'll shift it later
start_idx = stop_idx.clone()
start_idx[:, 0] += 1
Expand Down Expand Up @@ -991,6 +991,8 @@ def _sample_slices(
storage_length: int,
traj_idx: torch.Tensor | None = None,
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
# start_idx and stop_idx are 2d tensors organized like a non-zero

def get_traj_idx(lengths=lengths):
return torch.randint(lengths.shape[0], (num_slices,), device=lengths.device)

Expand All @@ -999,7 +1001,7 @@ def get_traj_idx(lengths=lengths):
idx = lengths == seq_length
if not idx.any():
raise RuntimeError(
"Did not find a single trajectory with sufficient length."
f"Did not find a single trajectory with sufficient length (length range: {lengths.min()} - {lengths.max()} / required={seq_length}))."
)
if (
isinstance(seq_length, torch.Tensor)
Expand Down Expand Up @@ -1307,14 +1309,24 @@ def sample(
seq_length, num_slices = self._adjusted_batch_size(batch_size)
indices, _ = SamplerWithoutReplacement.sample(self, storage, num_slices)
storage_length = storage.shape[0]

# traj_idx will either be a single tensor or a tuple that can be reorganized
# like a non-zero through stacking.
def tuple_to_tensor(traj_idx, lengths=lengths):
if isinstance(traj_idx, tuple):
traj_idx = torch.arange(len(storage), device=lengths.device).view(
storage.shape
)[traj_idx]
return traj_idx

idx, info = self._sample_slices(
lengths,
start_idx,
stop_idx,
seq_length,
num_slices,
storage_length,
traj_idx=indices,
traj_idx=tuple_to_tensor(indices),
)
return idx, info

Expand Down

0 comments on commit 07eb02d

Please sign in to comment.