Skip to content

Commit

Permalink
[BugFix] Fix slicesampler terminated/truncated signaling (pytorch#2044)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 27, 2024
1 parent f439b54 commit c98754f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 8 deletions.
51 changes: 47 additions & 4 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch

from _utils_internal import get_default_devices, make_tc

from mocking_classes import CountingEnv
from packaging import version
from packaging.version import parse
from tensordict import (
Expand All @@ -30,7 +32,6 @@
)
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map

from torchrl.collectors import RandomPolicy, SyncDataCollector
from torchrl.collectors.utils import split_trajectories
from torchrl.data import (
Expand Down Expand Up @@ -2025,8 +2026,11 @@ def test_slice_sampler(
assert too_short

assert len(trajs_unique_id) == 4
done = info[("next", "done")]
assert done.view(num_slices, -1)[:, -1].all()
truncated = info[("next", "truncated")]
assert truncated.view(num_slices, -1)[:, -1].all()
terminated = info[("next", "terminated")]
assert (truncated | terminated).view(num_slices, -1)[:, -1].all()

@pytest.mark.parametrize("sampler", [SliceSampler, SliceSamplerWithoutReplacement])
def test_slice_sampler_at_capacity(self, sampler):
Expand Down Expand Up @@ -2166,8 +2170,10 @@ def test_slice_sampler_without_replacement(
trajs_unique_id = trajs_unique_id.union(
cur_episodes,
)
truncated = info[("next", "truncated")]
assert truncated.view(num_slices, -1)[:, -1].all()
done = info[("next", "done")]
assert done.view(num_slices, -1)[:, -1].all()
done_recon = info[("next", "truncated")] | info[("next", "terminated")]
assert done_recon.view(num_slices, -1)[:, -1].all()

def test_slicesampler_strictlength(self):

Expand Down Expand Up @@ -2792,6 +2798,43 @@ def test_rb_multidim_collector(
print(f"rb {rb}") # noqa: T201
raise

@pytest.mark.parametrize("strict_length", [True, False])
def test_done_slicesampler(self, strict_length):
env = SerialEnv(
3,
[
lambda: CountingEnv(max_steps=31),
lambda: CountingEnv(max_steps=32),
lambda: CountingEnv(max_steps=33),
],
)
full_action_spec = CountingEnv(max_steps=32).full_action_spec
policy = lambda td: td.update(
full_action_spec.zero((3,)).apply_(lambda x: x + 1)
)
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(200, ndim=2),
sampler=SliceSampler(
slice_len=32,
strict_length=strict_length,
truncated_key=("next", "truncated"),
),
batch_size=128,
)

for i in range(50):
r = env.rollout(50, policy=policy, break_when_any_done=False)
r["next", "done"][:, -1] = 1
rb.extend(r)

sample = rb.sample()

assert sample["next", "done"].sum() == 128 // 32, (
i,
sample["next", "done"].sum(),
)
assert (split_trajectories(sample)["next", "done"].sum(-2) == 1).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
8 changes: 4 additions & 4 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,9 +1160,9 @@ def _get_index(
terminated = torch.zeros_like(truncated)
if traj_terminated.any():
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[traj_terminated] = 1
terminated.view(num_slices, -1)[traj_terminated, -1] = 1
else:
truncated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
terminated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
truncated = truncated & ~terminated
done = terminated | truncated
return index.to(torch.long).unbind(-1), {
Expand Down Expand Up @@ -1726,9 +1726,9 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
terminated = torch.zeros_like(truncated)
if traj_terminated.any():
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[traj_terminated] = 1
terminated.view(num_slices, -1)[:, traj_terminated] = 1
else:
truncated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
terminated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
truncated = truncated & ~terminated
done = terminated | truncated

Expand Down

0 comments on commit c98754f

Please sign in to comment.