Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] SliceSampler #1748

Merged
merged 10 commits into from
Dec 19, 2023
Merged

[Feature] SliceSampler #1748

merged 10 commits into from
Dec 19, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Dec 15, 2023

In this PR I propose a new sampler that samples slices of data in a storage.

Setting

Let RB be a replay buffer of capacity M (for simplicity the replay buffer is full).
We have N trajectories stored in this buffer, with length T_{min} < T_i < T_{max} for all T_i for i in 1:N.
We want to sample batches of data of some arbitrary length such that the batch size B = S * U (potentially with a tensordict of shape [S, U] or [S * U], TBD) where S is the number of slices and U a fixed sub-trajectory length.
Assume U > T_{min}.
The slices sampled from the RB can start anywhere from t_start=0 to t_start = T_i - U and end at t_start + U.

## Usage

In many cases, users want to sample sub-trajectories out of complete trajectories. Those sub-trajectories should generally not start at the beginning of the trajectory or end at the end (hence, we say that users want slices of trajectories). Typical usages are:

  • RNN/Transformer-based policies
  • Offline RL datasets

Proposed class

The SliceSampler we propose here does exactly what was suggested above.

To keep the batch-size unchanged, the user has to specify how many slices she wants out of the batch-size (S). Alternatively, we could ask for a trajectory length (U) (or one of these but not both).
Eg,

class SliceSampler(Sampler):
    def __init__(self, ..., num_slices=None, traj_len=None): # either traj_len or num_slices must be non-none, but not both
       ...

We can sample trajectories when we know the start and stop signals. To gather these, users can pass either the key pointing to the trajectory indicators traj_key or a done state done_key.

Note: passing "done" can be dangerous since if one gets an incomplete trajectory, it will not be marked as "done" at the last step! Using the "truncated" signal from the collector could be an alternative, but in general it should be safer to use the trajectory indicator. The user is responsible of making sure that two consecutive trajectories do not share the same indicator.

Computing the start and stop signals can be expensive, and for static datasets one could cache the start and stop signals. This speeds up sampling by an important factor (0.3 ms with caching vs 1ms on my computer)

Example code:

import timeit
import torch
import tqdm

from tensordict import TensorDict
from torchrl.data.replay_buffers import LazyMemmapStorage, \
    TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler

prefetch = 0
for cache in [True, False]:
    rb = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(1_000_000),
        sampler=SliceSampler(cache_values=cache, num_slices=10),
        batch_size=320,
        prefetch=prefetch
        )
    episode = torch.zeros(1000, dtype=torch.int)
    episode[:300] = 1
    episode[300:550] = 2
    episode[550:700] = 3
    episode[700:] = 4
    data = TensorDict(
        {
            "episode": episode,
            "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
            "act": torch.randn((20,)).expand(1000, 20),
            "other": torch.randn((20, 50)).expand(1000, 20, 50),
        }, [1000]
    )
    for _ in tqdm.tqdm(range(1000)):
        rb.extend(data)

    print(prefetch, min(*timeit.repeat("rb.sample()", globals={"rb": rb}, number=1000)))

Results in 0.3468737329999989 ms with caching, 1.150314997999999 without

Why is this PR important

This PR paves the way for a new API to store trajectories, where one can flatten the data coming from the collector and extend the RB with it. Storing trajectories of different lengths will be easy, and sampling trajectories should be cheaper. This will also participate in making the dataset and dynamic RB APIs uniform.
Pseudocode:

for data in collector: # data is of shape [C, D]
    replay_buffer.extend(data.view(-1)) # adds the C consecutive trajectories of length D
    samples = replay_buffer.sample(B) # B elements divided in S slices of length U
    samples = samples.view(S, U)

Ultimately, we will be able to deprecate the hacky "_data" key in TensorDictReplayBuffers that is bothering us so much!

cc: @matteobettini @albertbou92 @BY571 @skandermoalla @nicklashansen @truncs @btx0424 @smorad @vikashplus

