Skip to content

Commit

Permalink
[Feature] Add reduction parameter to On-Policy losses. (#1890)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <vincentmoens@gmail.com>
  • Loading branch information
albertbou92 and vmoens authored Feb 15, 2024
1 parent 57ac22b commit 67f659c
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 41 deletions.
78 changes: 69 additions & 9 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import contextlib
import functools
Expand Down Expand Up @@ -5898,8 +5897,16 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_ppo(
self, loss_class, device, gradient_mode, advantage, td_est, functional
self,
loss_class,
device,
gradient_mode,
advantage,
td_est,
functional,
reduction,
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
Expand Down Expand Up @@ -5930,14 +5937,29 @@ def test_ppo(
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2", functional=functional)
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)
if advantage is not None:
advantage(td)
else:
if td_est is not None:
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if reduction == "none":

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6027,6 +6049,7 @@ def test_ppo_shared(self, loss_class, device, advantage):
if advantage is not None:
advantage(td)
loss = loss_fn(td)

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6070,7 +6093,13 @@ def test_ppo_shared(self, loss_class, device, advantage):
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("separate_losses", [True, False])
def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
def test_ppo_shared_seq(
self,
loss_class,
device,
advantage,
separate_losses,
):
"""Tests PPO with shared module with and without passing twice across the common module."""
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
Expand Down Expand Up @@ -6121,11 +6150,13 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
if advantage is not None:
advantage(td)
loss = loss_fn(td).exclude("entropy")

sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
grad = TensorDict(dict(model.named_parameters()), []).apply(
lambda x: x.grad.clone()
)
loss2 = loss_fn2(td).exclude("entropy")

model.zero_grad()
sum(val for key, val in loss2.items() if key.startswith("loss_")).backward()
grad2 = TensorDict(dict(model.named_parameters()), []).apply(
Expand Down Expand Up @@ -6618,7 +6649,8 @@ def _create_seq_mock_data_a2c(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", (True, False))
def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reduction):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)

Expand Down Expand Up @@ -6648,7 +6680,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
else:
raise NotImplementedError

loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional)
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)

# Check error is raised when actions require grads
td["action"].requires_grad = True
Expand All @@ -6665,6 +6703,14 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
loss = loss_fn(td)
if reduction == "none":

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6711,7 +6757,9 @@ def test_a2c_separate_losses(self, separate_losses):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = A2CLoss(
actor_network=actor, critic_network=critic, separate_losses=separate_losses
actor_network=actor,
critic_network=critic,
separate_losses=separate_losses,
)

# Check error is raised when actions require grads
Expand Down Expand Up @@ -7291,14 +7339,26 @@ def _create_mock_common_layer_setup(
return actor, critic, common, td

@pytest.mark.parametrize("separate_losses", [False, True])
def test_reinforce_tensordict_separate_losses(self, separate_losses):
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = ReinforceLoss(
actor_network=actor, critic_network=critic, separate_losses=separate_losses
actor_network=actor,
critic_network=critic,
separate_losses=separate_losses,
reduction=reduction,
)

loss = loss_fn(td)
if reduction == "none":

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])

assert all(
(p.grad is None) or (p.grad == 0).all()
Expand Down
26 changes: 20 additions & 6 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import functools
import warnings
from copy import deepcopy
from dataclasses import dataclass
Expand All @@ -15,9 +16,11 @@
from torch import distributions as d

from torchrl.objectives.common import LossModule

from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_ERROR,
_reduce,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down Expand Up @@ -68,6 +71,10 @@ class A2CLoss(LossModule):
Functionalizing permits features like meta-RL, but makes it
impossible to use distributed models (DDP, FSDP, ...) and comes
with a little cost. Defaults to ``True``.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -234,6 +241,7 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
reduction: str = None,
):
if actor is not None:
actor_network = actor
Expand All @@ -245,6 +253,8 @@ def __init__(
raise TypeError(
"Missing positional arguments actor_network or critic_network."
)
if reduction is None:
reduction = "mean"

self._functional = functional
self._out_keys = None
Expand Down Expand Up @@ -277,6 +287,7 @@ def __init__(

self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus and entropy_coef
self.reduction = reduction

try:
device = next(self.parameters()).device
Expand Down Expand Up @@ -389,7 +400,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
entropy = -dist.log_prob(x)
entropy = -dist.log_prob(x).mean(0)
return entropy.unsqueeze(-1)

def _log_probs(
Expand Down Expand Up @@ -458,14 +469,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
assert not advantage.requires_grad
log_probs, dist = self._log_probs(tensordict)
loss = -(log_probs * advantage)
td_out = TensorDict({"loss_objective": loss.mean()}, [])
td_out = TensorDict({"loss_objective": loss}, batch_size=[])
if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.mean().detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())
td_out.set("entropy", entropy.detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef:
loss_critic = self.loss_critic(tensordict).mean()
td_out.set("loss_critic", loss_critic.mean())
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
)
return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down
Loading

0 comments on commit 67f659c

Please sign in to comment.