From 4d52d5ffe085aff01d6eee6d3e901aa25d4c7561 Mon Sep 17 00:00:00 2001 From: Remi Date: Wed, 7 Feb 2024 21:17:43 +0100 Subject: [PATCH] [Feature] Add PrioritizedSliceSampler (#1875) --- docs/source/reference/data.rst | 1 + test/test_rb.py | 149 ++++++++++++-- torchrl/data/replay_buffers/__init__.py | 1 + torchrl/data/replay_buffers/samplers.py | 246 +++++++++++++++++++++++- 4 files changed, 379 insertions(+), 18 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 6ed32ebe921..8ab6401b314 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -134,6 +134,7 @@ using the following components: Sampler PrioritizedSampler + PrioritizedSliceSampler RandomSampler SamplerWithoutReplacement SliceSampler diff --git a/test/test_rb.py b/test/test_rb.py index 548e4ba9726..697981909b5 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -40,6 +40,7 @@ from torchrl.data.replay_buffers import samplers, writers from torchrl.data.replay_buffers.samplers import ( PrioritizedSampler, + PrioritizedSliceSampler, RandomSampler, SamplerEnsemble, SamplerWithoutReplacement, @@ -1834,13 +1835,14 @@ def test_sampler_without_rep_state_dict(self, backend): assert (s.exclude("index") == 0).all() @pytest.mark.parametrize( - "batch_size,num_slices,slice_len", + "batch_size,num_slices,slice_len,prioritized", [ - [100, 20, None], - [120, 30, None], - [100, None, 5], - [120, None, 4], - [101, None, 101], + [100, 20, None, True], + [100, 20, None, False], + [120, 30, None, False], + [100, None, 5, False], + [120, None, 4, False], + [101, None, 101, False], ], ) @pytest.mark.parametrize("episode_key", ["episode", ("some", "episode")]) @@ -1853,6 +1855,7 @@ def test_slice_sampler( batch_size, num_slices, slice_len, + prioritized, episode_key, done_key, match_episode, @@ -1897,19 +1900,34 @@ def test_slice_sampler( else: strict_length = True - sampler = SliceSampler( - num_slices=num_slices, - traj_key=episode_key, - end_key=done_key, - slice_len=slice_len, - strict_length=strict_length, - ) + if prioritized: + num_steps = data.shape[0] + sampler = PrioritizedSliceSampler( + max_capacity=num_steps, + alpha=0.7, + beta=0.9, + num_slices=num_slices, + traj_key=episode_key, + end_key=done_key, + slice_len=slice_len, + strict_length=strict_length, + ) + index = torch.arange(0, num_steps, 1) + sampler.extend(index) + else: + sampler = SliceSampler( + num_slices=num_slices, + traj_key=episode_key, + end_key=done_key, + slice_len=slice_len, + strict_length=strict_length, + ) if slice_len is not None: num_slices = batch_size // slice_len trajs_unique_id = set() too_short = False count_unique = set() - for _ in range(10): + for _ in range(30): index, info = sampler.sample(storage, batch_size=batch_size) if _data_prefix: samples = storage._storage["_data"][index] @@ -1918,6 +1936,7 @@ def test_slice_sampler( if strict_length: # check that trajs are ok samples = samples.view(num_slices, -1) + assert samples["another_episode"].unique( dim=1 ).squeeze().shape == torch.Size([num_slices]) @@ -1936,6 +1955,7 @@ def test_slice_sampler( raise AssertionError( f"Not all items can be sampled: {set(range(100))-count_unique} are missing" ) + if strict_length: assert not too_short else: @@ -2071,6 +2091,107 @@ def test_slice_sampler_without_replacement( assert truncated.view(num_slices, -1)[:, -1].all() +def test_prioritized_slice_sampler_doc_example(): + sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9) + rb = TensorDictReplayBuffer( + storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6 + ) + data = TensorDict( + { + "observation": torch.randn(9, 16), + "action": torch.randn(9, 1), + "episode": torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=torch.long), + "steps": torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], dtype=torch.long), + ("next", "observation"): torch.randn(9, 16), + ("next", "reward"): torch.randn(9, 1), + ("next", "done"): torch.tensor( + [0, 0, 1, 0, 0, 1, 0, 0, 1], dtype=torch.bool + ).unsqueeze(1), + }, + batch_size=[9], + ) + rb.extend(data) + sample, info = rb.sample(return_info=True) + # print("episode", sample["episode"].tolist()) + # print("steps", sample["steps"].tolist()) + # print("weight", info["_weight"].tolist()) + + priority = torch.tensor([0, 3, 3, 0, 0, 0, 1, 1, 1]) + rb.update_priority(torch.arange(0, 9, 1), priority=priority) + sample, info = rb.sample(return_info=True) + # print("episode", sample["episode"].tolist()) + # print("steps", sample["steps"].tolist()) + # print("weight", info["_weight"].tolist()) + + +@pytest.mark.parametrize("device", get_default_devices()) +def test_prioritized_slice_sampler_episodes(device): + num_slices = 10 + batch_size = 20 + + episode = torch.zeros(100, dtype=torch.int, device=device) + episode[:30] = 1 + episode[30:55] = 2 + episode[55:70] = 3 + episode[70:] = 4 + steps = torch.cat( + [torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0 + ) + done = torch.zeros(100, 1, dtype=torch.bool) + done[torch.tensor([29, 54, 69])] = 1 + + data = TensorDict( + { + "observation": torch.randn(100, 16), + "action": torch.randn(100, 4), + "episode": episode, + "steps": steps, + ("next", "observation"): torch.randn(100, 16), + ("next", "reward"): torch.randn(100, 1), + ("next", "done"): done, + }, + batch_size=[100], + device=device, + ) + + num_steps = data.shape[0] + sampler = PrioritizedSliceSampler( + max_capacity=num_steps, + alpha=0.7, + beta=0.9, + num_slices=num_slices, + ) + + rb = TensorDictReplayBuffer( + storage=LazyMemmapStorage(100), + sampler=sampler, + batch_size=batch_size, + ) + rb.extend(data) + + episodes = [] + for _ in range(10): + sample = rb.sample() + episodes.append(sample["episode"]) + assert {1, 2, 3, 4} == set( + torch.cat(episodes).cpu().tolist() + ), "all episodes are expected to be sampled at least once" + + index = torch.arange(0, num_steps, 1) + new_priorities = torch.cat( + [torch.ones(30), torch.zeros(25), torch.ones(15), torch.zeros(30)], 0 + ) + sampler.update_priority(index, new_priorities) + + episodes = [] + for _ in range(10): + sample = rb.sample() + episodes.append(sample["episode"]) + assert {1, 3} == set( + torch.cat(episodes).cpu().tolist() + ), "after priority update, only episode 1 and 3 are expected to be sampled" + + class TestEnsemble: def _make_data(self, data_type): if data_type is torch.Tensor: diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 77e7501de0c..c9aadcb992b 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -13,6 +13,7 @@ ) from .samplers import ( PrioritizedSampler, + PrioritizedSliceSampler, RandomSampler, Sampler, SamplerEnsemble, diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 21245e37acd..96d73375ea9 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -596,7 +596,7 @@ class SliceSampler(Sampler): allowed to appear in the batch. Be mindful that this can result in effective `batch_size` shorter than the one asked for! Trajectories can be split using - :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. + :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``. .. note:: To recover the trajectory splits in the storage, :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first @@ -633,7 +633,7 @@ class SliceSampler(Sampler): >>> print("episodes", sample.get("episode").unique()) episodes tensor([1, 2, 3, 4], dtype=torch.int32) - :class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with + :class:`~torchrl.data.replay_buffers.SliceSampler` is default-compatible with most of TorchRL's datasets: Examples: @@ -1012,7 +1012,7 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): allowed to appear in the batch. Be mindful that this can result in effective `batch_size` shorter than the one asked for! Trajectories can be split using - :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. + :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``. shuffle (bool, optional): if ``False``, the order of the trajectories is not shuffled. Defaults to ``True``. @@ -1053,7 +1053,7 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): >>> print("sample:", sample) >>> print("trajectories in sample", sample.get("episode").unique()) - :class:`torchrl.data.replay_buffers.SliceSamplerWithoutReplacement` is default-compatible with + :class:`~torchrl.data.replay_buffers.SliceSamplerWithoutReplacement` is default-compatible with most of TorchRL's datasets, and allows users to consume datasets in a dataloader-like fashion: Examples: @@ -1129,6 +1129,244 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return SamplerWithoutReplacement.load_state_dict(self, state_dict) +class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler): + """Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling. + + This class samples sub-trajectories with replacement following a priority weighting presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. + Prioritized experience replay." + (https://arxiv.org/abs/1511.05952) + + For more info see :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` and :class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`. + + Args: + alpha (float): exponent α determines how much prioritization is used, + with α = 0 corresponding to the uniform case. + beta (float): importance sampling negative exponent. + eps (float, optional): delta added to the priorities to ensure that the buffer + does not contain null priorities. Defaults to 1e-8. + reduction (str, optional): the reduction method for multidimensional + tensordicts (i.e., stored trajectory). Can be one of "max", "min", + "median" or "mean". + + Keyword Args: + num_slices (int): the number of slices to be sampled. The batch-size + must be greater or equal to the ``num_slices`` argument. Exclusive + with ``slice_len``. + slice_len (int): the length of the slices to be sampled. The batch-size + must be greater or equal to the ``slice_len`` argument and divisible + by it. Exclusive with ``num_slices``. + end_key (NestedKey, optional): the key indicating the end of a + trajectory (or episode). Defaults to ``("next", "done")``. + traj_key (NestedKey, optional): the key indicating the trajectories. + Defaults to ``"episode"`` (commonly used across datasets in TorchRL). + ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. + trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. + cache_values (bool, optional): to be used with static datasets. + Will cache the start and end signal of the trajectory. + truncated_key (NestedKey, optional): If not ``None``, this argument + indicates where a truncated signal should be written in the output + data. This is used to indicate to value estimators where the provided + trajectory breaks. Defaults to ``("next", "truncated")``. + This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` + instances (otherwise the truncated key is returned in the info dictionary + returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method). + strict_length (bool, optional): if ``False``, trajectories of length + shorter than `slice_len` (or `batch_size // num_slices`) will be + allowed to appear in the batch. + Be mindful that this can result in effective `batch_size` shorter + than the one asked for! Trajectories can be split using + :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``. + + Examples: + >>> import torch + >>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler + >>> from tensordict import TensorDict + >>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9) + >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6) + >>> data = TensorDict( + ... { + ... "observation": torch.randn(9,16), + ... "action": torch.randn(9, 1), + ... "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long), + ... "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long), + ... ("next", "observation"): torch.randn(9,16), + ... ("next", "reward"): torch.randn(9,1), + ... ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1), + ... }, + ... batch_size=[9], + ... ) + >>> rb.extend(data) + >>> sample, info = rb.sample(return_info=True) + >>> print("episode", sample["episode"].tolist()) + episode [2, 2, 2, 2, 1, 1] + >>> print("steps", sample["steps"].tolist()) + steps [1, 2, 0, 1, 1, 2] + >>> print("weight", info["_weight"].tolist()) + weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + >>> priority = torch.tensor([0,3,3,0,0,0,1,1,1]) + >>> rb.update_priority(torch.arange(0,9,1), priority=priority) + >>> sample, info = rb.sample(return_info=True) + >>> print("episode", sample["episode"].tolist()) + episode [2, 2, 2, 2, 2, 2] + >>> print("steps", sample["steps"].tolist()) + steps [1, 2, 0, 1, 0, 1] + >>> print("weight", info["_weight"].tolist()) + weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06] + """ + + def __init__( + self, + max_capacity: int, + alpha: float, + beta: float, + eps: float = 1e-8, + dtype: torch.dtype = torch.float, + reduction: str = "max", + *, + num_slices: int = None, + slice_len: int = None, + end_key: NestedKey | None = None, + traj_key: NestedKey | None = None, + ends: torch.Tensor | None = None, + trajectories: torch.Tensor | None = None, + cache_values: bool = False, + truncated_key: NestedKey | None = ("next", "truncated"), + strict_length: bool = True, + ) -> object: + SliceSampler.__init__( + self, + num_slices=num_slices, + slice_len=slice_len, + end_key=end_key, + traj_key=traj_key, + cache_values=cache_values, + truncated_key=truncated_key, + strict_length=strict_length, + ends=ends, + trajectories=trajectories, + ) + PrioritizedSampler.__init__( + self, + max_capacity=max_capacity, + alpha=alpha, + beta=beta, + eps=eps, + dtype=dtype, + reduction=reduction, + ) + + def __getstate__(self): + state = SliceSampler.__getstate__(self) + state.update(PrioritizedSampler.__getstate__(self)) + + def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + # Sample `batch_size` indices representing the start of a slice. + # The sampling is based on a weight vector. + start_idx, stop_idx, lengths = self._get_stop_and_length(storage) + seq_length, num_slices = self._adjusted_batch_size(batch_size) + + num_trajs = lengths.shape[0] + traj_idx = torch.arange(0, num_trajs, 1, device=lengths.device) + + if (lengths < seq_length).any(): + if self.strict_length: + raise RuntimeError( + "Some stored trajectories have a length shorter than the slice that was asked for. " + "Create the sampler with `strict_length=False` to allow shorter trajectories to appear " + "in you batch." + ) + # make seq_length a tensor with values clamped by lengths + seq_length = lengths[traj_idx].clamp_max(seq_length) + + # build a list of index that we dont want to sample: all the steps at a `seq_length` distance of + # the end the trajectory, with the end of trajectory (`stop_idx`) included + if isinstance(seq_length, int): + subtractive_idx = torch.arange( + 0, seq_length - 1, 1, device=stop_idx.device, dtype=stop_idx.dtype + ) + preceding_stop_idx = ( + stop_idx[..., None] - subtractive_idx[None, ...] + ).view(-1) + else: + raise NotImplementedError("seq_length as a list is not supported for now") + + # force to not sample index at the end of a trajectory + self._sum_tree[preceding_stop_idx] = 0.0 + # and no need to update self._min_tree + + starts, info = PrioritizedSampler.sample( + self, storage=storage, batch_size=batch_size // seq_length + ) + # TODO: update PrioritizedSampler.sample to return torch tensors + starts = torch.as_tensor(starts, device=lengths.device) + info["_weight"] = torch.as_tensor(info["_weight"], device=lengths.device) + + # extends starting indices of each slice with sequence_length to get indices of all steps + index = self._tensor_slices_from_startend(seq_length, starts) + # repeat the weight of each slice to match the number of steps + info["_weight"] = torch.repeat_interleave(info["_weight"], seq_length) + + # sanity check + if index.shape[0] != batch_size: + raise ValueError( + f"Number of indices is expected to match the batch size ({index.shape[0]} != {batch_size})." + ) + + if self.truncated_key is not None: + # following logics borrowed from SliceSampler + truncated_key = self.truncated_key + done_key = _replace_last(truncated_key, "done") + terminated_key = _replace_last(truncated_key, "terminated") + + truncated = torch.zeros( + (*index.shape, 1), dtype=torch.bool, device=index.device + ) + if isinstance(seq_length, int): + truncated.view(num_slices, -1)[:, -1] = 1 + else: + truncated[seq_length.cumsum(0) - 1] = 1 + traj_terminated = stop_idx[traj_idx] == start_idx[traj_idx] + seq_length - 1 + terminated = torch.zeros_like(truncated) + if traj_terminated.any(): + if isinstance(seq_length, int): + truncated.view(num_slices, -1)[traj_terminated] = 1 + else: + truncated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1 + truncated = truncated & ~terminated + done = terminated | truncated + + info.update( + { + truncated_key: truncated, + done_key: done, + terminated_key: terminated, + } + ) + return index.to(torch.long), info + + def _empty(self): + # no op for SliceSampler + PrioritizedSampler._empty(self) + + def dumps(self, path): + # no op for SliceSampler + PrioritizedSampler.dumps(self, path) + + def loads(self, path): + # no op for SliceSampler + return PrioritizedSampler.loads(self, path) + + def state_dict(self): + # no op for SliceSampler + return PrioritizedSampler.state_dict(self) + + class SamplerEnsemble(Sampler): """An ensemble of samplers.