Skip to content

Commit

Permalink
[Doc] Getting started tutos (pytorch#1886)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 10, 2024
1 parent 4d52d5f commit 601867f
Show file tree
Hide file tree
Showing 30 changed files with 1,478 additions and 109 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ On the low-level end, torchrl comes with a set of highly re-usable functionals f

TorchRL aims at (1) a high modularity and (2) good runtime performance. Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library.

## Getting started

Check our [Getting Started tutorials](https://pytorch.org/rl/index.html#getting-started) for quickly ramp up with the basic
features of the library!

## Documentation and knowledge base

The TorchRL documentation can be found [here](https://pytorch.org/rl).
Expand Down
17 changes: 17 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ or via a ``git clone`` if you're willing to contribute to the library:
$ cd ../rl
$ python setup.py develop
Getting started
===============

A series of quick tutorials to get ramped up with the basic features of the
library. If you're in a hurry, you can start by
:ref:`the last item of the series <gs_first_training>`
and navigate to the previous ones whenever you want to learn more!

.. toctree::
:maxdepth: 1

tutorials/getting-started-0
tutorials/getting-started-1
tutorials/getting-started-2
tutorials/getting-started-3
tutorials/getting-started-4
tutorials/getting-started-5

Tutorials
=========
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
torchrl.collectors package
==========================

.. _ref_collectors:

Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they
collect data over non-static data sources and (2) the data is collected using a model
(likely a version of the model that is being trained).
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
torchrl.data package
====================

.. _ref_data:

Replay Buffers
--------------

Expand Down
3 changes: 3 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ single agent standards.

Transforms
----------

.. _transforms:

.. currentmodule:: torchrl.envs.transforms

In most cases, the raw output of an environment must be treated before being passed to another object (such as a
Expand Down
4 changes: 4 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
torchrl.modules package
=======================

.. _ref_modules:

TensorDict modules: Actors, exploration, value models and generative models
---------------------------------------------------------------------------

.. _tdmodules:

TorchRL offers a series of module wrappers aimed at making it easy to build
RL models from the ground up. These wrappers are exclusively based on
:class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential`.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
torchrl.objectives package
==========================

.. _ref_objectives:

TorchRL provides a series of losses to use in your training scripts.
The aim is to have losses that are easily reusable/swappable and that have
a simple signature.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ Utils
Loggers
-------

.. _ref_loggers:

.. currentmodule:: torchrl.record.loggers

.. autosummary::
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def __init__(
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
traj_key = "run"
traj_key = "episode"
self.end_key = end_key
self.traj_key = traj_key

Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,8 +2055,8 @@ def reset(

tensordict_reset = self._reset(tensordict, **kwargs)
# We assume that this is done properly
# if tensordict_reset.device != self.device:
# tensordict_reset = tensordict_reset.to(self.device, non_blocking=True)
# if reset.device != self.device:
# reset = reset.to(self.device, non_blocking=True)
if tensordict_reset is tensordict:
raise RuntimeError(
"EnvBase._reset should return outplace changes to the input "
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
return tensordict_reset

def _reset_proc_data(self, tensordict, tensordict_reset):
# self._complete_done(self.full_done_spec, tensordict_reset)
# self._complete_done(self.full_done_spec, reset)
self._reset_check_done(tensordict, tensordict_reset)
if tensordict is not None:
tensordict_reset = _update_during_reset(
Expand All @@ -802,7 +802,7 @@ def _reset_proc_data(self, tensordict, tensordict_reset):
# # doesn't do anything special
# mt_mode = self.transform.missing_tolerance
# self.set_missing_tolerance(True)
# tensordict_reset = self.transform._call(tensordict_reset)
# reset = self.transform._call(reset)
# self.set_missing_tolerance(mt_mode)
return tensordict_reset

Expand Down
52 changes: 29 additions & 23 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
class Actor(SafeModule):
"""General class for deterministic actors in RL.
The Actor class comes with default values for the out_keys (["action"])
and if the spec is provided but not as a CompositeSpec object, it will be
automatically translated into :obj:`spec = CompositeSpec(action=spec)`
The Actor class comes with default values for the out_keys (``["action"]``)
and if the spec is provided but not as a
:class:`~torchrl.data.CompositeSpec` object, it will be
automatically translated into ``spec = CompositeSpec(action=spec)``.
Args:
module (nn.Module): a :class:`torch.nn.Module` used to map the input to
module (nn.Module): a :class:`~torch.nn.Module` used to map the input to
the output parameter space.
in_keys (iterable of str, optional): keys to be read from input
tensordict and passed to the module. If it
Expand All @@ -47,9 +48,11 @@ class Actor(SafeModule):
Defaults to ``["observation"]``.
out_keys (iterable of str): keys to be written to the input tensordict.
The length of out_keys must match the
number of tensors returned by the embedded module. Using "_" as a
number of tensors returned by the embedded module. Using ``"_"`` as a
key avoid writing tensor to output.
Defaults to ``["action"]``.
Keyword Args:
spec (TensorSpec, optional): Keyword-only argument.
Specs of the output tensor. If the module
outputs multiple output tensors,
Expand All @@ -59,7 +62,7 @@ class Actor(SafeModule):
input spec. Out-of-domain sampling can
occur because of exploration policies or numerical under/overflow
issues. If this value is out of bounds, it is projected back onto the
desired space using the :obj:`TensorSpec.project`
desired space using the :meth:`~torchrl.data.TensorSpec.project`
method. Default is ``False``.
Examples:
Expand Down Expand Up @@ -148,17 +151,23 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
issues. If this value is out of bounds, it is projected back onto the
desired space using the :obj:`TensorSpec.project`
method. Default is ``False``.
default_interaction_type=InteractionType.RANDOM (str, optional): keyword-only argument.
default_interaction_type (str, optional): keyword-only argument.
Default method to be used to retrieve
the output value. Should be one of: 'mode', 'median', 'mean' or 'random'
(in which case the value is sampled randomly from the distribution). Default
is 'mode'.
Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will
first look for the interaction mode dictated by the `interaction_typ()`
global function. If this returns `None` (its default value), then the
`default_interaction_type` of the `ProbabilisticTDModule` instance will be
used. Note that DataCollector instances will use `set_interaction_type` to
:class:`tensordict.nn.InteractionType.RANDOM` by default.
the output value. Should be one of: 'InteractionType.MODE',
'InteractionType.MEDIAN', 'InteractionType.MEAN' or
'InteractionType.RANDOM' (in which case the value is sampled
randomly from the distribution). Defaults to is 'InteractionType.RANDOM'.
.. note:: When a sample is drawn, the :class:`ProbabilisticActor` instance will
first look for the interaction mode dictated by the
:func:`~tensordict.nn.probabilistic.interaction_type`
global function. If this returns `None` (its default value), then the
`default_interaction_type` of the `ProbabilisticTDModule`
instance will be used. Note that
:class:`~torchrl.collectors.collectors.DataCollectorBase`
instances will use `set_interaction_type` to
:class:`tensordict.nn.InteractionType.RANDOM` by default.
distribution_class (Type, optional): keyword-only argument.
A :class:`torch.distributions.Distribution` class to
be used for sampling.
Expand Down Expand Up @@ -197,9 +206,7 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
... in_keys=["loc", "scale"],
... distribution_class=TanhNormal,
... )
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
... td = td_module(td)
>>> td = td_module(td)
>>> td
TensorDict(
fields={
Expand Down Expand Up @@ -315,7 +322,8 @@ class ValueOperator(TensorDictModule):
The length of out_keys must match the
number of tensors returned by the embedded module. Using "_" as a
key avoid writing tensor to output.
Defaults to ``["action"]``.
Defaults to ``["state_value"]`` or
``["state_action_value"]`` if ``"action"`` is part of the ``in_keys``.
Examples:
>>> import torch
Expand All @@ -334,9 +342,7 @@ class ValueOperator(TensorDictModule):
>>> td_module = ValueOperator(
... in_keys=["observation", "action"], module=module
... )
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
... td = td_module(td)
>>> td = td_module(td)
>>> print(td)
TensorDict(
fields={
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ def __init__(
try:
action_space = value_network.action_space
except AttributeError:
raise ValueError(self.ACTION_SPEC_ERROR)
raise ValueError(
"The action space could not be retrieved from the value_network. "
"Make sure it is available to the DQN loss module."
)
if action_space is None:
warnings.warn(
"action_space was not specified. DQNLoss will default to 'one-hot'."
Expand Down
3 changes: 1 addition & 2 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ def __init__(
):
if eps is None and tau is None:
raise RuntimeError(
"Neither eps nor tau was provided. " "This behaviour is deprecated.",
category=DeprecationWarning,
"Neither eps nor tau was provided. This behaviour is deprecated.",
)
eps = 0.999
if (eps is None) ^ (tau is None):
Expand Down
4 changes: 3 additions & 1 deletion torchrl/record/loggers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import os
from collections import defaultdict
from pathlib import Path
Expand Down Expand Up @@ -126,7 +128,7 @@ class CSVLogger(Logger):
def __init__(
self,
exp_name: str,
log_dir: Optional[str] = None,
log_dir: str | None = None,
video_format: str = "pt",
video_fps: int = 30,
) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tutorials/sphinx-tutorials/README.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
README Tutos
============

Check the tutorials on torchrl documentation: https://pytorch.org/rl
2 changes: 2 additions & 0 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
======================================
**Author**: `Vincent Moens <https://github.com/vmoens>`_
.. _coding_ddpg:
"""

##############################################################################
Expand Down
8 changes: 5 additions & 3 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
==============================
**Author**: `Vincent Moens <https://github.com/vmoens>`_
.. _coding_dqn:
"""

##############################################################################
Expand Down Expand Up @@ -404,9 +406,9 @@ def get_replay_buffer(buffer_size, n_optim, batch_size):
# environment executed in parallel in each collector (controlled by the
# ``num_workers`` hyperparameter).
#
# When building the collector, we can choose on which device we want the
# environment and policy to execute the operations through the ``device``
# keyword argument. The ``storing_devices`` argument will modify the
# Collector's devices are fully parametrizable through the ``device`` (general),
# ``policy_device``, ``env_device`` and ``storing_device`` arguments.
# The ``storing_device`` argument will modify the
# location of the data being collected: if the batches that we are gathering
# have a considerable size, we may want to store them on a different location
# than the device where the computation is happening. For asynchronous data
Expand Down
2 changes: 2 additions & 0 deletions tutorials/sphinx-tutorials/coding_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
==================================================
**Author**: `Vincent Moens <https://github.com/vmoens>`_
.. _coding_ppo:
This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to train a parametric policy
network to solve the Inverted Pendulum task from the `OpenAI-Gym/Farama-Gymnasium
control library <https://github.com/Farama-Foundation/Gymnasium>`__.
Expand Down
2 changes: 2 additions & 0 deletions tutorials/sphinx-tutorials/dqn_with_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
**Author**: `Vincent Moens <https://github.com/vmoens>`_
.. _RNN_tuto:
.. grid:: 2
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
Expand Down
Loading

0 comments on commit 601867f

Please sign in to comment.