Skip to content

Commit

Permalink
[Doc] DDPG and DQN refactoring -- Doc cleaning (#1036)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 11, 2023
1 parent 80beaaa commit 4f01f1b
Show file tree
Hide file tree
Showing 44 changed files with 1,602 additions and 5,548 deletions.
Binary file added docs/source/_static/img/replaybuffer_traj.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3,824 changes: 2 additions & 3,822 deletions docs/source/_static/js/theme.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
intersphinx_mapping = {
"torch": ("https://pytorch.org/docs/stable/", None),
"tensordict": ("https://pytorch-labs.github.io/tensordict/", None),
"torchrl": ("https://pytorch.org/rl/", None),
# "torchrl": ("https://pytorch.org/rl/", None),
"torchaudio": ("https://pytorch.org/audio/stable/", None),
"torchtext": ("https://pytorch.org/text/stable/", None),
"torchvision": ("https://pytorch.org/vision/stable/", None),
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ node or across multiple nodes.

.. note::
*Choosing the sub-collector*: All distributed collectors support the various single machine collectors.
One may wonder why using a :class:`MultiSyncDataCollector` or a :class:`torchrl.envs.ParallelEnv`
One may wonder why using a :class:`MultiSyncDataCollector` or a :class:`~torchrl.envs.ParallelEnv`
instead. In general, multiprocessed collectors have a lower IO footprint than
parallel environments which need to communicate at each step. Yet, the model specs
play a role in the opposite direction, since using parallel environments will
Expand Down
8 changes: 4 additions & 4 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ does not mean storing 1M frames but 1M trajectories.

When sampling trajectories, it may be desirable to sample sub-trajectories
to diversify learning or make the sampling more efficient.
To do this, we provide a custom :class:`torchrl.envs.Transform` class named
:class:`torchrl.envs.RandomCropTensorDict`. Here is an example of how this class
To do this, we provide a custom :class:`~torchrl.envs.Transform` class named
:class:`~torchrl.envs.RandomCropTensorDict`. Here is an example of how this class
can be used:

.. code-block::Python
Expand Down Expand Up @@ -104,7 +104,7 @@ Datasets
--------

TorchRL provides wrappers around offline RL datasets.
These data are presented a :class:`torchrl.data.ReplayBuffer` instances, which
These data are presented a :class:`~torchrl.data.ReplayBuffer` instances, which
means that they can be customized at will with transforms, samplers and storages.
By default, datasets are stored as memory mapped tensors, allowing them to be
promptly sampled with virtually no memory footprint.
Expand Down Expand Up @@ -218,7 +218,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
Utils
-----

.. currentmodule:: torchrl.data.datasets
.. currentmodule:: torchrl.data

.. autosummary::
:toctree: generated/
Expand Down
19 changes: 9 additions & 10 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The goal is to be able to swap environments in an experiment with little or no e
even if these environments are simulated using different libraries.
TorchRL offers some out-of-the-box environment wrappers under :obj:`torchrl.envs.libs`,
which we hope can be easily imitated for other libraries.
The parent class :class:`torchrl.envs.EnvBase` is a :class:`torch.nn.Module` subclass that implements
The parent class :class:`~torchrl.envs.EnvBase` is a :class:`torch.nn.Module` subclass that implements
some typical environment methods using :class:`tensordict.TensorDict` as a data organiser. This allows this
class to be generic and to handle an arbitrary number of input and outputs, as well as
nested or batched data structures.
Expand All @@ -26,18 +26,18 @@ Each env will have the following attributes:
This is especially useful for transforms (see below). For parametric environments (e.g.
model-based environments), the device does represent the hardware that will be used to
compute the operations.
- :obj:`env.input_spec`: a :class:`torchrl.data.CompositeSpec` object containing
- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
all the input keys (:obj:`"action"` and others).
- :obj:`env.output_spec`: a :class:`torchrl.data.CompositeSpec` object containing
- :obj:`env.output_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
all the output keys (:obj:`"observation"`, :obj:`"reward"` and :obj:`"done"`).
- :obj:`env.observation_spec`: a :class:`torchrl.data.CompositeSpec` object
- :obj:`env.observation_spec`: a :class:`~torchrl.data.CompositeSpec` object
containing all the observation key-spec pairs.
This is a pointer to ``env.output_spec["observation"]``.
- :obj:`env.action_spec`: a :class:`torchrl.data.TensorSpec` object
- :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object
representing the action spec. This is a pointer to ``env.input_spec["action"]``.
- :obj:`env.reward_spec`: a :class:`torchrl.data.TensorSpec` object representing
- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing
the reward spec. This is a pointer to ``env.output_spec["reward"]``.
- :obj:`env.done_spec`: a :class:`torchrl.data.TensorSpec` object representing
- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing
the done-flag spec. This is a pointer to ``env.output_spec["done"]``.

Importantly, the environment spec shapes should contain the batch size, e.g.
Expand Down Expand Up @@ -100,7 +100,7 @@ function.
.. note::

In some contexts, it can be useful to mark the first step of a trajectory.
TorchRL provides such functionality through the :class:`torchrl.envs.InitTracker`
TorchRL provides such functionality through the :class:`~torchrl.envs.InitTracker`
transform.


Expand All @@ -114,7 +114,6 @@ provides more information on how to design a custom environment from scratch.
EnvBase
GymLikeEnv
EnvMetaData
Specs

Vectorized envs
---------------
Expand All @@ -132,7 +131,7 @@ Of course, a :class:`ParallelEnv` will have a batch size that corresponds to its

It is important that your environment specs match the input and output that it sends and receives, as
:class:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :func:`torchrl.envs.utils.check_env_specs` method for a sanity check.
Check the :func:`~torchrl.envs.utils.check_env_specs` method for a sanity check.

.. code-block::
:caption: Parallel environment
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ TensorDict modules

Hooks
-----
.. currentmodule:: torchrl.modules.tensordict_module.actors
.. currentmodule:: torchrl.modules

.. autosummary::
:toctree: generated/
Expand Down
10 changes: 6 additions & 4 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ The main characteristics of TorchRL losses are:
method will receive a tensordict as input that contains all the necessary
information to return a loss value.
- They output a :class:`tensordict.TensorDict` instance with the loss values
written under a ``"loss_<smth>`` where ``smth`` is a string describing the
written under a ``"loss_<smth>"`` where ``smth`` is a string describing the
loss. Additional keys in the tensordict may be useful metrics to log during
training time.
.. note::
The reason we return independent losses is to let the user use a different
optimizer for different sets of parameters for instance. Summing the losses
can be simply done via ``sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")``.
can be simply done via

>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))

Training value functions
------------------------
Expand Down Expand Up @@ -68,7 +70,7 @@ follow a similar structure:
>>> kwargs = {"gamma": 0.9, "lmbda": 0.9}
>>> loss_module.make_value_estimator(ValueEstimators.TDLambda, **kwargs)

The :class:`torchrl.objectives.ValueEstimators` class enumerates the value
The :class:`~torchrl.objectives.ValueEstimators` class enumerates the value
estimators to choose from. This makes it easy for the users to rely on
auto-completion to make their choice.

Expand Down Expand Up @@ -216,5 +218,5 @@ Utils
next_state_value
SoftUpdate
HardUpdate
ValueFunctions
ValueEstimators
default_value_kwargs
2 changes: 1 addition & 1 deletion docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"
- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept
a :obj:`TensorDict` object as input and update it given some strategy.
Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization
constants update), data subsampling (:doc:`BatchSubSampler`) and such.
constants update), data subsampling (:class:`~torchrl.trainers.BatchSubSampler`) and such.

- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger
some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward
Expand Down
12 changes: 11 additions & 1 deletion test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
import time

import pytest
import ray

try:
import ray

_has_ray = True
RAY_ERR = None
except ModuleNotFoundError as err:
_has_ray = False
RAY_ERR = err

import torch

from mocking_classes import ContinuousActionVecMockEnv, CountingEnv
Expand Down Expand Up @@ -418,6 +427,7 @@ def test_distributed_collector_updatepolicy(self, collector_class, update_interv
queue.close()


@pytest.mark.skipif(not _has_ray, reason=f"Ray not found (error: {RAY_ERR})")
class TestRayCollector(DistributedCollectorBase):
"""A testing distributed data collector class that runs tests without using a Queue,
to avoid potential deadlocks when combining Ray and multiprocessing.
Expand Down
17 changes: 8 additions & 9 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ class MockingLossModule(nn.Module):

def mocking_trainer(file=None, optimizer=_mocking_optim) -> Trainer:
trainer = Trainer(
MockingCollector(),
*[
None,
]
* 2,
collector=MockingCollector(),
total_frames=None,
frame_skip=None,
optim_steps_per_batch=None,
loss_module=MockingLossModule(),
optimizer=optimizer,
save_trainer_file=file,
Expand Down Expand Up @@ -862,7 +861,7 @@ def test_recorder(self, N=8):
with tempfile.TemporaryDirectory() as folder:
logger = TensorboardLogger(exp_name=folder)

recorder = transformed_env_constructor(
environment = transformed_env_constructor(
args,
video_tag="tmp",
norm_obs_only=True,
Expand All @@ -874,7 +873,7 @@ def test_recorder(self, N=8):
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
recorder=recorder,
environment=environment,
record_interval=args.record_interval,
)
trainer = mocking_trainer()
Expand Down Expand Up @@ -936,7 +935,7 @@ def _make_recorder_and_trainer(tmpdirname):
raise NotImplementedError
trainer = mocking_trainer(file)

recorder = transformed_env_constructor(
environment = transformed_env_constructor(
args,
video_tag="tmp",
norm_obs_only=True,
Expand All @@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname):
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
recorder=recorder,
environment=environment,
record_interval=args.record_interval,
)
recorder.register(trainer)
Expand Down
1 change: 1 addition & 0 deletions torchrl/collectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
DataCollectorBase,
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
SyncDataCollector,
)
43 changes: 19 additions & 24 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,22 @@


class RandomPolicy:
"""A random policy for data collectors."""

def __init__(self, action_spec: TensorSpec):
"""Random policy for a given action_spec.
This is a wrapper around the action_spec.rand method.
"""A random policy for data collectors.
This is a wrapper around the action_spec.rand method.
$ python example_google.py
Args:
action_spec: TensorSpec object describing the action specs
Args:
action_spec: TensorSpec object describing the action specs
Examples:
>>> from tensordict import TensorDict
>>> from torchrl.data.tensor_specs import BoundedTensorSpec
>>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3))
>>> actor = RandomPolicy(spec=action_spec)
>>> td = actor(TensorDict(batch_size=[])) # selects a random action in the cube [-1; 1]
Examples:
>>> from tensordict import TensorDict
>>> from torchrl.data.tensor_specs import BoundedTensorSpec
>>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3))
>>> actor = RandomPolicy(spec=action_spec)
>>> td = actor(TensorDict(batch_size=[])) # selects a random action in the cube [-1; 1]
"""

"""
def __init__(self, action_spec: TensorSpec):
self.action_spec = action_spec

def __call__(self, td: TensorDictBase) -> TensorDictBase:
Expand Down Expand Up @@ -339,11 +334,11 @@ class SyncDataCollector(DataCollectorBase):
Args:
create_env_fn (Callable): a callable that returns an instance of
:class:`torchrl.envs.EnvBase` class.
:class:`~torchrl.envs.EnvBase` class.
policy (Callable): Policy to be executed in the environment.
Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
If ``None`` is provided, the policy used will be a
:class:`RandomPolicy` instance with the environment
:class:`~torchrl.collectors.RandomPolicy` instance with the environment
``action_spec``.
frames_per_batch (int): A keyword-only argument representing the total
number of elements in a batch.
Expand Down Expand Up @@ -383,12 +378,12 @@ class SyncDataCollector(DataCollectorBase):
at the beginning of a batch collection.
Defaults to ``False``.
postproc (Callable, optional): A post-processing transform, such as
a :class:`torchrl.envs.Transform` or a :class:`torchrl.data.postprocs.MultiStep`
a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
instance.
Defaults to ``None``.
split_trajs (bool, optional): Boolean indicating whether the resulting
TensorDict should be split according to the trajectories.
See :func:`torchrl.collectors.utils.split_trajectories` for more
See :func:`~torchrl.collectors.utils.split_trajectories` for more
information.
Defaults to ``False``.
exploration_mode (str, optional): interaction mode to be used when
Expand Down Expand Up @@ -936,7 +931,7 @@ class _MultiDataCollector(DataCollectorBase):
Args:
create_env_fn (List[Callabled]): list of Callables, each returning an
instance of :class:`torchrl.envs.EnvBase`.
instance of :class:`~torchrl.envs.EnvBase`.
policy (Callable, optional): Instance of TensorDictModule class.
Must accept TensorDictBase object as input.
If ``None`` is provided, the policy used will be a
Expand Down Expand Up @@ -987,12 +982,12 @@ class _MultiDataCollector(DataCollectorBase):
at the beginning of a batch collection.
Defaults to ``False``.
postproc (Callable, optional): A post-processing transform, such as
a :class:`torchrl.envs.Transform` or a :class:`torchrl.data.postprocs.MultiStep`
a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
instance.
Defaults to ``None``.
split_trajs (bool, optional): Boolean indicating whether the resulting
TensorDict should be split according to the trajectories.
See :func:`torchrl.collectors.utils.split_trajectories` for more
See :func:`~torchrl.collectors.utils.split_trajectories` for more
information.
Defaults to ``False``.
exploration_mode (str, optional): interaction mode to be used when
Expand Down
16 changes: 8 additions & 8 deletions torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class DistributedDataCollector(DataCollectorBase):
Args:
create_env_fn (Callable or List[Callabled]): list of Callables, each returning an
instance of :class:`torchrl.envs.EnvBase`.
instance of :class:`~torchrl.envs.EnvBase`.
policy (Callable): Policy to be executed in the environment.
Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
If ``None`` is provided, the policy used will be a
Expand Down Expand Up @@ -275,12 +275,12 @@ class DistributedDataCollector(DataCollectorBase):
at the beginning of a batch collection.
Defaults to ``False``.
postproc (Callable, optional): A post-processing transform, such as
a :class:`torchrl.envs.Transform` or a :class:`torchrl.data.postprocs.MultiStep`
a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
instance.
Defaults to ``None``.
split_trajs (bool, optional): Boolean indicating whether the resulting
TensorDict should be split according to the trajectories.
See :func:`torchrl.collectors.utils.split_trajectories` for more
See :func:`~torchrl.collectors.utils.split_trajectories` for more
information.
Defaults to ``False``.
exploration_mode (str, optional): interaction mode to be used when
Expand All @@ -291,12 +291,12 @@ class DistributedDataCollector(DataCollectorBase):
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
collector_class (type or str, optional): a collector class for the remote node. Can be
:class:`torchrl.collectors.SyncDataCollector`,
:class:`torchrl.collectors.MultiSyncDataCollector`,
:class:`torchrl.collectors.MultiaSyncDataCollector`
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
:class:`~torchrl.collectors.MultiaSyncDataCollector`
or a derived class of these. The strings "single", "sync" and
"async" correspond to respective class.
Defaults to :class:`torchrl.collectors.SyncDataCollector`.
Defaults to :class:`~torchrl.collectors.SyncDataCollector`.
collector_kwargs (dict or list, optional): a dictionary of parameters to be passed to the
remote data-collector. If a list is provided, each element will
correspond to an individual set of keyword arguments for the
Expand Down Expand Up @@ -327,7 +327,7 @@ class DistributedDataCollector(DataCollectorBase):
updated.
Defaults to ``False``, ie. updates have to be executed manually
through
:meth:`torchrl.collectors.distributed.DistributedDataCollector.update_policy_weights_`.
:meth:`~torchrl.collectors.distributed.DistributedDataCollector.update_policy_weights_`.
max_weight_update_interval (int, optional): the maximum number of
batches that can be collected before the policy weights of a worker
is updated.
Expand Down
Loading

0 comments on commit 4f01f1b

Please sign in to comment.