Skip to content

Commit

Permalink
[Feature] Adds value clipping in ClipPPOLoss loss (#2005)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
3 people authored Mar 18, 2024
1 parent e3b66bb commit 43c6bca
Show file tree
Hide file tree
Showing 15 changed files with 390 additions and 11 deletions.
160 changes: 160 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6885,6 +6885,61 @@ def test_ppo_reduction(self, reduction, loss_class):
continue
assert loss[key].shape == torch.Size([])

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("clip_value", [True, False, None, 0.5, torch.tensor(0.5)])
def test_ppo_value_clipping(self, clip_value, loss_class, device):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
actor = self._create_mock_actor(device=device)
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
lmbda=0.9,
value_network=value,
)

if isinstance(clip_value, bool) and loss_class is not ClipPPOLoss:
with pytest.raises(
ValueError,
match=f"clip_value must be a float or a scalar tensor, got {clip_value}.",
):
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
clip_value=clip_value,
)

else:
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
clip_value=clip_value,
)
advantage(td)

value = td.pop(loss_fn.tensor_keys.value)

if clip_value:
# Test it fails without value key
with pytest.raises(
KeyError,
match=f"clip_value is set to {loss_fn.clip_value}, but the key "
"state_value was not found in the input tensordict. "
"Make sure that the value_key passed to PPO exists in "
"the input tensordict.",
):
loss = loss_fn(td)

# Add value back to td
td.set(loss_fn.tensor_keys.value, value)

# Test it works with value key
loss = loss_fn(td)
assert "loss_critic" in loss.keys()


class TestA2C(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -7519,6 +7574,59 @@ def test_a2c_reduction(self, reduction):
continue
assert loss[key].shape == torch.Size([])

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("clip_value", [True, None, 0.5, torch.tensor(0.5)])
def test_a2c_value_clipping(self, clip_value, device):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)
actor = self._create_mock_actor(device=device)
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
lmbda=0.9,
value_network=value,
)

if isinstance(clip_value, bool):
with pytest.raises(
ValueError,
match=f"clip_value must be a float or a scalar tensor, got {clip_value}.",
):
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
clip_value=clip_value,
)
else:
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
clip_value=clip_value,
)
advantage(td)

value = td.pop(loss_fn.tensor_keys.value)

if clip_value:
# Test it fails without value key
with pytest.raises(
KeyError,
match=f"clip_value is set to {clip_value}, but the key "
"state_value was not found in the input tensordict. "
"Make sure that the value_key passed to A2C exists in "
"the input tensordict.",
):
loss = loss_fn(td)

# Add value back to td
td.set(loss_fn.tensor_keys.value, value)

# Test it works with value key
loss = loss_fn(td)
assert "loss_critic" in loss.keys()


class TestReinforce(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -7923,6 +8031,58 @@ def test_reinforce_reduction(self, reduction):
continue
assert loss[key].shape == torch.Size([])

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("clip_value", [True, None, 0.5, torch.tensor(0.5)])
def test_reinforce_value_clipping(self, clip_value, device):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
actor = actor.to(device)
critic = critic.to(device)
td = td.to(device)
advantage = GAE(
gamma=0.9,
lmbda=0.9,
value_network=critic,
)
if isinstance(clip_value, bool):
with pytest.raises(
ValueError,
match=f"clip_value must be a float or a scalar tensor, got {clip_value}.",
):
loss_fn = ReinforceLoss(
actor_network=actor,
critic_network=critic,
clip_value=clip_value,
)
return
else:
loss_fn = ReinforceLoss(
actor_network=actor,
critic_network=critic,
clip_value=clip_value,
)
advantage(td)

value = td.pop(loss_fn.tensor_keys.value)

if clip_value:
# Test it fails without value key
with pytest.raises(
KeyError,
match=f"clip_value is set to {loss_fn.clip_value}, but the key "
"state_value was not found in the input tensordict. "
"Make sure that the value_key passed to Reinforce exists in "
"the input tensordict.",
):
loss = loss_fn(td)

# Add value back to td
td.set(loss_fn.tensor_keys.value, value)

# Test it works with value key
loss = loss_fn(td)
assert "loss_value" in loss.keys()


@pytest.mark.parametrize("device", get_default_devices())
class TestDreamer(LossModuleTestBase):
Expand Down
49 changes: 47 additions & 2 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import contextlib
import warnings
from copy import deepcopy
Expand All @@ -18,6 +20,7 @@

from torchrl.objectives.utils import (
_cache_values,
_clip_value_loss,
_GAMMA_LMBDA_DEPREC_ERROR,
_reduce,
default_value_kwargs,
Expand Down Expand Up @@ -74,6 +77,11 @@ class A2CLoss(LossModule):
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
clip_value (float, optional): If provided, it will be used to compute a clipped version of the value
prediction with respect to the input value estimate and use it to calculate the value loss.
The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training
and preventing large updates. However, it will have no impact if the value estimate was done by the current
version of the value estimator. Defaults to ``None``.
.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -241,6 +249,7 @@ def __init__(
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
reduction: str = None,
clip_value: float | None = None,
):
if actor is not None:
actor_network = actor
Expand Down Expand Up @@ -301,6 +310,20 @@ def __init__(
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type

if clip_value is not None:
if isinstance(clip_value, float):
clip_value = torch.tensor(clip_value)
elif isinstance(clip_value, torch.Tensor):
if clip_value.numel() != 1:
raise ValueError(
f"clip_value must be a float or a scalar tensor, got {clip_value}."
)
else:
raise ValueError(
f"clip_value must be a float or a scalar tensor, got {clip_value}."
)
self.register_buffer("clip_value", clip_value)

@property
def functional(self):
return self._functional
Expand Down Expand Up @@ -423,6 +446,16 @@ def _log_probs(
return log_prob, dist

def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
if self.clip_value:
try:
old_state_value = tensordict.get(self.tensor_keys.value).clone()
except KeyError:
raise KeyError(
f"clip_value is set to {self.clip_value}, but "
f"the key {self.tensor_keys.value} was not found in the input tensordict. "
f"Make sure that the value_key passed to A2C exists in the input tensordict."
)

try:
# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
Expand All @@ -449,7 +482,17 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
f"can be used for the value loss."
)
return self.critic_coef * loss_value
clip_fraction = None
if self.clip_value:
loss_value, clip_fraction = _clip_value_loss(
old_state_value,
state_value,
self.clip_value.to(state_value.device),
target_return,
loss_value,
self.loss_critic_type,
)
return self.critic_coef * loss_value, clip_fraction

@property
@_cache_values
Expand Down Expand Up @@ -478,8 +521,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
td_out.set("value_clip_fraction", value_clip_fraction)
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import math
import warnings
from copy import deepcopy
Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import math
from dataclasses import dataclass
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import math
from dataclasses import dataclass
from numbers import Number
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Optional, Union
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import torch

Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
Expand Down
Loading

0 comments on commit 43c6bca

Please sign in to comment.