Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add reduction parameter to On-Policy losses. #1890

Merged
merged 31 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
670108a
ppo reduction
albertbou92 Feb 9, 2024
668d212
ppo reduction
albertbou92 Feb 9, 2024
e624083
ppo reduction
albertbou92 Feb 9, 2024
26cd568
ppo reduction
albertbou92 Feb 9, 2024
3b253fc
ppo reduction
albertbou92 Feb 9, 2024
81e6a3a
ppo reduction
albertbou92 Feb 9, 2024
8560e8b
a2c / reinforce reduction
albertbou92 Feb 9, 2024
e747cec
a2c / reinforce reduction
albertbou92 Feb 9, 2024
b7c249c
a2c / reinforce tests
albertbou92 Feb 9, 2024
ea93914
format
albertbou92 Feb 9, 2024
666e7b1
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 10, 2024
b5ef409
fix recursion issue
vmoens Feb 10, 2024
6275f02
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 10, 2024
107e875
init
vmoens Feb 11, 2024
3163d1d
Merge branch 'fix-loss-exploration' into loss_reduction
vmoens Feb 11, 2024
331bd38
Update torchrl/objectives/reinforce.py
albertbou92 Feb 12, 2024
61fc41b
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
6dbb622
Update torchrl/objectives/a2c.py
albertbou92 Feb 12, 2024
2d6674e
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
95efebd
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
e64ee3d
suggestions added
albertbou92 Feb 12, 2024
5368bdc
format
albertbou92 Feb 12, 2024
efaa893
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 12, 2024
ac115a3
Merge branch 'loss_reduction' of https://github.com/PyTorchRL/rl into…
vmoens Feb 12, 2024
c218352
default reduction none
albertbou92 Feb 13, 2024
eebcbb4
Merge branch 'main' into loss_reduction
albertbou92 Feb 15, 2024
7e516f8
remove bs from loss
albertbou92 Feb 15, 2024
2701bb8
fix test
albertbou92 Feb 15, 2024
566b2b9
format
albertbou92 Feb 15, 2024
7e6b4b2
better tests
albertbou92 Feb 15, 2024
8052e33
better tests
albertbou92 Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
ppo reduction
  • Loading branch information
albertbou92 committed Feb 9, 2024
commit 670108ac7698ffcde35b5b93dfc126bc5e4efc3b
54 changes: 48 additions & 6 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5817,8 +5817,16 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_ppo(
self, loss_class, device, gradient_mode, advantage, td_est, functional
self,
loss_class,
device,
gradient_mode,
advantage,
td_est,
functional,
reduction,
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
Expand Down Expand Up @@ -5849,14 +5857,24 @@ def test_ppo(
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2", functional=functional)
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)
if advantage is not None:
advantage(td)
else:
if td_est is not None:
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -5904,7 +5922,8 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode):
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_ppo_shared(self, loss_class, device, advantage):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_shared(self, loss_class, device, advantage, reduction):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)

Expand Down Expand Up @@ -5946,6 +5965,10 @@ def test_ppo_shared(self, loss_class, device, advantage):
if advantage is not None:
advantage(td)
loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -5989,7 +6012,10 @@ def test_ppo_shared(self, loss_class, device, advantage):
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("separate_losses", [True, False])
def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_shared_seq(
self, loss_class, device, advantage, separate_losses, reduction
):
"""Tests PPO with shared module with and without passing twice across the common module."""
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
Expand Down Expand Up @@ -6040,11 +6066,19 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
if advantage is not None:
advantage(td)
loss = loss_fn(td).exclude("entropy")
if reduction is None:
assert loss.batch_size == td.batch_size
loss.apply(lambda x: x.float().mean(), batch_size=[])

sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
grad = TensorDict(dict(model.named_parameters()), []).apply(
lambda x: x.grad.clone()
)
loss2 = loss_fn2(td).exclude("entropy")
if reduction is None:
assert loss2.batch_size == td.batch_size
loss2.apply(lambda x: x.float().mean(), batch_size=[])

model.zero_grad()
sum(val for key, val in loss2.items() if key.startswith("loss_")).backward()
grad2 = TensorDict(dict(model.named_parameters()), []).apply(
Expand All @@ -6061,7 +6095,8 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_diff(self, loss_class, device, gradient_mode, advantage, reduction):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)

Expand Down Expand Up @@ -6107,6 +6142,9 @@ def zero_param(p):
if advantage is not None:
advantage(td)
loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
Expand Down Expand Up @@ -6194,7 +6232,8 @@ def test_ppo_tensordict_keys(self, loss_class, td_est):
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est, reduction):
"""Test PPO loss module with non-default tensordict keys."""
torch.manual_seed(self.seed)
gradient_mode = True
Expand Down Expand Up @@ -6263,6 +6302,9 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
Expand Down
62 changes: 45 additions & 17 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from .common import LossModule
from .utils import _reduce
from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator, VTrace


