Skip to content

Commit

Permalink
[Feature] A PixelRenderTransform (pytorch#2099)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 23, 2024
1 parent 7dd0128 commit df749a3
Show file tree
Hide file tree
Showing 12 changed files with 532 additions and 51 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
DiscreteTensorSpec
MultiDiscreteTensorSpec
MultiOneHotDiscreteTensorSpec
NonTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec
Expand Down
70 changes: 70 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,75 @@ to always know what the latest available actions are. You can do this like so:
Recorders
---------

.. _Environment-Recorders:

Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as
reporting results after training.

TorchRL offers several tools to interact with the environment output: first and foremost, a ``callback`` callable
can be passed to the :meth:`~torchrl.envs.EnvBase.rollout` method. This function will be called upon the collected
tensordict at each iteration of the rollout (if some iterations have to be skipped, an internal variable should be added
to keep track of the call count within ``callback``).

To save collected tensordicts on disk, the :class:`~torchrl.record.TensorDictRecorder` can be used.

Recording videos
~~~~~~~~~~~~~~~~

Several backends offer the possibility of recording rendered images from the environment.
If the pixels are already part of the environment output (e.g. Atari or other game simulators), a
:class:`~torchrl.record.VideoRecorder` can be appended to the environment. This environment transform takes as input
a logger capable of recording videos (e.g. :class:`~torchrl.record.loggers.CSVLogger`, :class:`~torchrl.record.loggers.WandbLogger`
or :class:`~torchrl.record.loggers.TensorBoardLogger`) as well as a tag indicating where the video should be saved.
For instance, to save mp4 videos on disk, one can use :class:`~torchrl.record.loggers.CSVLogger` with a `video_format="mp4"`
argument.

The :class:`~torchrl.record.VideoRecorder` transform can handle batched images and automatically detects numpy or PyTorch
formatted images (WHC or CWH).

>>> logger = CSVLogger("dummy-exp", video_format="mp4")
>>> env = GymEnv("ALE/Pong-v5")
>>> env = env.append_transform(VideoRecorder(logger, tag="rendered", in_keys=["pixels"]))
>>> env.rollout(10)
>>> env.transform.dump() # Save the video and clear cache

Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to
take care of calling dumpy when needed to avoid OOM issues.

In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible
(some libraries only allow one environment instance per workspace).
In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform`
can be used to call `render` on the parent environment and save the images in the rollout data stream.
This class works over single and batched environments alike:

>>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator
>>> from torchrl.record.loggers import CSVLogger
>>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
>>>
>>> def make_env():
>>> env = GymEnv("CartPole-v1", render_mode="rgb_array")
>>> # Uncomment this line to execute per-env
>>> # env = env.append_transform(PixelRenderTransform())
>>> return env
>>>
>>> if __name__ == "__main__":
... logger = CSVLogger("dummy", video_format="mp4")
...
... env = ParallelEnv(16, EnvCreator(make_env))
... env.start()
... # Comment this line to execute per-env
... env = env.append_transform(PixelRenderTransform())
...
... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
... env.rollout(3)
...
... check_env_specs(env)
...
... r = env.rollout(30)
... env.transform.dump()
... env.close()


.. currentmodule:: torchrl.record

Recorders are transforms that register data as they come in, for logging purposes.
Expand All @@ -769,6 +838,7 @@ Recorders are transforms that register data as they come in, for logging purpose

TensorDictRecorder
VideoRecorder
PixelRenderTransform


Helpers
Expand Down
63 changes: 33 additions & 30 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ loop the optimization steps. We believe this fits multiple RL training schemes,
on-policy, off-policy, model-based and model-free solutions, offline RL and others.
More particular cases, such as meta-RL algorithms may have training schemes that differ substentially.

The :obj:`trainer.train()` method can be sketched as follows:
The ``trainer.train()`` method can be sketched as follows:

.. code-block::
:caption: Trainer loops
Expand Down Expand Up @@ -63,35 +63,35 @@ The :obj:`trainer.train()` method can be sketched as follows:
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`,
:obj:`"process_optim_batch"`, :obj:`"post_loss"`, :obj:`"post_steps"`, :obj:`"post_optim"`, :obj:`"pre_steps_log"`,
:obj:`"post_steps_log"`, :obj:`"post_optim_log"` and :obj:`"optimizer"`. They are indicated in the comments where they are applied.
Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"` and :obj:`"process_optim_batch"`),
**logging** (:obj:`"pre_steps_log"`, :obj:`"post_optim_log"` and :obj:`"post_steps_log"`) and **operations** hook
(:obj:`"pre_optim_steps"`, :obj:`"post_loss"`, :obj:`"post_optim"` and :obj:`"post_steps"`).

- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept
a :obj:`TensorDict` object as input and update it given some strategy.
Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization
constants update), data subsampling (:class:`~torchrl.trainers.BatchSubSampler`) and such.

- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger
some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward
logger (:obj:`LogReward`) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key :obj:`"log_pbar"` is reserved to boolean values indicating if the logged value
There are 10 hooks that can be used in a trainer loop: ``"batch_process"``, ``"pre_optim_steps"``,
``"process_optim_batch"``, ``"post_loss"``, ``"post_steps"``, ``"post_optim"``, ``"pre_steps_log"``,
``"post_steps_log"``, ``"post_optim_log"`` and ``"optimizer"``. They are indicated in the comments where they are applied.
Hooks can be split into 3 categories: **data processing** (``"batch_process"`` and ``"process_optim_batch"``),
**logging** (``"pre_steps_log"``, ``"post_optim_log"`` and ``"post_steps_log"``) and **operations** hook
(``"pre_optim_steps"``, ``"post_loss"``, ``"post_optim"`` and ``"post_steps"``).

- **Data processing** hooks update a tensordict of data. Hooks ``__call__`` method should accept
a ``TensorDict`` object as input and update it given some strategy.
Examples of such hooks include Replay Buffer extension (``ReplayBufferTrainer.extend``), data normalization (including normalization
constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.

- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``Recorder`` hook, the reward
logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.

- **Operation** hooks are hooks that execute specific operations over the models, data collectors,
target network updates and such. For instance, syncing the weights of the collectors using :obj:`UpdateWeights`
or update the priority of the replay buffer using :obj:`ReplayBufferTrainer.update_priority` are examples
of operation hooks. They are data-independent (they do not require a :obj:`TensorDict`
target network updates and such. For instance, syncing the weights of the collectors using ``UpdateWeights``
or update the priority of the replay buffer using ``ReplayBufferTrainer.update_priority`` are examples
of operation hooks. They are data-independent (they do not require a ``TensorDict``
input), they are just supposed to be executed once at every iteration (or every N iterations).

The hooks provided by TorchRL usually inherit from a common abstract class :obj:`TrainerHookBase`,
and all implement three base methods: a :obj:`state_dict` and :obj:`load_state_dict` method for
checkpointing and a :obj:`register` method that registers the hook at the default value in the
The hooks provided by TorchRL usually inherit from a common abstract class ``TrainerHookBase``,
and all implement three base methods: a ``state_dict`` and ``load_state_dict`` method for
checkpointing and a ``register`` method that registers the hook at the default value in the
trainer. This method takes a trainer and a module name as input. For instance, the following logging
hook is executed every 10 calls to :obj:`"post_optim_log"`:
hook is executed every 10 calls to ``"post_optim_log"``:

.. code-block::
Expand Down Expand Up @@ -122,22 +122,22 @@ Checkpointing
-------------

The trainer class and hooks support checkpointing, which can be achieved either
using the `torchsnapshot <https://github.com/pytorch/torchsnapshot/>`_ backend or
the regular torch backend. This can be controlled via the global variable :obj:`CKPT_BACKEND`:
using the ``torchsnapshot <https://github.com/pytorch/torchsnapshot/>``_ backend or
the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``:

.. code-block::
$ CKPT_BACKEND=torch python script.py
which defaults to :obj:`torchsnapshot`. The advantage of torchsnapshot over pytorch
which defaults to ``torchsnapshot``. The advantage of torchsnapshot over pytorch
is that it is a more flexible API, which supports distributed checkpointing and
also allows users to load tensors from a file stored on disk to a tensor with a
physical storage (which pytorch currently does not support). This allows, for instance,
to load tensors from and to a replay buffer that would otherwise not fit in memory.

When building a trainer, one can provide a file path where the checkpoints are to
be written. With the :obj:`torchsnapshot` backend, a directory path is expected,
whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` file).
be written. With the ``torchsnapshot`` backend, a directory path is expected,
whereas the ``torch`` backend expects a file path (typically a ``.pt`` file).

.. code-block::
Expand All @@ -157,7 +157,7 @@ whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` fi
>>> # to load from a path
>>> trainer.load_from_file(filepath)
The :obj:`Trainer.train()` method can be used to execute the above loop with all of
The ``Trainer.train()`` method can be used to execute the above loop with all of
its hooks, although using the :obj:`Trainer` class for its checkpointing capability
only is also a perfectly valid use.

Expand Down Expand Up @@ -238,6 +238,8 @@ Loggers
Recording utils
---------------

Recording utils are detailed :ref:`here <Environment-Recorders>`.

.. currentmodule:: torchrl.record

.. autosummary::
Expand All @@ -246,3 +248,4 @@ Recording utils

VideoRecorder
TensorDictRecorder
PixelRenderTransform
40 changes: 39 additions & 1 deletion test/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import importlib.util
import os
import os.path
import pathlib
Expand All @@ -12,12 +13,14 @@

import pytest
import torch

from tensordict import MemoryMappedTensor

from torchrl.envs import check_env_specs, GymEnv, ParallelEnv
from torchrl.record.loggers.csv import CSVLogger
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
from torchrl.record.loggers.wandb import _has_wandb, WandbLogger
from torchrl.record.recorder import PixelRenderTransform, VideoRecorder

if _has_tv:
import torchvision
Expand All @@ -28,6 +31,11 @@
if _has_mlflow:
import mlflow

_has_gym = (
importlib.util.find_spec("gym", None) is not None
or importlib.util.find_spec("gymnasium", None) is not None
)


@pytest.fixture
def tb_logger(tmp_path_factory):
Expand Down Expand Up @@ -397,6 +405,36 @@ def test_log_hparams(self, mlflow_fixture, config):
logger.log_hparams(config)


@pytest.mark.skipif(not _has_gym, reason="gym required to test rendering")
class TestPixelRenderTransform:
@pytest.mark.parametrize("parallel", [False, True])
@pytest.mark.parametrize("in_key", ["pixels", ("nested", "pix")])
def test_pixel_render(self, parallel, in_key, tmpdir):
def make_env():
env = GymEnv("CartPole-v1", render_mode="rgb_array", device=None)
env = env.append_transform(PixelRenderTransform(out_keys=in_key))
return env

if parallel:
env = ParallelEnv(2, make_env, mp_start_method="spawn")
else:
env = make_env()
logger = CSVLogger("dummy", log_dir=tmpdir)
try:
env = env.append_transform(
VideoRecorder(logger=logger, in_keys=[in_key], tag="pixels_record")
)
check_env_specs(env)
env.rollout(10)
env.transform.dump()
assert os.path.isfile(
os.path.join(tmpdir, "dummy", "videos", "pixels_record_0.pt")
)
finally:
if not env.is_closed:
env.close()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
28 changes: 28 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
LazyStackedCompositeSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
NonTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
Expand Down Expand Up @@ -1462,6 +1463,14 @@ def test_multionehot(self, shape1, shape2):
assert spec2.rand().shape == spec2.shape
assert spec2.zero().shape == spec2.shape

def test_non_tensor(self):
spec = NonTensorSpec((3, 4), device="cpu")
assert (
spec.expand(2, 3, 4)
== spec.expand((2, 3, 4))
== NonTensorSpec((2, 3, 4), device="cpu")
)

@pytest.mark.parametrize("shape1", [None, (), (5,)])
@pytest.mark.parametrize("shape2", [(), (10,)])
def test_onehot(self, shape1, shape2):
Expand Down Expand Up @@ -1675,6 +1684,11 @@ def test_multionehot(
assert spec == spec.clone()
assert spec is not spec.clone()

def test_non_tensor(self):
spec = NonTensorSpec(shape=(3, 4), device="cpu")
assert spec.clone() == spec
assert spec.clone() is not spec

@pytest.mark.parametrize("shape1", [None, (), (5,)])
def test_onehot(
self,
Expand Down Expand Up @@ -1840,6 +1854,11 @@ def test_multionehot(
with pytest.raises(ValueError):
spec.unbind(-1)

def test_non_tensor(self):
spec = NonTensorSpec(shape=(3, 4), device="cpu")
assert spec.unbind(1)[0] == spec[:, 0]
assert spec.unbind(1)[0] is not spec[:, 0]

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(
self,
Expand Down Expand Up @@ -2114,6 +2133,15 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
r = c.zero()
assert r.shape == c.shape

def test_stack_non_tensor(self, shape, stack_dim):
spec0 = NonTensorSpec(shape=shape, device="cpu")
spec1 = NonTensorSpec(shape=shape, device="cpu")
new_spec = torch.stack([spec0, spec1], stack_dim)
shape_insert = list(shape)
shape_insert.insert(stack_dim, 2)
assert new_spec.shape == torch.Size(shape_insert)
assert new_spec.device == torch.device("cpu")

def test_stack_onehot(self, shape, stack_dim):
n = 5
shape = (*shape, 5)
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
LazyStackedTensorSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
NonTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
Expand Down
Loading

0 comments on commit df749a3

Please sign in to comment.