Skip to content

Commit

Permalink
[BugFix] Fix strict_length in prioritized slice sampler (pytorch#2194)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 4, 2024
1 parent 3e6cb84 commit e5c3e32
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 75 deletions.
27 changes: 17 additions & 10 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1974,6 +1974,7 @@ def test_slice_sampler(
"count": torch.arange(100),
"other": torch.randn((20, 50)).expand(100, 20, 50),
done_key: done,
"terminated": done,
},
[100],
device=device,
Expand Down Expand Up @@ -2035,6 +2036,14 @@ def test_slice_sampler(
samples["another_episode"].view(-1).tolist()
)
count_unique = count_unique.union(samples.get("count").view(-1).tolist())

truncated = info[("next", "truncated")]
terminated = info[("next", "terminated")]
assert (truncated | terminated).view(num_slices, -1)[:, -1].all()
assert (terminated == samples["terminated"].view_as(terminated)).all()
done = info[("next", "done")]
assert done.view(num_slices, -1)[:, -1].all()

if len(count_unique) == 100:
# all items have been sampled
break
Expand All @@ -2049,11 +2058,6 @@ def test_slice_sampler(
assert too_short

assert len(trajs_unique_id) == 4
done = info[("next", "done")]
assert done.view(num_slices, -1)[:, -1].all()
truncated = info[("next", "truncated")]
terminated = info[("next", "terminated")]
assert (truncated | terminated).view(num_slices, -1)[:, -1].all()

@pytest.mark.parametrize("sampler", [SliceSampler, SliceSamplerWithoutReplacement])
def test_slice_sampler_at_capacity(self, sampler):
Expand Down Expand Up @@ -2877,9 +2881,9 @@ def test_done_slicesampler(self, strict_length):
env = SerialEnv(
3,
[
lambda: CountingEnv(max_steps=31),
lambda: CountingEnv(max_steps=32),
lambda: CountingEnv(max_steps=33),
lambda: CountingEnv(max_steps=31).add_truncated_keys(),
lambda: CountingEnv(max_steps=32).add_truncated_keys(),
lambda: CountingEnv(max_steps=33).add_truncated_keys(),
],
)
full_action_spec = CountingEnv(max_steps=32).full_action_spec
Expand All @@ -2896,9 +2900,12 @@ def test_done_slicesampler(self, strict_length):
batch_size=128,
)

# env.add_truncated_keys()

for i in range(50):
r = env.rollout(50, policy=policy, break_when_any_done=False)
r["next", "done"][:, -1] = 1
r = env.rollout(
50, policy=policy, break_when_any_done=False, set_truncated=True
)
rb.extend(r)

sample = rb.sample()
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
NestedStorageCheckpointer,
PrioritizedReplayBuffer,
PrioritizedSampler,
PrioritizedSliceSampler,
RandomSampler,
RemoteTensorDictReplayBuffer,
ReplayBuffer,
Expand Down
113 changes: 48 additions & 65 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,12 +1184,9 @@ def _get_index(
out_of_traj = relative_starts + seq_length > lengths[traj_idx]
if out_of_traj.any():
# a negative start means sampling fewer elements
# print('seq_length before', seq_length)
# print('relative_starts', relative_starts)
seq_length = torch.minimum(
seq_length, lengths[traj_idx] - relative_starts
)
# print('seq_length after', seq_length)

starts = torch.cat(
[
Expand All @@ -1212,18 +1209,11 @@ def _get_index(
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
# a traj is terminated if the stop index along col 0 (time)
# equates start + traj length - 1
traj_terminated = (
stop_idx[traj_idx, 0] == start_idx[traj_idx, 0] + seq_length - 1
terminated = (
(index[:, 0].unsqueeze(0) == stop_idx[:, 0].unsqueeze(1))
.any(0)
.unsqueeze(1)
)
terminated = torch.zeros_like(truncated)
if traj_terminated.any():
if isinstance(seq_length, int):
terminated.view(num_slices, -1)[traj_terminated, -1] = 1
else:
terminated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
truncated = truncated & ~terminated
done = terminated | truncated
return index.to(torch.long).unbind(-1), {
truncated_key: truncated,
Expand Down Expand Up @@ -1664,13 +1654,43 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]

if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
"Some stored trajectories have a length shorter than the slice that was asked for. "
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
"in you batch."
)
# make seq_length a tensor with values clamped by lengths
seq_length = lengths[traj_idx].clamp_max(seq_length)
idx = lengths >= seq_length
if not idx.any():
raise RuntimeError(
f"Did not find a single trajectory with sufficient length (length range: {lengths.min()} - {lengths.max()} / required={seq_length}))."
)
if (
isinstance(seq_length, torch.Tensor)
and seq_length.shape == lengths.shape
):
seq_length = seq_length[idx]
lengths_idx = lengths[idx]
start_idx = start_idx[idx]
stop_idx = stop_idx[idx]

# Here we must filter out the indices that correspond to trajectories
# we don't want to keep. That could potentially lead to an empty sample.
# The difficulty with this adjustment is that traj_idx points to a full
# sequences of lengths, but we filter out part of it so we must
# convert traj_idx to a boolean mask, index this mask with the
# valid indices and then recover the nonzero.
idx_mask = torch.zeros_like(idx)
idx_mask[traj_idx] = True
traj_idx = idx_mask[idx].nonzero().squeeze(-1)
if not traj_idx.numel():
raise RuntimeError(
"None of the provided indices pointed to a trajectory of "
"sufficient length. Consider using strict_length=False for the "
"sampler instead."
)
num_slices = traj_idx.shape[0]
del idx
lengths = lengths_idx
else:
num_slices = traj_idx.shape[0]

# make seq_length a tensor with values clamped by lengths
seq_length = lengths[traj_idx].clamp_max(seq_length)

# build a list of index that we don't want to sample: all the steps at a `seq_length` distance of
# the end the trajectory, with the end of trajectory (`stop_idx`) included
Expand All @@ -1695,7 +1715,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
-1,
)
if storage.ndim > 1:
# convert the 2d index into a flat one to accomodate the _sum_tree
# convert the 2d index into a flat one to accommodate the _sum_tree
preceding_stop_idx = torch.as_tensor(
np.ravel_multi_index(
tuple(preceding_stop_idx.transpose(0, 1).numpy()), storage.shape
Expand Down Expand Up @@ -1729,42 +1749,10 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
raise ValueError(
f"Number of indices is expected to match the batch size ({index.shape[0]} != {batch_size})."
)

# if self.truncated_key is not None:
# truncated_key = self.truncated_key
# done_key = _replace_last(truncated_key, "done")
# terminated_key = _replace_last(truncated_key, "terminated")
#
# truncated = torch.zeros(
# (index.shape[0], 1), dtype=torch.bool, device=index.device
# )
# if isinstance(seq_length, int):
# truncated.view(num_slices, -1)[:, -1] = 1
# else:
# truncated[seq_length.cumsum(0) - 1] = 1
# # a traj is terminated if the stop index along col 0 (time)
# # equates start + traj length - 1
# traj_terminated = (
# stop_idx[traj_idx, 0] == start_idx[traj_idx, 0] + seq_length - 1
# )
# terminated = torch.zeros_like(truncated)
# if traj_terminated.any():
# if isinstance(seq_length, int):
# truncated.view(num_slices, -1)[traj_terminated] = 1
# else:
# truncated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
# truncated = truncated & ~terminated
# done = terminated | truncated
# return index.to(torch.long).unbind(-1), {
# truncated_key: truncated,
# done_key: done,
# terminated_key: terminated,
# }

if self.truncated_key is not None:
# TODO: fix this part
# following logics borrowed from SliceSampler
truncated_key = self.truncated_key

done_key = _replace_last(truncated_key, "done")
terminated_key = _replace_last(truncated_key, "terminated")

Expand All @@ -1775,25 +1763,20 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
traj_terminated = stop_idx[traj_idx, 0] == (
start_idx[traj_idx, 0] + seq_length - 1
terminated = (
(index[:, 0].unsqueeze(0) == stop_idx[:, 0].unsqueeze(1))
.any(0)
.unsqueeze(1)
)
terminated = torch.zeros_like(truncated)
if traj_terminated.any():
if isinstance(seq_length, int):
terminated.view(num_slices, -1)[traj_terminated, -1] = 1
else:
terminated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
truncated = truncated & ~terminated
done = terminated | truncated

info.update(
{
truncated_key: truncated,
done_key: done,
terminated_key: terminated,
}
)
return index.to(torch.long).unbind(-1), info
return index.to(torch.long).unbind(-1), info

def _empty(self):
Expand Down
6 changes: 6 additions & 0 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,12 @@ def _reset_proc_data(self, tensordict, tensordict_reset):
return _update_during_reset(tensordict_reset, tensordict, self.reset_keys)
return tensordict_reset

def add_truncated_keys(self):
raise RuntimeError(
"Cannot add truncated keys to a batched environment. Please add these entries to "
"the nested environments by calling sub_env.add_truncated_keys()"
)


class SerialEnv(BatchedEnvBase):
"""Creates a series of environments in the same process."""
Expand Down

0 comments on commit e5c3e32

Please sign in to comment.