Skip to content

Commit

Permalink
[Doc, Feature] Doc improvements for video recording and CSV video for…
Browse files Browse the repository at this point in the history
…mats (#1829)
  • Loading branch information
vmoens authored Jan 23, 2024
1 parent c390cf6 commit 24d14ad
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ dependencies:
- patchelf
- pyopengl==3.1.4
- ray<2.8.0
- av
42 changes: 35 additions & 7 deletions test/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import pytest
import torch

from tensordict import MemoryMappedTensor
from torchrl.record.loggers.csv import CSVLogger
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
Expand Down Expand Up @@ -150,16 +152,22 @@ def test_log_scalar(self, steps, tmpdir):
assert row == f"{step},{values[i].item()}\n"

@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_video(self, steps, tmpdir):
@pytest.mark.parametrize(
"video_format", ["pt", "memmap"] + ["mp4"] if _has_tv else []
)
def test_log_video(self, steps, video_format, tmpdir):
torch.manual_seed(0)
exp_name = "ramala"
logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name)
logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name, video_format=video_format)

# creating a sample video (T, C, H, W), where T - number of frames,
# C - number of image channels (e.g. 3 for RGB), H, W - image dimensions.
# the first 64 frames are black and the next 64 are white
video = torch.cat(
(torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255))
(
torch.zeros(64, 1, 32, 32, dtype=torch.uint8),
torch.full((64, 1, 32, 32), 255, dtype=torch.uint8),
)
)
video = video[None, :]
for i in range(3):
Expand All @@ -171,11 +179,31 @@ def test_log_video(self, steps, tmpdir):
sleep(0.01) # wait until events are registered

# check that the logged videos are the same as the initial video
video_file_name = "foo_" + ("0" if not steps else str(steps[0])) + ".pt"
logged_video = torch.load(
os.path.join(tmpdir, exp_name, "videos", video_file_name)
extention = (
".pt"
if video_format == "pt"
else ".memmap"
if video_format == "memmap"
else ".mp4"
)
assert torch.equal(video, logged_video), logged_video
video_file_name = "foo_" + ("0" if not steps else str(steps[0])) + extention
path = os.path.join(tmpdir, exp_name, "videos", video_file_name)
if video_format == "pt":
logged_video = torch.load(path)
assert torch.equal(video, logged_video), logged_video
elif video_format == "memmap":
logged_video = MemoryMappedTensor.from_filename(
path, dtype=torch.uint8, shape=(1, 128, 1, 32, 32)
)
assert torch.equal(video, logged_video), logged_video
elif video_format == "mp4":
import torchvision

logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][
:, :1
]
logged_video = logged_video.unsqueeze(0)
torch.testing.assert_close(video, logged_video)

# check that we catch the error in case the format of the tensor is wrong
video_wrong_format = torch.zeros(64, 2, 32, 32)
Expand Down
4 changes: 1 addition & 3 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ class A2CLoss(LossModule):
the expected keyword arguments are:
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic.
The return value is a tuple of tensors in the following order:
``["loss_objective"]``
+ ``["loss_critic"]`` if critic_coef is not None
+ ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coef is not None
``["loss_objective"]`` + ``["loss_critic"]`` if critic_coef is not None + ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coef is not None
Examples:
>>> import torch
Expand Down
3 changes: 1 addition & 2 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ class PPOLoss(LossModule):
the expected keyword arguments are:
``["action", "sample_log_prob", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and value network.
The return value is a tuple of tensors in the following order:
``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set
+ ``"loss_critic"`` if critic_coef is not None.
``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if critic_coef is not ``None``.
The output keys can also be filtered using :meth:`PPOLoss.select_out_keys` method.
Examples:
Expand Down
3 changes: 1 addition & 2 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ class REDQLoss(LossModule):
the expected keyword arguments are:
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network
The return value is a tuple of tensors in the following order:
``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy",
"state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value",]``.
``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy", "state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value",]``.
Examples:
>>> import torch
Expand Down
64 changes: 58 additions & 6 deletions torchrl/record/loggers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import Dict, Optional, Sequence, Union

import torch

from tensordict import MemoryMappedTensor
from torch import Tensor

from .common import Logger
Expand All @@ -16,11 +18,13 @@
class CSVExperiment:
"""A CSV logger experiment class."""

def __init__(self, log_dir: str):
def __init__(self, log_dir: str, *, video_format="pt", video_fps=30):
self.scalars = defaultdict(lambda: [])
self.videos_counter = defaultdict(lambda: 0)
self.text_counter = defaultdict(lambda: 0)
self.log_dir = log_dir
self.video_format = video_format
self.video_fps = video_fps
os.makedirs(self.log_dir, exist_ok=True)
os.makedirs(os.path.join(self.log_dir, "scalars"), exist_ok=True)
os.makedirs(os.path.join(self.log_dir, "videos"), exist_ok=True)
Expand All @@ -44,12 +48,43 @@ def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs
if global_step is None:
global_step = self.videos_counter[tag]
self.videos_counter[tag] += 1
if self.video_format == "pt":
extension = ".pt"
elif self.video_format == "memmap":
extension = ".memmap"
elif self.video_format == "mp4":
extension = ".mp4"
else:
raise ValueError(
f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
)

filepath = os.path.join(
self.log_dir, "videos", "_".join([tag, str(global_step)]) + ".pt"
self.log_dir, "videos", "_".join([tag, str(global_step)]) + extension
)
path_to_create = Path(str(filepath)).parent
os.makedirs(path_to_create, exist_ok=True)
torch.save(vid_tensor, filepath)
if self.video_format == "pt":
torch.save(vid_tensor, filepath)
elif self.video_format == "memmap":
MemoryMappedTensor.from_tensor(vid_tensor, filename=filepath)
elif self.video_format == "mp4":
import torchvision

