Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] VideoRecorder for datasets and replay buffers #2069

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
init
  • Loading branch information
vmoens committed Apr 9, 2024
commit 0ccf306a7c581317f318c27066fb3d14c707f42a
48 changes: 48 additions & 0 deletions examples/video/video-from-dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Video from dataset example.

This example shows how to save a video from a dataset.

To run it, you will need to install the openx requirements as well as torchvision.
"""

from torchrl.data.datasets import OpenXExperienceReplay
from torchrl.record import CSVLogger, VideoRecorder

# Create a logger that saves videos as mp4
logger = CSVLogger("./dump", video_format="mp4")


# We use the VideoRecorder transform to save register the images coming from the batch.
t = VideoRecorder(
logger=logger, tag="pixels", in_keys=[("next", "observation", "image")]
)
# Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
dataset = OpenXExperienceReplay(
"cmu_stretch",
batch_size=2000,
slice_len=200,
download=True,
strict_length=False,
transform=t,
)

# Get a batch of data and visualize it
for _ in dataset:
# The transform has seen the data since it's in the replay buffer
t.dump()
break

# Alternatively, we can build the dataset without the VideoRecorder and call it manually:
dataset = OpenXExperienceReplay(
"cmu_stretch",
batch_size=2000,
slice_len=200,
download=True,
strict_length=False,
)

# Get a batch of data and visualize it
for data in dataset:
t(data)
t.dump()
break
11 changes: 9 additions & 2 deletions torchrl/data/datasets/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ def __init__(
slice_len: int | None = None,
pad: float | bool | None = None,
replacement: bool = None,
streaming: bool = True,
streaming: bool | None = None,
root: str | Path | None = None,
download: bool = False,
download: bool | None = None,
sampler: Sampler | None = None,
writer: Writer | None = None,
collate_fn: Callable | None = None,
Expand All @@ -317,6 +317,13 @@ def __init__(
split_trajs: bool = False,
strict_length: bool = True,
):
if download is None and streaming is None:
download = False
streaming = True
elif download is None:
download = not streaming
elif streaming is None:
streaming = not download
self.download = download
self.streaming = streaming
self.dataset_id = dataset_id
Expand Down
128 changes: 104 additions & 24 deletions torchrl/record/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ class VideoRecorder(ObservationTransform):
in_keys (Sequence of NestedKey, optional): keys to be read to produce the video.
Default is :obj:`"pixels"`.
skip (int): frame interval in the output video.
Default is 2.
Default is ``2`` if the transform has a parent environment, and ``1`` if not.
center_crop (int, optional): value of square center crop.
make_grid (bool, optional): if ``True``, a grid is created assuming that a
tensor of shape [B x W x H x 3] is provided, with B being the batch
size. Default is True.
size. Default is ``True`` if the transform has a parent environment, and ``False``
if not.
out_keys (sequence of NestedKey, optional): destination keys. Defaults
to ``in_keys`` if not provided.

Expand Down Expand Up @@ -66,6 +67,26 @@ class VideoRecorder(ObservationTransform):

>>> env.transform.dump()

The transform can also be used within a dataset to save the video collected. Unlike in the environment case,
images will come in a batch. The ``skip`` argument will enable to save the images only at specific intervals.

>>> from torchrl.data.datasets import OpenXExperienceReplay
>>> from torchrl.envs import Compose
>>> from torchrl.record import VideoRecorder, CSVLogger
>>> # Create a logger that saves videos as mp4
>>> logger = CSVLogger("./dump", video_format="mp4")
>>> # We use the VideoRecorder transform to save register the images coming from the batch.
>>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")])
>>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
>>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200,
... download=True, strict_length=False,
... transform=t)
>>> # Get a batch of data and visualize it
>>> for data in dataset:
... t.dump()
... break


Our video is available under ``./cheetah_videos/cheetah/videos/run_video_0.mp4``!

"""
Expand All @@ -75,9 +96,9 @@ def __init__(
logger: Logger,
tag: str,
in_keys: Optional[Sequence[NestedKey]] = None,
skip: int = 2,
skip: int | None = None,
center_crop: Optional[int] = None,
make_grid: bool = True,
make_grid: bool | None = None,
out_keys: Optional[Sequence[NestedKey]] = None,
**kwargs,
) -> None:
Expand All @@ -102,12 +123,59 @@ def __init__(
)
self.obs = []

@property
def make_grid(self):
make_grid = self._make_grid
if make_grid is None:
if self.parent is not None:
self._make_grid = True
return True
self._make_grid = False
return False
return make_grid

@make_grid.setter
def make_grid(self, value):
self._make_grid = value

@property
def skip(self):
skip = self._skip
if skip is None:
if self.parent is not None:
self._skip = 2
return 2
self._skip = 1
return 1
return skip

@skip.setter
def skip(self, value):
self._skip = value

def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
if not (observation.shape[-1] == 3 or observation.ndimension() == 2):
raise RuntimeError(f"Invalid observation shape, got: {observation.shape}")
observation_trsf = observation.clone()
self.count += 1
if self.count % self.skip == 0:
if (
observation.ndim >= 3
and observation.shape[-3] == 3
and observation.shape[-2] > 3
and observation.shape[-1] > 3
):
# permute the channels to the last dim
observation_trsf = observation.permute(
*range(observation.ndim - 3), -2, -1, -3
)
else:
observation_trsf = observation
if not (
observation_trsf.shape[-1] == 3 or observation_trsf.ndimension() == 2
):
raise RuntimeError(
f"Invalid observation shape, got: {observation.shape}"
)
observation_trsf = observation_trsf.clone()

if observation.ndimension() == 2:
observation_trsf = observation.unsqueeze(-3)
else:
Expand All @@ -131,38 +199,50 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
observation_trsf = center_crop_fn(
observation_trsf, [self.center_crop, self.center_crop]
)
if self.make_grid and observation_trsf.ndimension() == 4:
if self.make_grid and observation_trsf.ndimension() >= 4:
if not _has_tv:
raise ImportError(
"Could not import torchvision, `make_grid` not available."
"Make sure torchvision is installed in your environment."
)
from torchvision.utils import make_grid

observation_trsf = make_grid(observation_trsf)
self.obs.append(observation_trsf.to(torch.uint8))
observation_trsf = make_grid(observation_trsf.flatten(0, -4))
self.obs.append(observation_trsf.to(torch.uint8))
elif observation_trsf.ndimension() >= 4:
self.obs.extend(observation_trsf.to(torch.uint8).flatten(0, -4))
else:
self.obs.append(observation_trsf.to(torch.uint8))
return observation

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)

def dump(self, suffix: Optional[str] = None) -> None:
"""Writes the video to the self.logger attribute.
"""Writes the video to the ``self.logger`` attribute.

Calling ``dump`` when no image has been stored in a no-op.

Args:
suffix (str, optional): a suffix for the video to be recorded
"""
if suffix is None:
tag = self.tag
if self.obs:
obs = torch.stack(self.obs, 0).unsqueeze(0).cpu()
else:
tag = "_".join([self.tag, suffix])
obs = torch.stack(self.obs, 0).unsqueeze(0).cpu()
del self.obs
if self.logger is not None:
self.logger.log_video(
name=tag,
video=obs,
step=self.iter,
**self.video_kwargs,
)
del obs
obs = None
self.obs = []
if obs is not None:
if suffix is None:
tag = self.tag
else:
tag = "_".join([self.tag, suffix])
if self.logger is not None:
self.logger.log_video(
name=tag,
video=obs,
step=self.iter,
**self.video_kwargs,
)
self.iter += 1
self.count = 0
self.obs = []
Expand Down
Loading