Skip to content

Commit

Permalink
[Refactor] Refactor losses (value function, doc, input batch size) (#987
Browse files Browse the repository at this point in the history
)
  • Loading branch information
vmoens authored Mar 28, 2023
1 parent 48208ce commit 0b2d2d8
Show file tree
Hide file tree
Showing 48 changed files with 2,612 additions and 723 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Knowledge Base
==============

.. toctree::
:maxdepth: 1
:maxdepth: 2

reference/knowledge_base

Expand Down
98 changes: 91 additions & 7 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,81 @@
torchrl.objectives package
==========================

TorchRL provides a series of losses to use in your training scripts.
The aim is to have losses that are easily reusable/swappable and that have
a simple signature.

The main characteristics of TorchRL losses are:

- they are stateful objects: they contain a copy of the trainable parameters
such that ``loss_module.parameters()`` gives whatever is needed to train the
algorithm.
- They follow the ``tensordict`` convention: the :meth:`torch.nn.Module.forward`
method will receive a tensordict as input that contains all the necessary
information to return a loss value.
- They output a :class:`tensordict.TensorDict` instance with the loss values
written under a ``"loss_<smth>`` where ``smth`` is a string describing the
loss. Additional keys in the tensordict may be useful metrics to log during
training time.
.. note::
The reason we return independent losses is to let the user use a different
optimizer for different sets of parameters for instance. Summing the losses
can be simply done via ``sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")``.

Training value functions
------------------------

TorchRL provides a range of **value estimators** such as TD(0), TD(1), TD(:math:`\lambda`)
and GAE.
In a nutshell, a value estimator is a function of data (mostly
rewards and done states) and a state value (ie. the value
returned by a function that is fit to estimate state-values).
To learn more about value estimators, check the introduction to RL from `Sutton
and Barto <https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf>`_,
in particular the chapters about value iteration and TD learning.
It gives a somewhat biased estimation of the discounted return following a state
or a state-action pair based on data and proxy maps. These estimators are
used in two contexts:

- To train the value network to learn the "true" state value (or state-action
value) map, one needs a target value to fit it to. The better (less bias,
less variance) the estimator, the better the value network will be, which in
turn can speed up the policy training significantly. Typically, the value
network loss will look like:

>>> value = value_network(states)
>>> target_value = value_estimator(rewards, done, value_network(next_state))
>>> value_net_loss = (value - target_value).pow(2).mean()

- Computing an "advantage" signal for policy-optimization. The advantage is
the delta between the value estimate (from the estimator, ie from "real" data)
and the output of the value network (ie the proxy to this value). A positive
advantage can be seen as a signal that the policy actually performed better
than expected, thereby signaling that there is room for improvement if that
trajectory is to be taken as example. Conversely, a negative advantage signifies
that the policy underperformed compared to what was to be expected.

Thins are not always as easy as in the example above and the formula to compute
the value estimator or the advantage may be slightly more intricate than this.
To help users flexibly use one or another value estimator, we provide a simple
API to change it on-the-fly. Here is an example with DQN, but all modules will
follow a similar structure:

>>> from torchrl.objectives import DQNLoss, ValueEstimators
>>> loss_module = DQNLoss(actor)
>>> kwargs = {"gamma": 0.9, "lmbda": 0.9}
>>> loss_module.make_value_estimator(ValueEstimators.TDLambda, **kwargs)

The :class:`torchrl.objectives.ValueEstimators` class enumerates the value
estimators to choose from. This makes it easy for the users to rely on
auto-completion to make their choice.

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

LossModule

DQN
---

Expand Down Expand Up @@ -108,16 +183,23 @@ Returns
:toctree: generated/
:template: rl_template_noinherit.rst

ValueEstimatorBase
TD0Estimator
TD1Estimator
TDLambdaEstimator
GAE
TDLambdaEstimate
TDEstimate
functional.generalized_advantage_estimate
functional.vec_generalized_advantage_estimate
functional.vec_td_lambda_return_estimate
functional.vec_td_lambda_advantage_estimate
functional.td0_return_estimate
functional.td0_advantage_estimate
functional.td1_return_estimate
functional.vec_td1_return_estimate
functional.td1_advantage_estimate
functional.vec_td1_advantage_estimate
functional.td_lambda_return_estimate
functional.vec_td_lambda_return_estimate
functional.td_lambda_advantage_estimate
functional.td_advantage_estimate
functional.vec_td_lambda_advantage_estimate
functional.generalized_advantage_estimate
functional.vec_generalized_advantage_estimate


Utils
Expand All @@ -134,3 +216,5 @@ Utils
next_state_value
SoftUpdate
HardUpdate
ValueFunctions
default_value_kwargs
9 changes: 5 additions & 4 deletions examples/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hydra.core.config_store import ConfigStore
from torchrl.envs.transforms import RewardScaling
from torchrl.envs.utils import set_exploration_mode
from torchrl.objectives.value import TDEstimate
from torchrl.objectives.value import TD0Estimator
from torchrl.record.loggers import generate_exp_name, get_logger
from torchrl.trainers.helpers.collectors import (
make_collector_onpolicy,
Expand Down Expand Up @@ -144,14 +144,15 @@ def main(cfg: "DictConfig"): # noqa: F821
)

critic_model = model.get_value_operator()
advantage = TDEstimate(
cfg.gamma,
advantage = TD0Estimator(
gamma=cfg.gamma,
value_network=critic_model,
average_rewards=True,
differentiable=True,
)
trainer.register_op(
"process_optim_batch",
advantage,
torch.no_grad()(advantage),
)

final_seed = collector.set_seed(cfg.seed)
Expand Down
7 changes: 4 additions & 3 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,15 @@ def main(cfg: "DictConfig"): # noqa: F821

critic_model = model.get_value_operator()
advantage = GAE(
cfg.gamma,
cfg.lmbda,
gamma=cfg.gamma,
lmbda=cfg.lmbda,
value_network=critic_model,
average_gae=True,
differentiable=True,
)
trainer.register_op(
"process_optim_batch",
lambda tensordict: advantage(tensordict.to(device)),
lambda tensordict: torch.no_grad()(advantage(tensordict.to(device))),
)
trainer._process_optim_batch_ops = [
trainer._process_optim_batch_ops[-1],
Expand Down
Loading

0 comments on commit 0b2d2d8

Please sign in to comment.