Skip to content

Commit

Permalink
[Feature] Split-trajectories and represent as nested tensor (pytorch#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 28, 2024
1 parent 1083b35 commit a563c5e
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 2 deletions.
68 changes: 68 additions & 0 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch
from _utils_internal import get_default_devices
from tensordict import assert_allclose_td, TensorDict

from torchrl._utils import _ends_with
from torchrl.collectors.utils import split_trajectories
from torchrl.data.postprocs.postprocs import MultiStep

Expand Down Expand Up @@ -310,6 +312,72 @@ def test_splits(self, num_workers, traj_len, constr):
== split_trajs.get(("collector", "traj_ids")).max() + 1
)

@pytest.mark.parametrize("num_workers", range(3, 34, 3))
@pytest.mark.parametrize("traj_len", [10, 17, 50, 97])
@pytest.mark.parametrize(
"constr",
[
functools.partial(split_trajectories, prefix="collector", as_nested=True),
functools.partial(split_trajectories, as_nested=True),
functools.partial(
split_trajectories,
trajectory_key=("collector", "traj_ids"),
as_nested=True,
),
],
)
def test_split_traj_nested(self, num_workers, traj_len, constr):
trajs = TestSplits.create_fake_trajs(num_workers, traj_len)
assert trajs.shape[0] == num_workers
assert trajs.shape[1] == traj_len
split_trajs = constr(trajs)
assert split_trajs.shape[-1] == -1
assert split_trajs["next", "done"].is_nested

@pytest.mark.parametrize("num_workers", range(3, 34, 3))
@pytest.mark.parametrize("traj_len", [10, 17, 50, 97])
@pytest.mark.parametrize(
"constr0,constr1",
[
[
functools.partial(
split_trajectories, prefix="collector", as_nested=True
),
functools.partial(
split_trajectories, prefix="collector", as_nested=False
),
],
[
functools.partial(split_trajectories, as_nested=True),
functools.partial(split_trajectories, as_nested=False),
],
[
functools.partial(
split_trajectories,
trajectory_key=("collector", "traj_ids"),
as_nested=True,
),
functools.partial(
split_trajectories,
trajectory_key=("collector", "traj_ids"),
as_nested=False,
),
],
],
)
def test_split_traj_nested_equiv(self, num_workers, traj_len, constr0, constr1):
trajs = TestSplits.create_fake_trajs(num_workers, traj_len)
assert trajs.shape[0] == num_workers
assert trajs.shape[1] == traj_len
split_trajs1 = constr1(trajs)
mask_key = None
for key in split_trajs1.keys(True, True):
if _ends_with(key, "mask"):
mask_key = key
break
split_trajs0 = constr0(trajs).to_padded_tensor(mask_key=mask_key)
assert (split_trajs0 == split_trajs1).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
52 changes: 50 additions & 2 deletions torchrl/collectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def split_trajectories(
prefix=None,
trajectory_key: NestedKey | None = None,
done_key: NestedKey | None = None,
as_nested: bool = False,
) -> TensorDictBase:
"""A util function for trajectory separation.
Expand All @@ -44,6 +45,8 @@ def split_trajectories(
Args:
rollout_tensordict (TensorDictBase): a rollout with adjacent trajectories
along the last dimension.
Keyword Args:
prefix (NestedKey, optional): the prefix used to read and write meta-data,
such as ``"traj_ids"`` (the optional integer id of each trajectory)
and the ``"mask"`` entry indicating which data are valid and which
Expand All @@ -56,6 +59,19 @@ def split_trajectories(
to ``(prefix, "traj_ids")``.
done_key (NestedKey, optional): the key pointing to the ``"done""`` signal,
if the trajectory could not be directly recovered. Defaults to ``"done"``.
as_nested (bool or torch.layout, optional): whether to return the results as nested
tensors. Defaults to ``False``. If a ``torch.layout`` is provided, it will be used
to construct the nested tensor, otherwise the default layout will be used.
.. note:: Using ``split_trajectories(tensordict, as_nested=True).to_padded_tensor(mask=mask_key)``
should result in the exact same result as ``as_nested=False``. Since this is an experimental
feature and relies on nested_tensors, which API may change in the future, we made this
an optional feature. The runtime should be faster with ``as_nested=True``.
.. note:: Providing a layout lets the user control whether the nested tensor is to be used
with ``torch.strided`` or ``torch.jagged`` layout. While the former has slightly more
capabilities at the time of writing, the second will be the main focus of the PyTorch team
in the future due to its better compatibility with :func:`~torch.compile`.
Returns:
A new tensordict with a leading dimension corresponding to the trajectory.
Expand Down Expand Up @@ -92,7 +108,7 @@ def split_trajectories(
batch_size=torch.Size([2, 10]),
device=None,
is_shared=False)
>>> # check that split_trajectory got the trajectories right with the done signal
>>> # check that split_trajectories got the trajectories right with the done signal
>>> assert (data_split["traj_ids"] == data_split["trajectory"]).all()
>>> print(data_split["mask"])
tensor([[ True, True, True, True, True, True, True, True, True, True],
Expand Down Expand Up @@ -171,7 +187,39 @@ def split_trajectories(
rollout_tensordict = rollout_tensordict.unsqueeze(0)
return rollout_tensordict

out_splits = rollout_tensordict.reshape(-1).split(splits, 0)
out_splits = rollout_tensordict.reshape(-1)

if as_nested:
if hasattr(torch, "_nested_compute_contiguous_strides_offsets"):

def nest(x, splits=splits):
# Convert splits into shapes
shape = torch.tensor([[int(split), *x.shape[1:]] for split in splits])
return torch._nested_view_from_buffer(
x.reshape(-1),
shape,
*torch._nested_compute_contiguous_strides_offsets(shape),
)

return out_splits._fast_apply(
nest,
batch_size=[len(splits), -1],
)
else:
out_splits = out_splits.split(splits, 0)

layout = as_nested if as_nested is not bool else None

def nest(*x):
return torch.nested.nested_tensor(list(x), layout=layout)

return out_splits[0]._fast_apply(
nest,
*out_splits[1:],
batch_size=[len(out_splits), *out_splits[0].batch_size[:-1], -1],
)

out_splits = out_splits.split(splits, 0)

for out_split in out_splits:
out_split.set(
Expand Down

0 comments on commit a563c5e

Please sign in to comment.