Expand Down Expand Up @@ -85,6 +86,10 @@ class PPOLoss(LossModule):
Functionalizing permits features like meta-RL, but makes it
impossible to use distributed models (DDP, FSDP, ...) and comes
with a little cost. Defaults to ``True``.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

.. note::
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -278,6 +283,7 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
reduction: str = "mean",
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
):
if actor is not None:
actor_network = actor
Expand Down Expand Up @@ -319,6 +325,7 @@ def __init__(
self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus
self.separate_losses = separate_losses
self.reduction = reduction

try:
device = next(self.parameters()).device
Expand Down Expand Up @@ -530,15 +537,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
advantage = (advantage - loc) / scale

log_weight, dist = self._log_weight(tensordict)
neg_loss = (log_weight.exp() * advantage).mean()
td_out = TensorDict({"loss_objective": -neg_loss.mean()}, [])
neg_loss = log_weight.exp() * advantage
td_out = TensorDict({"loss_objective": -_reduce(neg_loss, self.reduction)}, [])
if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.mean().detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())
td_out.set(
"entropy", _reduce(entropy.detach(), self.reduction)
) # for logging
td_out.set("loss_entropy", -self.entropy_coef * _reduce(entropy))
if self.critic_coef:
loss_critic = self.loss_critic(tensordict).mean()
td_out.set("loss_critic", loss_critic.mean())
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", _reduce(loss_critic, self.reduction))
return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down Expand Up @@ -634,6 +643,10 @@ class ClipPPOLoss(PPOLoss):
Functionalizing permits features like meta-RL, but makes it
impossible to use distributed models (DDP, FSDP, ...) and comes
with a little cost. Defaults to ``True``.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -692,6 +705,7 @@ def __init__(
normalize_advantage: bool = True,
gamma: float = None,
separate_losses: bool = False,
reduction: str = "mean",
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
super(ClipPPOLoss, self).__init__(
Expand All @@ -705,6 +719,7 @@ def __init__(
normalize_advantage=normalize_advantage,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
**kwargs,
)
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon))
Expand Down Expand Up @@ -764,16 +779,20 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
gain2 = log_weight_clip.exp() * advantage

gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
td_out = TensorDict({"loss_objective": -gain.mean()}, [])
td_out = TensorDict({"loss_objective": -_reduce(gain, self.reduction)}, [])

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.mean().detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())
td_out.set(
"entropy", _reduce(entropy, self.reduction).detach()
) # for logging
td_out.set(
"loss_entropy", -self.entropy_coef * _reduce(entropy, self.reduction)
)
if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic.mean())
td_out.set("ESS", ess.mean() / batch)
td_out.set("loss_critic", _reduce(loss_critic, self.reduction))
td_out.set("ESS", _reduce(ess, self.reduction) / batch)
return td_out


Expand Down Expand Up @@ -832,7 +851,10 @@ class KLPENPPOLoss(PPOLoss):
Functionalizing permits features like meta-RL, but makes it
impossible to use distributed models (DDP, FSDP, ...) and comes
with a little cost. Defaults to ``True``.

reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -895,6 +917,7 @@ def __init__(
normalize_advantage: bool = True,
gamma: float = None,
separate_losses: bool = False,
reduction: str = "mean",
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
super(KLPENPPOLoss, self).__init__(
Expand All @@ -908,6 +931,7 @@ def __init__(
normalize_advantage=normalize_advantage,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
**kwargs,
)

Expand Down Expand Up @@ -978,20 +1002,24 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
self.beta.data *= self.decrement
td_out = TensorDict(
{
"loss_objective": -neg_loss.mean(),
"kl": kl.detach().mean(),
"loss_objective": -_reduce(neg_loss, self.reduction),
"kl": _reduce(kl.detach(), self.reduction),
},
[],
)

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.mean().detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())
td_out.set(
"entropy", _reduce(entropy, self.reduction).detach()
) # for logging
td_out.set(
"loss_entropy", -self.entropy_coef * _reduce(entropy, self.reduction)
)

if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic.mean())
td_out.set("loss_critic", _reduce(loss_critic, self.reduction))

return td_out

Expand Down
14 changes: 14 additions & 0 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import re
Expand Down Expand Up @@ -501,3 +502,16 @@ def decorated_module(*module_args_params):
raise RuntimeError(
"Please use <loss_module>.set_vmap_randomness('different') to handle random operations during vmap."
) from err


def _reduce(tensor: torch.Tensor, reduction: str) -> Union[float, torch.Tensor]:
"""Reduces a tensor given the reduction method."""
if reduction is None:
return tensor
elif reduction == "mean":
result = tensor.mean()
elif reduction == "sum":
result = tensor.sum()
else:
raise NotImplementedError(f"Unknown reduction method {reduction}")
return result.item()