Skip to content

Commit

Permalink
[Feature] CatFrames.make_rb_transform_and_sampler
Browse files Browse the repository at this point in the history
ghstack-source-id: 7ecf952ec9f102a831aefdba533027ff8c4c29cc
Pull Request resolved: #2643
  • Loading branch information
vmoens committed Dec 13, 2024
1 parent 17983d4 commit 7365fb5
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 3 deletions.
99 changes: 99 additions & 0 deletions examples/replay-buffers/catframes-in-buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import (
CatFrames,
Compose,
DMControlEnv,
StepCounter,
ToTensorImage,
TransformedEnv,
UnsqueezeTransform,
)

# Number of frames to stack together
frame_stack = 4
# Dimension along which the stack should occur
stack_dim = -4
# Max size of the buffer
max_size = 100_000
# Batch size of the replay buffer
training_batch_size = 32

seed = 123


def main():
catframes = CatFrames(
N=frame_stack,
dim=stack_dim,
in_keys=["pixels_trsf"],
out_keys=["pixels_trsf"],
)
env = TransformedEnv(
DMControlEnv(
env_name="cartpole",
task_name="balance",
device="cpu",
from_pixels=True,
pixels_only=True,
),
Compose(
ToTensorImage(
from_int=True,
dtype=torch.float32,
in_keys=["pixels"],
out_keys=["pixels_trsf"],
shape_tolerant=True,
),
UnsqueezeTransform(
dim=stack_dim, in_keys=["pixels_trsf"], out_keys=["pixels_trsf"]
),
catframes,
StepCounter(),
),
)
env.set_seed(seed)

transform, sampler = catframes.make_rb_transform_and_sampler(
batch_size=training_batch_size,
traj_key=("collector", "traj_ids"),
strict_length=True,
)