Question

  • Should batch_size indicate the number of slices (ie, trajectories) sampled? In other words: Should users ask for B or S. Pros for B: clearer API, I ask for B transitions I get B transitions, whatever the sampler. Pros for S: users may ask for a certain "number of items" referring items to the trajectories. I'm still in favour of B because otherwise we assume that the sampler impacts the shape of the samples, which isn't great for composability (swapping samples will change the amount of data consumed, or UTD of the algo).
  • Related: Should we return batches of size [S, U] or [S * U]? I like the second best for the same reason as before (composability).
  • Last one: what do we do if some trajectories have an insufficient length (T_min > U)? My take: there should be an arg pad in the constructor, None by default. If pad=True, we pad with 0s. If pad=smth and smth is not None, we pas with smth. If pad=None (default) an exception is raised if T_min < U (or cheaper option: if T_i < U is encountered at some point during training.
  • Should we give a device to the sampler to speed up the trajectory recovering (with cuda)?

Related issues
#1671
#1079

TODO:

  • Doc
  • Tests
  • Allow padding => @matteobettini Samplers cannot do padding (they just return indices) so we just return fewer indices if trajectories don't comply with minimum length whenever strict_length=False.
  • allow sampling without replacement (what does that look like?)
  • sampling each transition with equal probability? => NOT PLANNED
  • prepare ordered dataset iterations => follow-up PR

Copy link

pytorch-bot bot commented Dec 15, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1748

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (8 Unrelated Failures)

As of commit 4873db8 with merge base 08f0bed (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 15, 2023
Copy link

github-actions bot commented Dec 15, 2023

$\color{#D29922}\textsf{\Large&amp;#x26A0;\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}22$. Worsened: $\large\color{#d91a1a}4$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 65.0710ms 64.1574ms 15.5867 Ops/s 14.9837 Ops/s $\color{#35bf28}+4.02\%$
test_sync 42.0395ms 35.5734ms 28.1109 Ops/s 26.9369 Ops/s $\color{#35bf28}+4.36\%$
test_async 0.1081s 34.1146ms 29.3129 Ops/s 29.5694 Ops/s $\color{#d91a1a}-0.87\%$
test_simple 0.4977s 0.4424s 2.2603 Ops/s 2.1905 Ops/s $\color{#35bf28}+3.19\%$
test_transformed 0.6587s 0.6067s 1.6483 Ops/s 1.6365 Ops/s $\color{#35bf28}+0.72\%$
test_serial 1.4057s 1.3485s 0.7416 Ops/s 0.7355 Ops/s $\color{#35bf28}+0.83\%$
test_parallel 1.3691s 1.3233s 0.7557 Ops/s 0.7536 Ops/s $\color{#35bf28}+0.28\%$
test_step_mdp_speed[True-True-True-True-True] 0.1880ms 22.0712μs 45.3080 KOps/s 45.3873 KOps/s $\color{#d91a1a}-0.17\%$
test_step_mdp_speed[True-True-True-True-False] 84.0770μs 13.7259μs 72.8550 KOps/s 74.6489 KOps/s $\color{#d91a1a}-2.40\%$
test_step_mdp_speed[True-True-True-False-True] 57.8380μs 13.1234μs 76.1998 KOps/s 76.6076 KOps/s $\color{#d91a1a}-0.53\%$
test_step_mdp_speed[True-True-True-False-False] 33.9040μs 7.8888μs 126.7616 KOps/s 125.9867 KOps/s $\color{#35bf28}+0.62\%$
test_step_mdp_speed[True-True-False-True-True] 75.8120μs 23.6473μs 42.2882 KOps/s 42.9597 KOps/s $\color{#d91a1a}-1.56\%$
test_step_mdp_speed[True-True-False-True-False] 41.0770μs 14.8439μs 67.3679 KOps/s 68.2513 KOps/s $\color{#d91a1a}-1.29\%$
test_step_mdp_speed[True-True-False-False-True] 56.5860μs 14.5635μs 68.6649 KOps/s 69.7551 KOps/s $\color{#d91a1a}-1.56\%$
test_step_mdp_speed[True-True-False-False-False] 39.1530μs 9.2205μs 108.4543 KOps/s 109.1224 KOps/s $\color{#d91a1a}-0.61\%$
test_step_mdp_speed[True-False-True-True-True] 68.6090μs 24.7637μs 40.3817 KOps/s 40.0450 KOps/s $\color{#35bf28}+0.84\%$
test_step_mdp_speed[True-False-True-True-False] 70.7830μs 16.1755μs 61.8220 KOps/s 62.0209 KOps/s $\color{#d91a1a}-0.32\%$
test_step_mdp_speed[True-False-True-False-True] 44.9950μs 14.3922μs 69.4823 KOps/s 69.5526 KOps/s $\color{#d91a1a}-0.10\%$
test_step_mdp_speed[True-False-True-False-False] 44.4530μs 9.1779μs 108.9574 KOps/s 108.0292 KOps/s $\color{#35bf28}+0.86\%$
test_step_mdp_speed[True-False-False-True-True] 53.0590μs 26.2125μs 38.1498 KOps/s 38.7942 KOps/s $\color{#d91a1a}-1.66\%$
test_step_mdp_speed[True-False-False-True-False] 58.1490μs 17.4449μs 57.3233 KOps/s 58.0037 KOps/s $\color{#d91a1a}-1.17\%$
test_step_mdp_speed[True-False-False-False-True] 57.3640μs 15.5519μs 64.3006 KOps/s 64.2005 KOps/s $\color{#35bf28}+0.16\%$
test_step_mdp_speed[True-False-False-False-False] 44.7940μs 10.3652μs 96.4770 KOps/s 96.4418 KOps/s $\color{#35bf28}+0.04\%$
test_step_mdp_speed[False-True-True-True-True] 61.0740μs 24.8095μs 40.3071 KOps/s 40.1460 KOps/s $\color{#35bf28}+0.40\%$
test_step_mdp_speed[False-True-True-True-False] 66.1130μs 16.0922μs 62.1420 KOps/s 61.7615 KOps/s $\color{#35bf28}+0.62\%$
test_step_mdp_speed[False-True-True-False-True] 59.1710μs 16.4873μs 60.6527 KOps/s 59.3049 KOps/s $\color{#35bf28}+2.27\%$
test_step_mdp_speed[False-True-True-False-False] 64.4310μs 10.6143μs 94.2126 KOps/s 94.8767 KOps/s $\color{#d91a1a}-0.70\%$
test_step_mdp_speed[False-True-False-True-True] 59.1400μs 25.8524μs 38.6811 KOps/s 38.4553 KOps/s $\color{#35bf28}+0.59\%$
test_step_mdp_speed[False-True-False-True-False] 50.8150μs 17.3596μs 57.6050 KOps/s 57.7536 KOps/s $\color{#d91a1a}-0.26\%$
test_step_mdp_speed[False-True-False-False-True] 47.5690μs 17.6485μs 56.6621 KOps/s 56.3088 KOps/s $\color{#35bf28}+0.63\%$
test_step_mdp_speed[False-True-False-False-False] 44.0130μs 11.5829μs 86.3339 KOps/s 86.2345 KOps/s $\color{#35bf28}+0.12\%$
test_step_mdp_speed[False-False-True-True-True] 68.1370μs 27.0383μs 36.9845 KOps/s 36.2545 KOps/s $\color{#35bf28}+2.01\%$
test_step_mdp_speed[False-False-True-True-False] 49.3620μs 18.5440μs 53.9258 KOps/s 53.5842 KOps/s $\color{#35bf28}+0.64\%$
test_step_mdp_speed[False-False-True-False-True] 45.5950μs 17.5805μs 56.8812 KOps/s 55.8286 KOps/s $\color{#35bf28}+1.89\%$
test_step_mdp_speed[False-False-True-False-False] 51.2150μs 11.6069μs 86.1556 KOps/s 86.0247 KOps/s $\color{#35bf28}+0.15\%$
test_step_mdp_speed[False-False-False-True-True] 0.1359ms 28.6097μs 34.9532 KOps/s 34.5599 KOps/s $\color{#35bf28}+1.14\%$
test_step_mdp_speed[False-False-False-True-False] 0.1503ms 20.2771μs 49.3168 KOps/s 50.6023 KOps/s $\color{#d91a1a}-2.54\%$
test_step_mdp_speed[False-False-False-False-True] 57.5880μs 18.7347μs 53.3770 KOps/s 52.8422 KOps/s $\color{#35bf28}+1.01\%$
test_step_mdp_speed[False-False-False-False-False] 41.6080μs 12.8548μs 77.7918 KOps/s 78.2506 KOps/s $\color{#d91a1a}-0.59\%$
test_values[generalized_advantage_estimate-True-True] 12.9234ms 11.9895ms 83.4064 Ops/s 82.1946 Ops/s $\color{#35bf28}+1.47\%$
test_values[vec_generalized_advantage_estimate-True-True] 36.5556ms 28.3579ms 35.2636 Ops/s 34.2969 Ops/s $\color{#35bf28}+2.82\%$
test_values[td0_return_estimate-False-False] 0.2558ms 0.1772ms 5.6433 KOps/s 4.6591 KOps/s $\textbf{\color{#35bf28}+21.12\%}$
test_values[td1_return_estimate-False-False] 31.2613ms 25.3437ms 39.4575 Ops/s 38.9827 Ops/s $\color{#35bf28}+1.22\%$
test_values[vec_td1_return_estimate-False-False] 36.7277ms 28.3644ms 35.2555 Ops/s 33.8968 Ops/s $\color{#35bf28}+4.01\%$
test_values[td_lambda_return_estimate-True-False] 35.9065ms 35.4171ms 28.2349 Ops/s 27.9556 Ops/s $\color{#35bf28}+1.00\%$
test_values[vec_td_lambda_return_estimate-True-False] 36.2608ms 28.3980ms 35.2137 Ops/s 33.1357 Ops/s $\textbf{\color{#35bf28}+6.27\%}$
test_gae_speed[generalized_advantage_estimate-False-1-512] 8.3498ms 8.1075ms 123.3430 Ops/s 124.8975 Ops/s $\color{#d91a1a}-1.24\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 2.1765ms 1.9550ms 511.4964 Ops/s 516.2741 Ops/s $\color{#d91a1a}-0.93\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.5480ms 0.4253ms 2.3514 KOps/s 2.3346 KOps/s $\color{#35bf28}+0.72\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 46.5582ms 39.8817ms 25.0742 Ops/s 24.9249 Ops/s $\color{#35bf28}+0.60\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 10.8154ms 2.6594ms 376.0313 Ops/s 335.3006 Ops/s $\textbf{\color{#35bf28}+12.15\%}$
test_dqn_speed 10.3445ms 1.6359ms 611.2783 Ops/s 492.3776 Ops/s $\textbf{\color{#35bf28}+24.15\%}$
test_ddpg_speed 15.0446ms 3.9808ms 251.2044 Ops/s 237.6913 Ops/s $\textbf{\color{#35bf28}+5.69\%}$
test_sac_speed 82.1696ms 11.0435ms 90.5509 Ops/s 90.8725 Ops/s $\color{#d91a1a}-0.35\%$
test_redq_speed 27.6609ms 19.2774ms 51.8742 Ops/s 49.1625 Ops/s $\textbf{\color{#35bf28}+5.52\%}$
test_redq_deprec_speed 24.5523ms 15.3599ms 65.1046 Ops/s 59.0505 Ops/s $\textbf{\color{#35bf28}+10.25\%}$
test_td3_speed 18.0884ms 10.5367ms 94.9063 Ops/s 90.0208 Ops/s $\textbf{\color{#35bf28}+5.43\%}$
test_cql_speed 50.0001ms 40.7859ms 24.5183 Ops/s 23.2464 Ops/s $\textbf{\color{#35bf28}+5.47\%}$
test_a2c_speed 16.4288ms 8.1615ms 122.5269 Ops/s 108.4323 Ops/s $\textbf{\color{#35bf28}+13.00\%}$
test_ppo_speed 17.4866ms 8.6269ms 115.9166 Ops/s 106.7517 Ops/s $\textbf{\color{#35bf28}+8.59\%}$
test_reinforce_speed 23.4057ms 7.3743ms 135.6053 Ops/s 122.8620 Ops/s $\textbf{\color{#35bf28}+10.37\%}$
test_iql_speed 43.4784ms 34.9356ms 28.6241 Ops/s 27.3183 Ops/s $\color{#35bf28}+4.78\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.4969ms 1.9351ms 516.7695 Ops/s 490.7816 Ops/s $\textbf{\color{#35bf28}+5.30\%}$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.1055s 2.2785ms 438.8839 Ops/s 451.3675 Ops/s $\color{#d91a1a}-2.77\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 3.4987ms 2.0567ms 486.2260 Ops/s 442.5949 Ops/s $\textbf{\color{#35bf28}+9.86\%}$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 2.8800ms 1.9288ms 518.4527 Ops/s 465.9175 Ops/s $\textbf{\color{#35bf28}+11.28\%}$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.1061s 2.2861ms 437.4188 Ops/s 459.2943 Ops/s $\color{#d91a1a}-4.76\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 3.0598ms 2.0427ms 489.5538 Ops/s 450.8824 Ops/s $\textbf{\color{#35bf28}+8.58\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.3424ms 1.9207ms 520.6506 Ops/s 485.3388 Ops/s $\textbf{\color{#35bf28}+7.28\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.1190s 2.3551ms 424.6170 Ops/s 454.2413 Ops/s $\textbf{\color{#d91a1a}-6.52\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 4.2575ms 2.1608ms 462.7854 Ops/s 458.4130 Ops/s $\color{#35bf28}+0.95\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.8742ms 1.9847ms 503.8461 Ops/s 489.6512 Ops/s $\color{#35bf28}+2.90\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.1099s 2.3102ms 432.8675 Ops/s 458.1483 Ops/s $\textbf{\color{#d91a1a}-5.52\%}$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 3.9600ms 2.0906ms 478.3430 Ops/s 457.4196 Ops/s $\color{#35bf28}+4.57\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 2.5525ms 1.9327ms 517.4146 Ops/s 488.7634 Ops/s $\textbf{\color{#35bf28}+5.86\%}$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.1124s 2.2839ms 437.8481 Ops/s 452.9501 Ops/s $\color{#d91a1a}-3.33\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 2.7628ms 2.0429ms 489.4945 Ops/s 457.0089 Ops/s $\textbf{\color{#35bf28}+7.11\%}$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 2.6181ms 1.9112ms 523.2325 Ops/s 489.2612 Ops/s $\textbf{\color{#35bf28}+6.94\%}$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.1218s 2.3134ms 432.2694 Ops/s 455.6280 Ops/s $\textbf{\color{#d91a1a}-5.13\%}$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 3.1747ms 2.0722ms 482.5713 Ops/s 450.4681 Ops/s $\textbf{\color{#35bf28}+7.13\%}$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1556s 17.3159ms 57.7505 Ops/s 55.6390 Ops/s $\color{#35bf28}+3.80\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1095s 16.4016ms 60.9696 Ops/s 59.8553 Ops/s $\color{#35bf28}+1.86\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1116s 16.2379ms 61.5842 Ops/s 59.2309 Ops/s $\color{#35bf28}+3.97\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1049s 16.0954ms 62.1294 Ops/s 59.2539 Ops/s $\color{#35bf28}+4.85\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1074s 16.1461ms 61.9343 Ops/s 58.7321 Ops/s $\textbf{\color{#35bf28}+5.45\%}$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1112s 16.2090ms 61.6940 Ops/s 59.0834 Ops/s $\color{#35bf28}+4.42\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1090s 16.2836ms 61.4116 Ops/s 59.4900 Ops/s $\color{#35bf28}+3.23\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1147s 16.5181ms 60.5396 Ops/s 59.4017 Ops/s $\color{#35bf28}+1.92\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1095s 18.2079ms 54.9212 Ops/s 59.0676 Ops/s $\textbf{\color{#d91a1a}-7.02\%}$

Copy link

github-actions bot commented Dec 15, 2023

$\color{#D29922}\textsf{\Large&amp;#x26A0;\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 92. Improved: $\large\color{#35bf28}10$. Worsened: $\large\color{#d91a1a}3$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1246s 0.1219s 8.2010 Ops/s 7.9431 Ops/s $\color{#35bf28}+3.25\%$
test_sync 0.1049s 0.1028s 9.7269 Ops/s 9.7026 Ops/s $\color{#35bf28}+0.25\%$
test_async 0.2675s 0.1010s 9.9003 Ops/s 9.9096 Ops/s $\color{#d91a1a}-0.09\%$
test_single_pixels 0.1461s 0.1454s 6.8786 Ops/s 7.3910 Ops/s $\textbf{\color{#d91a1a}-6.93\%}$
test_sync_pixels 98.1256ms 95.7747ms 10.4412 Ops/s 10.4793 Ops/s $\color{#d91a1a}-0.36\%$
test_async_pixels 0.1665s 89.6093ms 11.1596 Ops/s 10.7770 Ops/s $\color{#35bf28}+3.55\%$
test_simple 0.9487s 0.8776s 1.1395 Ops/s 1.0840 Ops/s $\textbf{\color{#35bf28}+5.11\%}$
test_transformed 1.1838s 1.1181s 0.8944 Ops/s 0.8591 Ops/s $\color{#35bf28}+4.10\%$
test_serial 2.5069s 2.4444s 0.4091 Ops/s 0.3798 Ops/s $\textbf{\color{#35bf28}+7.72\%}$
test_parallel 2.5462s 2.4662s 0.4055 Ops/s 0.4017 Ops/s $\color{#35bf28}+0.94\%$
test_step_mdp_speed[True-True-True-True-True] 0.1193ms 33.0972μs 30.2141 KOps/s 30.0333 KOps/s $\color{#35bf28}+0.60\%$
test_step_mdp_speed[True-True-True-True-False] 43.8600μs 20.0303μs 49.9244 KOps/s 51.3004 KOps/s $\color{#d91a1a}-2.68\%$
test_step_mdp_speed[True-True-True-False-True] 66.5010μs 19.3010μs 51.8107 KOps/s 52.8903 KOps/s $\color{#d91a1a}-2.04\%$
test_step_mdp_speed[True-True-True-False-False] 45.6000μs 11.1996μs 89.2886 KOps/s 89.4326 KOps/s $\color{#d91a1a}-0.16\%$
test_step_mdp_speed[True-True-False-True-True] 0.5835ms 34.7573μs 28.7709 KOps/s 28.8108 KOps/s $\color{#d91a1a}-0.14\%$
test_step_mdp_speed[True-True-False-True-False] 51.9510μs 21.4404μs 46.6409 KOps/s 46.5202 KOps/s $\color{#35bf28}+0.26\%$
test_step_mdp_speed[True-True-False-False-True] 98.9110μs 20.3560μs 49.1257 KOps/s 47.9591 KOps/s $\color{#35bf28}+2.43\%$
test_step_mdp_speed[True-True-False-False-False] 29.4510μs 13.1495μs 76.0488 KOps/s 76.3411 KOps/s $\color{#d91a1a}-0.38\%$
test_step_mdp_speed[True-False-True-True-True] 65.6310μs 36.4504μs 27.4346 KOps/s 27.2728 KOps/s $\color{#35bf28}+0.59\%$
test_step_mdp_speed[True-False-True-True-False] 43.9110μs 23.2849μs 42.9463 KOps/s 42.7782 KOps/s $\color{#35bf28}+0.39\%$
test_step_mdp_speed[True-False-True-False-True] 65.9600μs 20.5960μs 48.5531 KOps/s 47.7098 KOps/s $\color{#35bf28}+1.77\%$
test_step_mdp_speed[True-False-True-False-False] 95.5020μs 13.1031μs 76.3177 KOps/s 76.4655 KOps/s $\color{#d91a1a}-0.19\%$
test_step_mdp_speed[True-False-False-True-True] 0.1183ms 37.9834μs 26.3273 KOps/s 25.8915 KOps/s $\color{#35bf28}+1.68\%$
test_step_mdp_speed[True-False-False-True-False] 58.1300μs 25.7798μs 38.7901 KOps/s 39.5577 KOps/s $\color{#d91a1a}-1.94\%$
test_step_mdp_speed[True-False-False-False-True] 0.1005ms 22.5626μs 44.3211 KOps/s 43.3091 KOps/s $\color{#35bf28}+2.34\%$
test_step_mdp_speed[True-False-False-False-False] 40.5310μs 15.0492μs 66.4485 KOps/s 67.2813 KOps/s $\color{#d91a1a}-1.24\%$
test_step_mdp_speed[False-True-True-True-True] 0.1164ms 37.1049μs 26.9506 KOps/s 27.2629 KOps/s $\color{#d91a1a}-1.15\%$
test_step_mdp_speed[False-True-True-True-False] 0.1037ms 24.2309μs 41.2697 KOps/s 42.9448 KOps/s $\color{#d91a1a}-3.90\%$
test_step_mdp_speed[False-True-True-False-True] 0.1035ms 24.9775μs 40.0360 KOps/s 40.1222 KOps/s $\color{#d91a1a}-0.22\%$
test_step_mdp_speed[False-True-True-False-False] 44.0710μs 15.4289μs 64.8136 KOps/s 67.7241 KOps/s $\color{#d91a1a}-4.30\%$
test_step_mdp_speed[False-True-False-True-True] 0.1334ms 38.7350μs 25.8165 KOps/s 26.0861 KOps/s $\color{#d91a1a}-1.03\%$
test_step_mdp_speed[False-True-False-True-False] 0.1056ms 25.6181μs 39.0349 KOps/s 39.3748 KOps/s $\color{#d91a1a}-0.86\%$
test_step_mdp_speed[False-True-False-False-True] 41.5800μs 26.7141μs 37.4334 KOps/s 37.8983 KOps/s $\color{#d91a1a}-1.23\%$
test_step_mdp_speed[False-True-False-False-False] 38.2500μs 17.3252μs 57.7194 KOps/s 60.5315 KOps/s $\color{#d91a1a}-4.65\%$
test_step_mdp_speed[False-False-True-True-True] 0.1201ms 40.7781μs 24.5230 KOps/s 24.8190 KOps/s $\color{#d91a1a}-1.19\%$
test_step_mdp_speed[False-False-True-True-False] 0.1096ms 27.4695μs 36.4041 KOps/s 36.7787 KOps/s $\color{#d91a1a}-1.02\%$
test_step_mdp_speed[False-False-True-False-True] 0.1049ms 27.0117μs 37.0210 KOps/s 37.7285 KOps/s $\color{#d91a1a}-1.88\%$
test_step_mdp_speed[False-False-True-False-False] 54.3120μs 17.2816μs 57.8650 KOps/s 59.9462 KOps/s $\color{#d91a1a}-3.47\%$
test_step_mdp_speed[False-False-False-True-True] 0.1192ms 41.9217μs 23.8540 KOps/s 23.8650 KOps/s $\color{#d91a1a}-0.05\%$
test_step_mdp_speed[False-False-False-True-False] 0.1140ms 29.5607μs 33.8287 KOps/s 34.6328 KOps/s $\color{#d91a1a}-2.32\%$
test_step_mdp_speed[False-False-False-False-True] 0.1146ms 28.2071μs 35.4521 KOps/s 36.0110 KOps/s $\color{#d91a1a}-1.55\%$
test_step_mdp_speed[False-False-False-False-False] 41.8010μs 19.0483μs 52.4981 KOps/s 53.6349 KOps/s $\color{#d91a1a}-2.12\%$
test_values[generalized_advantage_estimate-True-True] 26.4005ms 25.6301ms 39.0167 Ops/s 38.5002 Ops/s $\color{#35bf28}+1.34\%$
test_values[vec_generalized_advantage_estimate-True-True] 95.5758ms 3.5033ms 285.4482 Ops/s 298.4492 Ops/s $\color{#d91a1a}-4.36\%$
test_values[td0_return_estimate-False-False] 0.1020ms 67.8373μs 14.7412 KOps/s 14.8657 KOps/s $\color{#d91a1a}-0.84\%$
test_values[td1_return_estimate-False-False] 55.9039ms 55.2213ms 18.1090 Ops/s 17.2816 Ops/s $\color{#35bf28}+4.79\%$
test_values[vec_td1_return_estimate-False-False] 2.0138ms 1.7948ms 557.1677 Ops/s 553.8103 Ops/s $\color{#35bf28}+0.61\%$
test_values[td_lambda_return_estimate-True-False] 91.1868ms 88.4846ms 11.3014 Ops/s 10.5348 Ops/s $\textbf{\color{#35bf28}+7.28\%}$
test_values[vec_td_lambda_return_estimate-True-False] 2.0675ms 1.7920ms 558.0209 Ops/s 556.0409 Ops/s $\color{#35bf28}+0.36\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 24.5290ms 24.3610ms 41.0492 Ops/s 39.9058 Ops/s $\color{#35bf28}+2.87\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 0.9164ms 0.7371ms 1.3567 KOps/s 1.2855 KOps/s $\textbf{\color{#35bf28}+5.54\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.8015ms 0.6854ms 1.4589 KOps/s 1.4452 KOps/s $\color{#35bf28}+0.95\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 1.5173ms 1.4837ms 673.9855 Ops/s 670.5128 Ops/s $\color{#35bf28}+0.52\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 0.9770ms 0.7083ms 1.4118 KOps/s 1.3934 KOps/s $\color{#35bf28}+1.32\%$
test_dqn_speed 7.7732ms 1.4388ms 695.0279 Ops/s 666.2511 Ops/s $\color{#35bf28}+4.32\%$
test_ddpg_speed 4.2438ms 3.2482ms 307.8592 Ops/s 296.2756 Ops/s $\color{#35bf28}+3.91\%$
test_sac_speed 9.5212ms 9.0320ms 110.7180 Ops/s 106.1826 Ops/s $\color{#35bf28}+4.27\%$
test_redq_speed 16.6942ms 16.2122ms 61.6818 Ops/s 59.6559 Ops/s $\color{#35bf28}+3.40\%$
test_redq_deprec_speed 0.1007s 13.8256ms 72.3296 Ops/s 75.7349 Ops/s $\color{#d91a1a}-4.50\%$
test_td3_speed 9.1905ms 9.0888ms 110.0258 Ops/s 103.5137 Ops/s $\textbf{\color{#35bf28}+6.29\%}$
test_cql_speed 0.1308s 36.1243ms 27.6822 Ops/s 28.6855 Ops/s $\color{#d91a1a}-3.50\%$
test_a2c_speed 8.0580ms 6.9699ms 143.4733 Ops/s 135.1024 Ops/s $\textbf{\color{#35bf28}+6.20\%}$
test_ppo_speed 8.3964ms 7.2901ms 137.1715 Ops/s 129.4667 Ops/s $\textbf{\color{#35bf28}+5.95\%}$
test_reinforce_speed 7.1068ms 5.9862ms 167.0515 Ops/s 158.3778 Ops/s $\textbf{\color{#35bf28}+5.48\%}$
test_iql_speed 28.0464ms 26.6794ms 37.4821 Ops/s 36.4371 Ops/s $\color{#35bf28}+2.87\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.1550ms 2.4662ms 405.4852 Ops/s 409.4048 Ops/s $\color{#d91a1a}-0.96\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 3.7486ms 2.6853ms 372.4042 Ops/s 328.9257 Ops/s $\textbf{\color{#35bf28}+13.22\%}$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 3.9077ms 2.6706ms 374.4531 Ops/s 372.9386 Ops/s $\color{#35bf28}+0.41\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.2443ms 2.4302ms 411.4816 Ops/s 406.0514 Ops/s $\color{#35bf28}+1.34\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 4.3427ms 2.6811ms 372.9881 Ops/s 372.4165 Ops/s $\color{#35bf28}+0.15\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 3.7170ms 2.6695ms 374.6016 Ops/s 374.0657 Ops/s $\color{#35bf28}+0.14\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.0255ms 2.4612ms 406.2998 Ops/s 405.7137 Ops/s $\color{#35bf28}+0.14\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 4.3249ms 2.6750ms 373.8351 Ops/s 374.4015 Ops/s $\color{#d91a1a}-0.15\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 4.1579ms 2.6799ms 373.1425 Ops/s 374.5398 Ops/s $\color{#d91a1a}-0.37\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.6779ms 2.4442ms 409.1264 Ops/s 402.7058 Ops/s $\color{#35bf28}+1.59\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 4.1406ms 2.6693ms 374.6280 Ops/s 372.7792 Ops/s $\color{#35bf28}+0.50\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.1042ms 2.6861ms 372.2851 Ops/s 372.7166 Ops/s $\color{#d91a1a}-0.12\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 2.6830ms 2.4496ms 408.2348 Ops/s 405.0727 Ops/s $\color{#35bf28}+0.78\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 4.0748ms 2.6687ms 374.7110 Ops/s 373.0454 Ops/s $\color{#35bf28}+0.45\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 3.7111ms 2.6881ms 372.0078 Ops/s 373.3318 Ops/s $\color{#d91a1a}-0.35\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.1526ms 2.4631ms 405.9916 Ops/s 406.3100 Ops/s $\color{#d91a1a}-0.08\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 3.5821ms 2.6689ms 374.6842 Ops/s 374.1866 Ops/s $\color{#35bf28}+0.13\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 3.8761ms 2.6787ms 373.3094 Ops/s 372.9310 Ops/s $\color{#35bf28}+0.10\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1897s 18.8125ms 53.1562 Ops/s 52.9548 Ops/s $\color{#35bf28}+0.38\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1200s 17.2071ms 58.1154 Ops/s 57.3534 Ops/s $\color{#35bf28}+1.33\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1238s 17.2749ms 57.8874 Ops/s 57.3995 Ops/s $\color{#35bf28}+0.85\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1228s 15.1579ms 65.9723 Ops/s 57.4698 Ops/s $\textbf{\color{#35bf28}+14.79\%}$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1196s 17.2124ms 58.0975 Ops/s 65.3411 Ops/s $\textbf{\color{#d91a1a}-11.09\%}$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1196s 17.2011ms 58.1360 Ops/s 57.1018 Ops/s $\color{#35bf28}+1.81\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1209s 17.1384ms 58.3485 Ops/s 57.2653 Ops/s $\color{#35bf28}+1.89\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1198s 17.1990ms 58.1431 Ops/s 57.2749 Ops/s $\color{#35bf28}+1.52\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1198s 17.2277ms 58.0461 Ops/s 65.2023 Ops/s $\textbf{\color{#d91a1a}-10.98\%}$

@vmoens vmoens added the enhancement New feature or request label Dec 15, 2023
@vmoens vmoens marked this pull request as ready for review December 15, 2023 12:03
@truncs
Copy link

truncs commented Dec 15, 2023

This is great! Some thoughts -

  1. I think sampling in just trajectories would be great but I can see how some API users would like to have more control of how to convert the chunks to trajectories.
  2. Trajectories of insufficient chunk length should be padded up to the chunk size and saved IMHO. Another option could be to let API users pass a function to handle this case where they can decide if they want to throw an exception or just pad the insufficient chunks.

@skandermoalla
Copy link
Contributor

I'm not sure to understand the use case of this. Could you put in more context?

It seems useful to methods that use full episodes.

Should batch_size indicate the number of slices (ie, trajectories) sampled?

Yes, it seems a bit odd and error prone to have hald the information in the slice variable and the other half in the batch size.

Related: Should we return batches of size [M, N] where M is the number of slices and N the trajectory lengths?

This would allow to have a determinitic batch size, but would be lots of zeros. I would prefer a determinitic batch size by default, and if there's a way to save memory on the tail it'd be great.

Last one: what do we do if some trajectories have an insufficient length? Should they be considered (and possibly padded)? IMO we should just return an exception or a warning but i'm open to other suggestions.

What does insuffient lenght mean here?
It seems like sampling and giving a mask to user would be the easiest.

Should we give a device to the transform to speed up the trajectory recovering (with cuda)?

The transform?

@vmoens
Copy link
Contributor Author

vmoens commented Dec 15, 2023

I'm not sure to understand the use case of this. Could you put in more context?

Sure!
Currently, when people need to sample trajectories, we advise them to build RBs storages with shape [N, T], but that's clunky because it means that if I tell you "build a RB of 1M elements" you have to divide 1M by the traj length.
Also it means that you have to pad or use trajs that all have the same length. On top of that, getting sub-trajectories is hard and inefficient... Finally, because in general trajectory (episode) lengths cannot be assumed to be identical so we store them as [N x T], but people will eventually want to sample some trajectories or some parts of trajectories...

It seems useful to methods that use full episodes.

Should batch_size indicate the number of slices (ie, trajectories) sampled?

Yes, it seems a bit odd and error prone to have hald the information in the slice variable and the other half in the batch size.

On the other hand, you will get more elements than you asked for! Could be surprising too.

Related: Should we return batches of size [M, N] where M is the number of slices and N the trajectory lengths?

This would allow to have a determinitic batch size, but would be lots of zeros. I would prefer a determinitic batch size by default, and if there's a way to save memory on the tail it'd be great.

Why lots of zeroes? What I meant is: should this reshape happen within the transform?

>>> for batch in data:
... batch = batch.reshape(num_slices, -1)
... break

Last one: what do we do if some trajectories have an insufficient length? Should they be considered (and possibly padded)? IMO we should just return an exception or a warning but i'm open to other suggestions.

What does insuffient lenght mean here? It seems like sampling and giving a mask to user would be the easiest.

Say some trajectories in your buffer have length 10, and you ask for a batch of total 320, across 10 splits: you will want have traj length of 32, which is greater than 10 -> should we pad or jusy throw an exception?

Should we give a device to the transform to speed up the trajectory recovering (with cuda)?

The transform?

Sorry it's Friday, I meant the sampler!

@matteobettini
Copy link
Contributor

This is very cool and it is much needed to have an easy way to sample trajectories!

I am a bit confused from the decrtiption though, it seems to me that N and M are overloaded in different parts of the text

for data in collector: # data is of shape [N, T]
    replay_buffer.extend(data.view(-1)) # adds the N consecutive trajectories of length T
    samples = replay_buffer.sample(M) # M elements divided in N slices
    samples = samples.view(N, -1)
  • Why is T trajectory length? in normal collectors T is just the number of collected frames divided by the batch size of collection

I think we might benefit from rewriting the description with clearer symbols and a recap of the current state of affairs prior to this change

My opinion

  1. I think users should request a desired trajectory length rather than a number of slices (at construction of the sampler). this is more intuive as most people care how long the traj fragment is (leaving this as a byproduct of number_of_slices is counterintuitive to me ). Therefore I would prefer something like TrajectorySampler(cache_values=cache,traj_length=10). For this reason is also suggest naming the sampler TrajectorySampler
  2. Regarding the batch size, when users request buffer.sample(n_samples) I think they should get a td with batch_size [n_slices,traj_length] where n_slices is computed as n_samples / traj_length (n_samples and traj_length are the user-defined quantities and n_slices is the byproduct)
  3. If the actual episode length is different from the traj_length length provided at construction (longer for some episodes, shorter for some other). We should:
    • if shorter: pad and give a warning that some are padded
    • if longer: just cut it at traj_length and put the overflowing bit in the next slice

@vmoens
Copy link
Contributor Author

vmoens commented Dec 16, 2023

Good points
I think we should raise an error if the length is shorter unless users explicitly set a pad arg to True, because literally everyone in RL disable warnings, and no one ever looks at the log.

@vmoens
Copy link
Contributor Author

vmoens commented Dec 16, 2023

if longer: just cut it at traj_length and put the overflowing bit in the next slice

I don't get that

@matteobettini
Copy link
Contributor

if longer: just cut it at traj_length and put the overflowing bit in the next slice

I don't get that

If the user requests 100 samples and my buffer has one epsiode which is 100 steps long and the user asked to sample trajectory fragments of length 70, we should return 2 fragments in a td of batch size [2,70]:

  • the first fragment is the first 70 steps
  • the second fragment is the last 30 steps plus 40 padded steps (supposing pagging is reuqested)

@vmoens
Copy link
Contributor Author

vmoens commented Dec 16, 2023

That isn't the contract of this class, which is to sample randomly subsequences (this is why it isn't a trajectory sampler but a slice sampler - it samples incomplete episodes). This is also why we add a truncated signal at the end.
The implementation you're suggesting seems more constraining to me, and surprising: consecutive samples of a batch should not be correlated.
I don't see why we should complete the trajectory in the next batch, what's the advantage? If users want full trajectories of length 100 why not asking for that?
In this implementation, I could also expect a slice of length 70 without padding because all my trajectories are longer than 70, and here I'll get some trajectories that are padded. Pretty sure that will lead to issues on the repo!

@vmoens
Copy link
Contributor Author

vmoens commented Dec 16, 2023

Another issue with this is that it seems to be that it assumes that all trajectories have the same length. Say this is not the case (eg, all our datasets) and you want sub trajs of length 100. One trajectory may be 101 transitions long so you will get 2 batches, one of 100 and another padded of 1. Users will never see the second batch contiguously with the first (it will appear like this one element alone isn't part of the same trajectory). Also it's not clear what's the batch size here: 200? 101? What if users wanted only 100 transitions? What if they wanted 200, should they get samples from another trajectory too, thereby having a varying output size from the buffer?

@vmoens
Copy link
Contributor Author

vmoens commented Dec 16, 2023

I updated the description!

@matteobettini
Copy link
Contributor

Thanks for the update! The new descriptions clears a lot of things and the choice of leaving to the users if they want to give S or U is great imo!

I guess the point I got confused about is if sampling is handled with or without replacement.
Ideally, both on policy algos (like PPO) and off policy algos may want to use this, so maybe the replacement strategy of this sampler might be important.

What I understand from your answers is that this sampler samples with replacement (so the data of 2 slices sampled in the same call could be the same)

I still think though I am not getting some API details of this sampler, so here are some examples that would be helpful for me to clarify:

  1. from what i understand, data from different episodes will never be present in the same slice. I.e. a slice will always contain a fragment from the same episode
  2. Are we sampling from trajectories independently? For example, if i have a buffer containing 2 episodes: one of length 30 and the other of length 70, and the users asks for 100 samples in 2 slices of length 50 each, can you explain what would be given to the user?
    • Are we going to sample 2 slices of 50 from the episode with length 70?
    • Are we going to sample one slice from the episode with length 70 and one padded one from the episode with length 30?
    • Are we going to sample in a random way and pad if there are not 50 steps after the selected starting indices?

@vmoens
Copy link
Contributor Author

vmoens commented Dec 16, 2023

What I understand from your answers is that this sampler samples with replacement (so the data of 2 slices sampled in the same call could be the same)

yes but that can be edited

from what i understand, data from different episodes will never be present in the same slice. I.e. a slice will always contain a fragment from the same episode

yes that is the point, but we could also loosen that (RNNs and value estimators allow mixed trajectories in torchrl so that wouldn't be an issue for them)

Are we sampling from trajectories independently? For example, if i have a buffer containing 2 episodes: one of length 30 and the other of length 70, and the users asks for 100 samples in 2 slices of length 50 each, can you explain what would be given to the user?

We sample first the episodes, then within the episodes. So each element does not have the same probabilty of ending up in the batch. Not sure how to implement the other option mathematically: perhaps sample items from the buffer and then sample a slice around these samples. That will mean that some trajectories have a higher priority than others, which is as opinionated as the first.

if i have a buffer containing 2 episodes: one of length 30 and the other of length 70, and the users asks for 100 samples in 2 slices of length 50 each

Currently, you would have one sample of 50 from trajectory 1 and a sample of length 30 + 20 pads from trajectory 0.

@vmoens
Copy link
Contributor Author

vmoens commented Dec 16, 2023

@matteobettini What does sampling without replacement mean?
If I have a trajectory of length 101 and ask for length 100, the second time I sample from it I will have only 1 sample. This means that the samples at the end of the epoch will be degenerated. If we samples another slice that contains that element, longer trajectories will have less repetitions across samples (and it won't be truly without replacement). The only option I see would be so sample each trajectory only once (that makes more sense to me) but then we would end up with a different likelihood per transition in the buffer (which I don't think is an issue).

Remember that in some settings people allocate themselves a number of "episodes" rather than a number of "transitions", which kind of make sense: for cartpole for instance, if you solve it in 100 episodes and not 200 it's way more impressive than 1000 samples and not 2000 (since success = longer episode)

@matteobettini
Copy link
Contributor

matteobettini commented Dec 16, 2023

Are we sampling from trajectories independently? For example, if i have a buffer containing 2 episodes: one of length 30 and the other of length 70, and the users asks for 100 samples in 2 slices of length 50 each, can you explain what would be given to the user?

We sample first the episodes, then within the episodes. So each element does not have the same probabilty of ending up in the batch. Not sure how to implement the other option mathematically: perhaps sample items from the buffer and then sample a slice around these samples. That will mean that some trajectories have a higher priority than others, which is as opinionated as the first.

if i have a buffer containing 2 episodes: one of length 30 and the other of length 70, and the users asks for 100 samples in 2 slices of length 50 each

Currently, you would have one sample of 50 from trajectory 1 and a sample of length 30 + 20 pads from trajectory 0.

Now I got it! Thanks for clarifying!

To me it was not clear that we first sample episode indeces and then slices.
So the sampling with replacement happens at 2 levels: for episodes (first) and trajecotries (second).
Or is the first sampling of episodes without replacement?

Yes I guess this has some implications. Basically each episode will be weighted equally (so if you have an episode of length 1 where the agent committed suicide in one step) this will have a dedicated slice of length U (potentially polluting the data)

The other implementation I was referring to is that you sample S indeces at random in the buffer flattened data and then you look U steps in the future from that index (if you do not have U steps in the future, you pad the remaining ones).

This implementation removes the hierarchical sampling (of episodes first and then trajs), but (as you said) gives less weight to shorter trajectories.

@vmoens
Copy link
Contributor Author

vmoens commented Dec 19, 2023

@matteobettini added SliceSamplerWithoutReplacement, which simply inherits from both parent classes in a rather simple way.

@vmoens vmoens merged commit 4d3a0c6 into main Dec 19, 2023
54 of 62 checks passed
@vmoens vmoens deleted the continuous-rb-sampler branch December 19, 2023 18:18
@skandermoalla
Copy link
Contributor

Sorry for being late to the party. This is really cool and much clearer now!

Raising a broader question: the description of the PR seems outdated again, so for recent features do we want to rebuild the documentation and serve it somewhere? or rely on the PR descriptions for the temporary documentation?

@vmoens
Copy link
Contributor Author

vmoens commented Dec 25, 2023

Good point I can update it! But the doc should be accurate

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants