Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Feb 15, 2024
1 parent 832a118 commit 711d741
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 59 deletions.
62 changes: 17 additions & 45 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
)

from .common import LossModule
from .utils import _reduce
from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator, VTrace


Expand Down Expand Up @@ -86,10 +85,6 @@ class PPOLoss(LossModule):
Functionalizing permits features like meta-RL, but makes it
impossible to use distributed models (DDP, FSDP, ...) and comes
with a little cost. Defaults to ``True``.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
.. note::
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -283,7 +278,6 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
reduction: str = "mean",
):
if actor is not None:
actor_network = actor
Expand Down Expand Up @@ -325,7 +319,6 @@ def __init__(
self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus
self.separate_losses = separate_losses
self.reduction = reduction

try:
device = next(self.parameters()).device
Expand Down Expand Up @@ -537,17 +530,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
advantage = (advantage - loc) / scale

log_weight, dist = self._log_weight(tensordict)
neg_loss = log_weight.exp() * advantage
td_out = TensorDict({"loss_objective": -_reduce(neg_loss, self.reduction)}, [])
neg_loss = (log_weight.exp() * advantage).mean()
td_out = TensorDict({"loss_objective": -neg_loss.mean()}, [])
if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set(
"entropy", _reduce(entropy.detach(), self.reduction)
) # for logging
td_out.set("loss_entropy", -self.entropy_coef * _reduce(entropy))
td_out.set("entropy", entropy.mean().detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())
if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", _reduce(loss_critic, self.reduction))
loss_critic = self.loss_critic(tensordict).mean()
td_out.set("loss_critic", loss_critic.mean())
return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down Expand Up @@ -643,10 +634,6 @@ class ClipPPOLoss(PPOLoss):
Functionalizing permits features like meta-RL, but makes it
impossible to use distributed models (DDP, FSDP, ...) and comes
with a little cost. Defaults to ``True``.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -705,7 +692,6 @@ def __init__(
normalize_advantage: bool = True,
gamma: float = None,
separate_losses: bool = False,
reduction: str = "mean",
**kwargs,
):
super(ClipPPOLoss, self).__init__(
Expand All @@ -719,7 +705,6 @@ def __init__(
normalize_advantage=normalize_advantage,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
**kwargs,
)
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon))
Expand Down Expand Up @@ -779,20 +764,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
gain2 = log_weight_clip.exp() * advantage

gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
td_out = TensorDict({"loss_objective": -_reduce(gain, self.reduction)}, [])
td_out = TensorDict({"loss_objective": -gain.mean()}, [])

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set(
"entropy", _reduce(entropy, self.reduction).detach()
) # for logging
td_out.set(
"loss_entropy", -self.entropy_coef * _reduce(entropy, self.reduction)
)
td_out.set("entropy", entropy.mean().detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())
if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", _reduce(loss_critic, self.reduction))
td_out.set("ESS", _reduce(ess, self.reduction) / batch)
td_out.set("loss_critic", loss_critic.mean())
td_out.set("ESS", ess.mean() / batch)
return td_out


Expand Down Expand Up @@ -851,10 +832,7 @@ class KLPENPPOLoss(PPOLoss):
Functionalizing permits features like meta-RL, but makes it
impossible to use distributed models (DDP, FSDP, ...) and comes
with a little cost. Defaults to ``True``.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -917,7 +895,6 @@ def __init__(
normalize_advantage: bool = True,
gamma: float = None,
separate_losses: bool = False,
reduction: str = "mean",
**kwargs,
):
super(KLPENPPOLoss, self).__init__(
Expand All @@ -931,7 +908,6 @@ def __init__(
normalize_advantage=normalize_advantage,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
**kwargs,
)

Expand Down Expand Up @@ -1002,24 +978,20 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
self.beta.data *= self.decrement
td_out = TensorDict(
{
"loss_objective": -_reduce(neg_loss, self.reduction),
"kl": _reduce(kl.detach(), self.reduction),
"loss_objective": -neg_loss.mean(),
"kl": kl.detach().mean(),
},
[],
)

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set(
"entropy", _reduce(entropy, self.reduction).detach()
) # for logging
td_out.set(
"loss_entropy", -self.entropy_coef * _reduce(entropy, self.reduction)
)
td_out.set("entropy", entropy.mean().detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())

if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", _reduce(loss_critic, self.reduction))
td_out.set("loss_critic", loss_critic.mean())

return td_out

Expand Down
14 changes: 0 additions & 14 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import re
Expand Down Expand Up @@ -502,16 +501,3 @@ def decorated_module(*module_args_params):
raise RuntimeError(
"Please use <loss_module>.set_vmap_randomness('different') to handle random operations during vmap."
) from err


def _reduce(tensor: torch.Tensor, reduction: str) -> Union[float, torch.Tensor]:
"""Reduces a tensor given the reduction method."""
if reduction is None:
return tensor
elif reduction == "mean":
result = tensor.mean()
elif reduction == "sum":
result = tensor.sum()
else:
raise NotImplementedError(f"Unknown reduction method {reduction}")
return result.item()

0 comments on commit 711d741

Please sign in to comment.