if vid_tensor.shape[-3] not in (3, 1):
raise RuntimeError(
"expected the video tensor to be of format [T, C, H, W] but the third channel "
f"starting from the end isn't in (1, 3) but is {vid_tensor.shape[-3]}."
)
if vid_tensor.ndim > 4:
vid_tensor = vid_tensor.flatten(0, vid_tensor.ndim - 4)
vid_tensor = vid_tensor.permute((0, 2, 3, 1))
vid_tensor = vid_tensor.expand(*vid_tensor.shape[:-1], 3)
torchvision.io.write_video(filepath, vid_tensor, fps=self.video_fps)
else:
raise ValueError(
f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
)

def add_text(self, tag, text, global_step: Optional[int] = None):
if global_step is None:
Expand Down Expand Up @@ -77,20 +112,37 @@ class CSVLogger(Logger):
Args:
exp_name (str): The name of the experiment.
log_dir (str or Path, optional): where the experiment should be saved.
Defaults to ``<cur_dir>/csv_logs``.
video_format (str, optional): how videos should be saved. Must be one of
``"pt"`` (video saved as a `video_<tag>_<step>.pt` file with torch.save),
``"memmap"`` (video saved as a `video_<tag>_<step>.memmap` file with :class:`~tensordict.MemoryMappedTensor`),
``"mp4"`` (video saved as a `video_<tag>_<step>.mp4` file, requires torchvision to be installed).
Defaults to ``"pt"``.
video_fps (int, optional): the video frames-per-seconds if `video_format="mp4"`. Defaults to 30.
"""

def __init__(self, exp_name: str, log_dir: Optional[str] = None) -> None:
def __init__(
self,
exp_name: str,
log_dir: Optional[str] = None,
video_format: str = "pt",
video_fps: int = 30,
) -> None:
if log_dir is None:
log_dir = "csv_logs"
self.video_format = video_format
self.video_fps = video_fps
super().__init__(exp_name=exp_name, log_dir=log_dir)

self._has_imported_moviepy = False

def _create_experiment(self) -> "CSVExperiment":
"""Creates a CSV experiment."""
log_dir = str(os.path.join(self.log_dir, self.exp_name))
return CSVExperiment(log_dir)
return CSVExperiment(
log_dir, video_format=self.video_format, video_fps=self.video_fps
)

def log_scalar(self, name: str, value: float, step: int = None) -> None:
"""Logs a scalar value to the tensorboard.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/record/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TensorboardLogger(Logger):
Args:
exp_name (str): The name of the experiment.
log_dir (str): the tensorboard log_dir.
log_dir (str): the tensorboard log_dir. Defaults to ``td_logs``.
"""

Expand Down
16 changes: 16 additions & 0 deletions torchrl/record/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,24 @@
class WandbLogger(Logger):
"""Wrapper for the wandb logger.
The keyword arguments are mainly based on the :func:`wandb.init` kwargs.
See the doc `here <https://docs.wandb.ai/ref/python/init>`__.
Args:
exp_name (str): The name of the experiment.
offline (bool, optional): if ``True``, the logs will be stored locally
only. Defaults to ``False``.
save_dir (path, optional): the directory where to save data. Exclusive with
``log_dir``.
log_dir (path, optional): the directory where to save data. Exclusive with
``save_dir``.
id (str, optional): A unique ID for this run, used for resuming.
It must be unique in the project, and if you delete a run you can't reuse the ID.
project (str, optional): The name of the project where you're sending
the new run. If the project is not specified, the run is put in
an ``"Uncategorized"`` project.
**kwargs: Extra keyword arguments for ``wandb.init``. See relevant page for
more info.
"""

Expand Down
26 changes: 25 additions & 1 deletion torchrl/record/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class VideoRecorder(ObservationTransform):
Args:
logger (Logger): a Logger instance where the video
should be written.
should be written. To save the video under a memmap tensor or an mp4 file, use
the :class:`~torchrl.record.loggers.CSVLogger` class.
tag (str): the video tag in the logger.
in_keys (Sequence of NestedKey, optional): keys to be read to produce the video.
Default is :obj:`"pixels"`.
Expand All @@ -43,6 +44,29 @@ class VideoRecorder(ObservationTransform):
out_keys (sequence of NestedKey, optional): destination keys. Defaults
to ``in_keys`` if not provided.
Examples:
The following example shows how to save a rollout under a video. First a few imports:
>>> from torchrl.record import VideoRecorder
>>> from torchrl.record.loggers.csv import CSVLogger
>>> from torchrl.envs import TransformedEnv, DMControlEnv
The video format is chosen in the logger. Wandb and tensorboard will take care of that
on their own, CSV accepts various video formats.
>>> logger = CSVLogger(exp_name="cheetah", log_dir="cheetah_videos", video_format="mp4")
Some envs (eg, Atari games) natively return images, some require the user to ask for them.
Check :class:`~torchrl.env.GymEnv` or :class:`~torchrl.envs.DMControlEnv` to see how to render images
in these contexts.
>>> base_env = DMControlEnv("cheetah", "run", from_pixels=True)
>>> env = TransformedEnv(base_env, VideoRecorder(logger=logger, tag="run_video"))
>>> env.rollout(100)
All transforms have a dump function, mostly a no-op except for ``VideoRecorder``, and :class:`~torchrl.envs.transforms.Composite`
which will dispatch the `dumps` to all its members.
>>> env.transform.dump()
Our video is available under ``./cheetah_videos/cheetah/videos/run_video_0.mp4``!
"""

def __init__(
Expand Down

0 comments on commit 24d14ad

Please sign in to comment.