-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] SliceSampler should return unique IDs when sampling multiple times from the same trajectory #2588
Comments
Taking a fresh look at this again, it seems that a workaround may be to do something like: sample, info = rb.sample(minibatch_size, return_info=True)
sample["next", "end_of_slice"] = (
info["next", "truncated"]
| info["next", "done"]
| info["next", "terminated"]
)
sample = split_trajectories(sample, done_key="end_of_slice") But this is hardly ergonomic, or should at least be clarified as an example in the documentation. |
Hey
import torch
from tensordict import TensorDict
from torchrl.collectors.utils import split_trajectories
from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
sampler=SliceSamplerWithoutReplacement(
slice_len=5, traj_key="episode",strict_length=False
))
ep_1 = TensorDict(
{"obs": torch.arange(100),
"episode": torch.zeros(100),},
batch_size=[100]
)
ep_2 = TensorDict(
{"obs": torch.arange(4),
"episode": torch.ones(4),},
batch_size=[4]
)
rb.extend(ep_1)
rb.extend(ep_2)
s = rb.sample(50)
t = split_trajectories(s, trajectory_key="episode")
print(t["obs"])
print(t["episode"]) That will ensure that you don't have the same item twice
import torch
from tensordict import TensorDict
from torchrl.collectors.utils import split_trajectories
from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000),
sampler=SliceSampler(
slice_len=5, traj_key="episode",strict_length=False,
))
ep_1 = TensorDict(
{"obs": torch.arange(100),
"episode": torch.zeros(100),},
batch_size=[100]
)
ep_2 = TensorDict(
{"obs": torch.arange(4),
"episode": torch.ones(4),},
batch_size=[4]
)
rb.extend(ep_1)
rb.extend(ep_2)
s = rb.sample(50)
print(s)
t = split_trajectories(s, done_key="truncated")
print(t["obs"])
print(t["episode"])
s, info = rb.sample(50, return_info=True)
print(s)
s["next", "truncated"] = info[("next", "truncated")]
t = split_trajectories(s, done_key="truncated") But in general I do agree that we need better doc. |
Thanks for responding so quickly! In my particular case, I am collecting a few episodes (of wildly varying length), training on a few large-ish batches on short-ish slices, and then clearing the replay buffer, so unfortunately I first looked at the I then looked at the |
Is this a good edit? |
This seems like a major improvement to the documentation! Thanks for updating that. |
Describe the bug
When using SliceSampler, with
strict_length=False
, the documentation recommends the use ofsplit_trajectories
. However, if two samples from the same episode are placed next to each other, this produces the wrong output because subsequent samples may have the sametrajectory_key
despite being logically independent.To Reproduce
split_trajectories
returns nonsense results whentrajectory_key
contains non-contiguous duplicates.Even if that weren't the case, there would still be a bug:
When
SliceSampler
is drawing from relatively few trajectories, there will be situations where multiple slices of the same trajectory are returned next to each other:However,
split_trajectories
will see thatepisode
is the same for both slices, and incorrectly combine them into one longer slice.Expected behavior
SliceSampler
should add an additional key to its returned dict to distinguish samples, at least whenstrict_length=False
:Screenshots
If applicable, add screenshots to help explain your problem.
System info
M1 Mac, version 15.1
Both
torchrl
andtensordict
were installed from source.Checklist
The text was updated successfully, but these errors were encountered: