Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] SliceSampler #1748

Merged
merged 10 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Dec 15, 2023
commit c12291de0c176a83da422507790aaae5d9934d88
7 changes: 1 addition & 6 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,16 +439,11 @@ def sample(
if not self._prefetch:
ret = self._sample(batch_size)
else:
if len(self._prefetch_queue) == 0:
ret = self._sample(batch_size)
else:
with self._futures_lock:
ret = self._prefetch_queue.popleft().result()

with self._futures_lock:
while len(self._prefetch_queue) < self._prefetch_cap:
fut = self._prefetch_executor.submit(self._sample, batch_size)
self._prefetch_queue.append(fut)
ret = self._prefetch_queue.popleft().result()

if return_info:
return ret
Expand Down
141 changes: 140 additions & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from tensordict import MemoryMappedTensor
from tensordict.utils import NestedKey

from ..._extension import EXTENSION_WARNING

Expand Down Expand Up @@ -223,7 +224,7 @@ class PrioritizedSampler(Sampler):
eps (float, optional): delta added to the priorities to ensure that the buffer
does not contain null priorities. Defaults to 1e-8.
reduction (str, optional): the reduction method for multidimensional
tensordicts (ie stored trajectories). Can be one of "max", "min",
tensordicts (ie stored trajectory). Can be one of "max", "min",
"median" or "mean".

"""
Expand Down Expand Up @@ -466,3 +467,141 @@ def loads(self, path):
self._sum_tree[i] = elt
for i, elt in enumerate(mm_mt.tolist()):
self._min_tree[i] = elt


class SliceSampler(Sampler):
"""Samples slices of data along the first dimension, given start and stop signals.

Keyword Args:
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument.
end_key (NestedKey, optional): the key indicating the end of a
trajectory (or episode). Defaults to ``"done"``.
traj_key (NestedKey, optional): the key indicating the trajectories.
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.

.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
attempt to find the ``traj_key`` entry in the storage. If it cannot be
found, the ``end_key`` will be used to reconstruct the episodes.

"""

def __init__(
self,
*,
num_slices: int,
end_key: NestedKey = ("next", "done"),
traj_key: NestedKey = "episode",
cache_values: bool = False,
):
self.num_slices = num_slices
self.end_key = end_key
self.traj_key = traj_key
self.cache_values = cache_values
self._fetch_traj = True
self._cache = {}

def _find_start_stop_traj(self, *, trajectory=None, end=None):
if trajectory is not None:
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, torch.ones_like(end[:1])], 0)
stop_idx = end.view(-1).nonzero().view(-1)
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1]])
lengths = stop_idx - start_idx
return stop_idx, lengths

def _tensor_slices_from_startend(self, seq_length, start):
return (torch.arange(seq_length).unsqueeze(0) + start.unsqueeze(1)).view(-1)

def _get_stop_and_length(self, storage, fallback=True):
if self.cache_values and "stop-and-length" in self._cache:
return self._cache.get("stop-and-length")

if self._fetch_traj:
# We first try with the traj_key
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
trajectory = storage.get(self._used_traj_key)
except KeyError:
trajectory = storage.get(("_data", self.traj_key))
# cache that value for future use
self._used_traj_key = ("_data", self.traj_key)
return self._cache.setdefault(
"stop-and-length", self._find_start_stop_traj(trajectory=trajectory)
)
except KeyError:
if fallback:
self._fetch_traj = False
return self._get_stop_and_length(storage, fallback=False)
raise

else:
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
done = storage.get(self._used_end_key)
except KeyError:
done = storage.get(("_data", self.end_key))
# cache that value for future use
self._used_end_key = ("_data", self.end_key)
return self._cache.setdefault(
"stop-and-length", self._find_start_stop_traj(end=done)
)
except KeyError:
if fallback:
self._fetch_traj = True
return self._get_stop_and_length(storage, fallback=False)
raise

def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
seq_length = batch_size // self.num_slices
# pick up as many trajs as we need
stop_idx, lengths = self._get_stop_and_length(storage)
traj_idx = torch.randint(len(lengths), (self.num_slices,))
starts = (
torch.rand(self.num_slices) * (stop_idx[traj_idx] - seq_length)
).floor()
index = self._tensor_slices_from_startend(seq_length, starts)
return index.to(torch.long), {}

@property
def _used_traj_key(self):
return self.__dict__.get("__used_traj_key", self.traj_key)

@_used_traj_key.setter
def _used_traj_key(self, value):
self.__used_traj_key = value

@property
def _used_end_key(self):
return self.__dict__.get("__used_end_key", self.end_key)

@_used_end_key.setter
def _used_end_key(self, value):
self.__used_end_key = value

def _empty(self):
pass

def dumps(self, path):
# no op - cache does not need to be saved
...

def loads(self, path):
# no op
...

def __getstate__(self):
state = copy(self.__dict__)
state["_cache"] = {}
return state
Loading