Skip to content

Commit

Permalink
[Feature] VideoRecorder for datasets and replay buffers (#2069)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 9, 2024
1 parent 79e2b07 commit 02c8342
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 26 deletions.
48 changes: 48 additions & 0 deletions examples/video/video-from-dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Video from dataset example.
This example shows how to save a video from a dataset.
To run it, you will need to install the openx requirements as well as torchvision.
"""

from torchrl.data.datasets import OpenXExperienceReplay
from torchrl.record import CSVLogger, VideoRecorder

# Create a logger that saves videos as mp4
logger = CSVLogger("./dump", video_format="mp4")


# We use the VideoRecorder transform to save register the images coming from the batch.
t = VideoRecorder(
logger=logger, tag="pixels", in_keys=[("next", "observation", "image")]
)
# Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
dataset = OpenXExperienceReplay(
"cmu_stretch",
batch_size=2000,
slice_len=200,
download=True,
strict_length=False,
transform=t,
)

# Get a batch of data and visualize it
for _ in dataset:
# The transform has seen the data since it's in the replay buffer
t.dump()
break

# Alternatively, we can build the dataset without the VideoRecorder and call it manually:
dataset = OpenXExperienceReplay(
"cmu_stretch",
batch_size=2000,
slice_len=200,
download=True,
strict_length=False,
)

# Get a batch of data and visualize it
for data in dataset:
t(data)
t.dump()
break
11 changes: 9 additions & 2 deletions torchrl/data/datasets/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ def __init__(
slice_len: int | None = None,
pad: float | bool | None = None,
replacement: bool = None,
streaming: bool = True,
streaming: bool | None = None,
root: str | Path | None = None,
download: bool = False,
download: bool | None = None,
sampler: Sampler | None = None,
writer: Writer | None = None,
collate_fn: Callable | None = None,
Expand All @@ -317,6 +317,13 @@ def __init__(
split_trajs: bool = False,
strict_length: bool = True,
):
if download is None and streaming is None:
download = False
streaming = True
elif download is None:
download = not streaming
elif streaming is None:
streaming = not download
self.download = download
self.streaming = streaming
self.dataset_id = dataset_id
Expand Down
128 changes: 104 additions & 24 deletions torchrl/record/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ class VideoRecorder(ObservationTransform):
in_keys (Sequence of NestedKey, optional): keys to be read to produce the video.
Default is :obj:`"pixels"`.
skip (int): frame interval in the output video.
Default is 2.
Default is ``2`` if the transform has a parent environment, and ``1`` if not.
center_crop (int, optional): value of square center crop.
make_grid (bool, optional): if ``True``, a grid is created assuming that a
tensor of shape [B x W x H x 3] is provided, with B being the batch
size. Default is True.
size. Default is ``True`` if the transform has a parent environment, and ``False``
if not.
out_keys (sequence of NestedKey, optional): destination keys. Defaults
to ``in_keys`` if not provided.
Expand Down Expand Up @@ -66,6 +67,26 @@ class VideoRecorder(ObservationTransform):
>>> env.transform.dump()
The transform can also be used within a dataset to save the video collected. Unlike in the environment case,
images will come in a batch. The ``skip`` argument will enable to save the images only at specific intervals.
>>> from torchrl.data.datasets import OpenXExperienceReplay
>>> from torchrl.envs import Compose
>>> from torchrl.record import VideoRecorder, CSVLogger
>>> # Create a logger that saves videos as mp4
>>> logger = CSVLogger("./dump", video_format="mp4")
>>> # We use the VideoRecorder transform to save register the images coming from the batch.
>>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")])
>>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
>>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200,
... download=True, strict_length=False,
... transform=t)
>>> # Get a batch of data and visualize it
>>> for data in dataset:
... t.dump()
... break
Our video is available under ``./cheetah_videos/cheetah/videos/run_video_0.mp4``!
"""
Expand All @@ -75,9 +96,9 @@ def __init__(
logger: Logger,
tag: str,
in_keys: Optional[Sequence[NestedKey]] = None,
skip: int = 2,
skip: int | None = None,
center_crop: Optional[int] = None,
make_grid: bool = True,
make_grid: bool | None = None,
out_keys: Optional[Sequence[NestedKey]] = None,
**kwargs,
) -> None:
Expand All @@ -102,12 +123,59 @@ def __init__(
)
self.obs = []

@property
def make_grid(self):
make_grid = self._make_grid
if make_grid is None:
if self.parent is not None:
self._make_grid = True
return True
self._make_grid = False
return False
return make_grid

@make_grid.setter
def make_grid(self, value):
self._make_grid = value

@property
def skip(self):
skip = self._skip
if skip is None:
if self.parent is not None:
self._skip = 2
return 2
self._skip = 1
return 1
return skip

@skip.setter
def skip(self, value):
self._skip = value

def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
if not (observation.shape[-1] == 3 or observation.ndimension() == 2):
raise RuntimeError(f"Invalid observation shape, got: {observation.shape}")
observation_trsf = observation.clone()
self.count += 1
if self.count % self.skip == 0:
if (
observation.ndim >= 3
and observation.shape[-3] == 3
and observation.shape[-2] > 3
and observation.shape[-1] > 3
):
# permute the channels to the last dim
observation_trsf = observation.permute(
*range(observation.ndim - 3), -2, -1, -3
)
else:
observation_trsf = observation
if not (
observation_trsf.shape[-1] == 3 or observation_trsf.ndimension() == 2
):
raise RuntimeError(
f"Invalid observation shape, got: {observation.shape}"
)
observation_trsf = observation_trsf.clone()

if observation.ndimension() == 2:
observation_trsf = observation.unsqueeze(-3)
else:
Expand All @@ -131,38 +199,50 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
observation_trsf = center_crop_fn(
observation_trsf, [self.center_crop, self.center_crop]
)
if self.make_grid and observation_trsf.ndimension() == 4:
if self.make_grid and observation_trsf.ndimension() >= 4:
if not _has_tv:
raise ImportError(
"Could not import torchvision, `make_grid` not available."
"Make sure torchvision is installed in your environment."
)
from torchvision.utils import make_grid

observation_trsf = make_grid(observation_trsf)
self.obs.append(observation_trsf.to(torch.uint8))
observation_trsf = make_grid(observation_trsf.flatten(0, -4))
self.obs.append(observation_trsf.to(torch.uint8))
elif observation_trsf.ndimension() >= 4:
self.obs.extend(observation_trsf.to(torch.uint8).flatten(0, -4))
else:
self.obs.append(observation_trsf.to(torch.uint8))
return observation

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)

def dump(self, suffix: Optional[str] = None) -> None:
"""Writes the video to the self.logger attribute.
"""Writes the video to the ``self.logger`` attribute.
Calling ``dump`` when no image has been stored in a no-op.
Args:
suffix (str, optional): a suffix for the video to be recorded
"""
if suffix is None:
tag = self.tag
if self.obs:
obs = torch.stack(self.obs, 0).unsqueeze(0).cpu()
else:
tag = "_".join([self.tag, suffix])
obs = torch.stack(self.obs, 0).unsqueeze(0).cpu()
del self.obs
if self.logger is not None:
self.logger.log_video(
name=tag,
video=obs,
step=self.iter,
**self.video_kwargs,
)
del obs
obs = None
self.obs = []
if obs is not None:
if suffix is None:
tag = self.tag
else:
tag = "_".join([self.tag, suffix])
if self.logger is not None:
self.logger.log_video(
name=tag,
video=obs,
step=self.iter,
**self.video_kwargs,
)
self.iter += 1
self.count = 0
self.obs = []
Expand Down

0 comments on commit 02c8342

Please sign in to comment.