Skip to content

Commit

Permalink
[Benchmark] Benchmark slice sampler (#1992)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 4, 2024
1 parent 52827c2 commit 9337550
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
20 changes: 19 additions & 1 deletion benchmarks/test_replaybuffer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement
from torchrl.data.replay_buffers import (
RandomSampler,
SamplerWithoutReplacement,
SliceSampler,
)

_TensorDictPrioritizedReplayBuffer = functools.partial(
TensorDictPrioritizedReplayBuffer, alpha=1, beta=0.9
Expand Down Expand Up @@ -49,6 +53,8 @@ def __call__(self):
},
batch_size=[self.size],
)
if "sampler" in kwargs and isinstance(kwargs["sampler"], SliceSampler):
data["traj"] = torch.arange(self.size) // 123
if self.populated:
rb.extend(data)
return ((rb,), {})
Expand Down Expand Up @@ -77,6 +83,18 @@ def iterate(rb):
[TensorDictReplayBuffer, ListStorage, SamplerWithoutReplacement, 4000],
[TensorDictReplayBuffer, LazyMemmapStorage, SamplerWithoutReplacement, 10_000],
[TensorDictReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement, 10_000],
[
TensorDictReplayBuffer,
LazyMemmapStorage,
functools.partial(SliceSampler, num_slices=8, traj_key="traj"),
10_000,
],
[
TensorDictReplayBuffer,
LazyTensorStorage,
functools.partial(SliceSampler, num_slices=8, traj_key="traj"),
10_000,
],
[_TensorDictPrioritizedReplayBuffer, ListStorage, None, 4000],
[_TensorDictPrioritizedReplayBuffer, LazyMemmapStorage, None, 10_000],
[_TensorDictPrioritizedReplayBuffer, LazyTensorStorage, None, 10_000],
Expand Down
4 changes: 4 additions & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,8 @@ def _get_stop_and_length(self, storage, fallback=True):
else:
try:
trajectory = storage[:].get(self.traj_key)
except KeyError:
raise
except Exception:
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
Expand All @@ -855,6 +857,8 @@ def _get_stop_and_length(self, storage, fallback=True):
try:
try:
done = storage[:].get(self.end_key)
except KeyError:
raise
except Exception:
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
Expand Down

0 comments on commit 9337550

Please sign in to comment.