rb_transforms = Compose(
ToTensorImage(
from_int=True,
dtype=torch.float32,
in_keys=["pixels", ("next", "pixels")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
shape_tolerant=True,
), # C W' H' -> C W' H' (unchanged due to shape_tolerant)
UnsqueezeTransform(
dim=stack_dim,
in_keys=["pixels_trsf", ("next", "pixels_trsf")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
), # 1 C W' H'
transform,
)

rb = ReplayBuffer(
storage=LazyTensorStorage(max_size=max_size, device="cpu"),
sampler=sampler,
batch_size=training_batch_size,
transform=rb_transforms,
)

data = env.rollout(1000, break_when_any_done=False)
rb.extend(data)

training_batch = rb.sample()
print(training_batch)


if __name__ == "__main__":
main()
23 changes: 23 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,29 @@ def test_transform_rb(self, dim, N, padding, rbclass):
assert (tdsample["out_" + key1] == td["out_" + key1]).all()
assert (tdsample["next", "out_" + key1] == td["next", "out_" + key1]).all()

def test_transform_rb_maker(self):
env = CountingEnv(max_steps=10)
catframes = CatFrames(
in_keys=["observation"], out_keys=["observation_stack"], dim=-1, N=4
)
env.append_transform(catframes)
policy = lambda td: td.update(env.full_action_spec.zeros() + 1)
rollout = env.rollout(150, policy, break_when_any_done=False)
transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32)
rb = ReplayBuffer(
sampler=sampler, storage=LazyTensorStorage(150), transform=transform
)
rb.extend(rollout)
sample = rb.sample(32)
assert "observation_stack" not in rb._storage._storage
assert sample.shape == (32,)
assert sample["observation_stack"].shape == (32, 4)
assert sample["next", "observation_stack"].shape == (32, 4)
assert (
sample["observation_stack"]
== sample["observation_stack"][:, :1] + torch.arange(4)
).all()

@pytest.mark.parametrize("dim", [-1])
@pytest.mark.parametrize("N", [3, 4])
@pytest.mark.parametrize("padding", ["same", "constant"])
Expand Down
7 changes: 7 additions & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,9 @@ class SliceSampler(Sampler):
"""

# We use this whenever we need to sample N times too many transitions to then select only a 1/N fraction of them
_batch_size_multiplier: int | None = 1

def __init__(
self,
*,
Expand Down Expand Up @@ -1295,6 +1298,8 @@ def _adjusted_batch_size(self, batch_size):
return seq_length, num_slices

def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if self._batch_size_multiplier is not None:
batch_size = batch_size * self._batch_size_multiplier
# pick up as many trajs as we need
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
# we have to make sure that the number of dims of the storage
Expand Down Expand Up @@ -1747,6 +1752,8 @@ def _storage_len(self, storage):
def sample(
self, storage: Storage, batch_size: int
) -> Tuple[Tuple[torch.Tensor, ...], dict]:
if self._batch_size_multiplier is not None:
batch_size = batch_size * self._batch_size_multiplier
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
# we have to make sure that the number of dims of the storage
# is the same as the stop/start signals since we will
Expand Down
83 changes: 80 additions & 3 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,9 +2825,9 @@ def _reset(
class CatFrames(ObservationTransform):
"""Concatenates successive observation frames into a single tensor.
This can, for instance, account for movement/velocity of the observed
feature. Proposed in "Playing Atari with Deep Reinforcement Learning" (
https://arxiv.org/pdf/1312.5602.pdf).
This transform is useful for creating a sense of movement or velocity in the observed features.
It can also be used with models that require access to past observations such as transformers and the like.
It was first proposed in "Playing Atari with Deep Reinforcement Learning" (https://arxiv.org/pdf/1312.5602.pdf).
When used within a transformed environment,
:class:`CatFrames` is a stateful class, and it can be reset to its native state by
Expand Down Expand Up @@ -2915,6 +2915,14 @@ class CatFrames(ObservationTransform):
such as those found in MARL settings, are currently not supported.
If this feature is needed, please raise an issue on TorchRL repo.
.. note:: Storing stacks of frames in the replay buffer can significantly increase memory consumption (by N times).
To mitigate this, you can store trajectories directly in the replay buffer and apply :class:`CatFrames` at sampling time.
This approach involves sampling slices of the stored trajectories and then applying the frame stacking transform.
For convenience, :class:`CatFrames` provides a :meth:`~.make_rb_transform_and_sampler` method that creates:
- A modified version of the transform suitable for use in replay buffers
- A corresponding :class:`SliceSampler` to use with the buffer
"""

inplace = False
Expand Down Expand Up @@ -2964,6 +2972,75 @@ def __init__(
self.reset_key = reset_key
self.done_key = done_key

def make_rb_transform_and_sampler(
self, batch_size: int, **sampler_kwargs
) -> Tuple[Transform, "torchrl.data.replay_buffers.SliceSampler"]: # noqa: F821
"""Creates a transform and sampler to be used with a replay buffer when storing frame-stacked data.
This method helps reduce redundancy in stored data by avoiding the need to
store the entire stack of frames in the buffer. Instead, it creates a
transform that stacks frames on-the-fly during sampling, and a sampler that
ensures the correct sequence length is maintained.
Args:
batch_size (int): The batch size to use for the sampler.
**sampler_kwargs: Additional keyword arguments to pass to the
:class:`~torchrl.data.replay_buffers.SliceSampler` constructor.
Returns:
A tuple containing:
- transform (Transform): A transform that stacks frames on-the-fly during sampling.
- sampler (SliceSampler): A sampler that ensures the correct sequence length is maintained.
Example:
>>> env = TransformedEnv(...)
>>> catframes = CatFrames(N=4, ...)
>>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32)
>>> rb = ReplayBuffer(..., sampler=sampler, transform=transform)
.. note:: When working with images, it's recommended to use distinct ``in_keys`` and ``out_keys`` in the preceding
:class:`~torchrl.envs.ToTensorImage` transform. This ensures that the tensors stored in the buffer are separate
from their processed counterparts, which we don't want to store.
For non-image data, consider inserting a :class:`~torchrl.envs.RenameTransform` before :class:`CatFrames` to create
a copy of the data that will be stored in the buffer.
.. note:: When adding the transform to the replay buffer, one should pay attention to also pass the transforms
that precede CatFrames, such as :class:`~torchrl.envs.ToTensorImage` or :class:`~torchrl.envs.UnsqueezeTransform`
in such a way that the :class:`~torchrl.envs.CatFrames` transforms sees data formatted as it was during data
collection.
.. note:: For a more complete example, refer to torchrl's github repo `examples` folder:
https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py
"""
from torchrl.data.replay_buffers import SliceSampler

in_keys = self.in_keys
in_keys = in_keys + [unravel_key(("next", key)) for key in in_keys]
out_keys = self.out_keys
out_keys = out_keys + [unravel_key(("next", key)) for key in out_keys]
catframes = type(self)(
N=self.N,
in_keys=in_keys,
out_keys=out_keys,
dim=self.dim,
padding=self.padding,
padding_value=self.padding_value,
as_inverse=False,
reset_key=self.reset_key,
done_key=self.done_key,
)
sampler = SliceSampler(slice_len=self.N, **sampler_kwargs)
sampler._batch_size_multiplier = self.N
transform = Compose(
lambda td: td.reshape(-1, self.N),
catframes,
lambda td: td[:, -1],
# We only store "pixels" to the replay buffer to save memory
ExcludeTransform(*out_keys, inverse=True),
)
return transform, sampler

@property
def done_key(self):
done_key = self.__dict__.get("_done_key", None)
Expand Down

0 comments on commit 7365fb5

Please sign in to comment.