Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add PrioritizedSliceSampler #1875

Merged
merged 11 commits into from
Feb 7, 2024
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Feb 7, 2024
commit 8463e091748b1f7b5f1bb9b6ed7090e21dbccdd7
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ using the following components:

Sampler
PrioritizedSampler
PrioritizedSliceSampler
RandomSampler
SamplerWithoutReplacement
SliceSampler
Expand Down
24 changes: 12 additions & 12 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ class SliceSampler(Sampler):
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.

.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
Expand Down Expand Up @@ -583,7 +583,7 @@ class SliceSampler(Sampler):
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)

:class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
:class:`~torchrl.data.replay_buffers.SliceSampler` is default-compatible with
most of TorchRL's datasets:

Examples:
Expand Down Expand Up @@ -962,7 +962,7 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
shuffle (bool, optional): if ``False``, the order of the trajectories
is not shuffled. Defaults to ``True``.

Expand Down Expand Up @@ -1003,7 +1003,7 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
>>> print("sample:", sample)
>>> print("trajectories in sample", sample.get("episode").unique())

:class:`torchrl.data.replay_buffers.SliceSamplerWithoutReplacement` is default-compatible with
:class:`~torchrl.data.replay_buffers.SliceSamplerWithoutReplacement` is default-compatible with
most of TorchRL's datasets, and allows users to consume datasets in a dataloader-like fashion:

Examples:
Expand Down Expand Up @@ -1131,7 +1131,7 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.

Examples:
>>> import torch
Expand All @@ -1154,19 +1154,19 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
>>> rb.extend(data)
>>> sample, info = rb.sample(return_info=True)
>>> print("episode", sample["episode"].tolist())
episode [2, 2, 2, 2, 1, 1]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 1, 2]
>>> print("weight", info["_weight"].tolist())
weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
>>> priority = torch.tensor([0,3,3,0,0,0,1,1,1])
>>> rb.update_priority(torch.arange(0,9,1), priority=priority)
>>> sample, info = rb.sample(return_info=True)
>>> print("episode", sample["episode"].tolist())
>>> print("steps", sample["steps"].tolist())
>>> print("weight", info["_weight"].tolist())
episode [2, 2, 2, 2, 1, 1]
steps [1, 2, 0, 1, 1, 2]
weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
episode [2, 2, 2, 2, 2, 2]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 0, 1]
>>> print("weight", info["_weight"].tolist())
weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
"""

Expand Down Expand Up @@ -1254,8 +1254,8 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
self, storage=storage, batch_size=batch_size // seq_length
)
# TODO: update PrioritizedSampler.sample to return torch tensors
starts = torch.from_numpy(starts).to(device=lengths.device)
info["_weight"] = torch.from_numpy(info["_weight"]).to(device=lengths.device)
starts = torch.as_tensor(starts, device=lengths.device)
info["_weight"] = torch.as_tensor(info["_weight"], device=lengths.device)

# extends starting indices of each slice with sequence_length to get indices of all steps
index = self._tensor_slices_from_startend(seq_length, starts)
Expand Down