Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Adds value clipping in ClipPPOLoss loss #2005

Merged
merged 44 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
88b1ea3
clipping code
albertbou92 Mar 9, 2024
704f7da
fix test
albertbou92 Mar 9, 2024
fee26d6
fix test
albertbou92 Mar 9, 2024
0cd5af0
fix test
albertbou92 Mar 9, 2024
27e1822
extend logiv
albertbou92 Mar 10, 2024
3af7f0c
fix test
albertbou92 Mar 10, 2024
91577c9
register param
albertbou92 Mar 10, 2024
42071f3
register param
albertbou92 Mar 10, 2024
5e33992
minor fix
albertbou92 Mar 10, 2024
47ed022
added param for a2c and reinforce
albertbou92 Mar 11, 2024
c363a37
fix test
albertbou92 Mar 11, 2024
c0d03aa
fix test
albertbou92 Mar 11, 2024
102235e
fix test
albertbou92 Mar 11, 2024
8bf1ae2
fix test
albertbou92 Mar 11, 2024
025ceb4
fix test
albertbou92 Mar 11, 2024
3b8728b
fix test
albertbou92 Mar 11, 2024
fdcbf9f
fix test
albertbou92 Mar 11, 2024
5d75566
docstrings
albertbou92 Mar 11, 2024
4c8c783
docstrings
albertbou92 Mar 11, 2024
19e57b5
Update torchrl/objectives/a2c.py
albertbou92 Mar 12, 2024
245e562
Update torchrl/objectives/ppo.py
albertbou92 Mar 12, 2024
315c9e2
Update torchrl/objectives/ppo.py
albertbou92 Mar 12, 2024
b79c99e
Update torchrl/objectives/ppo.py
albertbou92 Mar 12, 2024
48f06ce
Update torchrl/objectives/ppo.py
albertbou92 Mar 12, 2024
c356670
Update torchrl/objectives/a2c.py
albertbou92 Mar 12, 2024
fc3c3dd
Update torchrl/objectives/a2c.py
albertbou92 Mar 12, 2024
7d04eca
Update torchrl/objectives/a2c.py
albertbou92 Mar 12, 2024
28c9cae
integrate feedback
albertbou92 Mar 12, 2024
9e273d1
fix test
albertbou92 Mar 12, 2024
c836bd8
fix test
albertbou92 Mar 12, 2024
1672c48
format
albertbou92 Mar 12, 2024
cc04ec4
return clip fractions
albertbou92 Mar 12, 2024
78ad4d8
fix test
albertbou92 Mar 12, 2024
ad1425c
update feedback
albertbou92 Mar 14, 2024
82e9efd
fix test
albertbou92 Mar 14, 2024
3723dc2
fix test
albertbou92 Mar 14, 2024
bd1c974
fix test
albertbou92 Mar 14, 2024
3e7b824
Update torchrl/objectives/ppo.py
albertbou92 Mar 14, 2024
7a5c9ee
Update torchrl/objectives/a2c.py
albertbou92 Mar 14, 2024
1374d80
Update torchrl/objectives/ppo.py
albertbou92 Mar 14, 2024
cb83230
Update torchrl/objectives/reinforce.py
albertbou92 Mar 14, 2024
021aa71
minor fixes
albertbou92 Mar 14, 2024
c123231
Merge remote-tracking branch 'origin/main' into clip_value_loss
vmoens Mar 18, 2024
c72347c
amend
vmoens Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added param for a2c and reinforce
  • Loading branch information
albertbou92 committed Mar 11, 2024
commit 47ed022a831b10ddfdf386b41b8b27f06f1b3f01
120 changes: 118 additions & 2 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6885,8 +6885,9 @@ def test_ppo_reduction(self, reduction, loss_class):
continue
assert loss[key].shape == torch.Size([])

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("clip_value_loss", [True, False, 0.5])
def test_ppo_value_clipping(self, clip_value_loss):
def test_ppo_value_clipping(self, clip_value_loss, loss_class):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
Expand All @@ -6902,7 +6903,22 @@ def test_ppo_value_clipping(self, clip_value_loss):
value_network=value,
)

loss_fn = ClipPPOLoss(
if isinstance(clip_value_loss, bool) and not isinstance(
loss_class, ClipPPOLoss
):
with pytest.raises(
ValueError,
match="If provided, clip_value_loss must be a float.",
):
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
clip_value_loss=clip_value_loss,
)
return

loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
Expand Down Expand Up @@ -7564,6 +7580,64 @@ def test_a2c_reduction(self, reduction):
continue
assert loss[key].shape == torch.Size([])

@pytest.mark.parametrize("clip_value_loss", [True, None, 0.5])
def test_a2c_value_clipping(self, clip_value_loss):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_seq_mock_data_a2c(device=device)
actor = self._create_mock_actor(device=device)
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
lmbda=0.9,
value_network=value,
)

if isinstance(clip_value_loss, bool):
with pytest.raises(
ValueError,
match="If provided, clip_value_loss must be a float.",
):
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
clip_value_loss=clip_value_loss,
)
return

loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
clip_value_loss=clip_value_loss,
)
advantage(td)

value = td.pop(loss_fn.tensor_keys.value)

if clip_value_loss:
# Test it fails without value key
with pytest.raises(
KeyError,
match="clip_value_loss is set to True, but the key "
"state_value was not found in the input tensordict. "
"Make sure that the value_key passed to A2C exists in "
"the input tensordict.",
):
loss = loss_fn(td)

# Add value back to td
td.set(loss_fn.tensor_keys.value, value)

# Test it works with value key
loss = loss_fn(td)
assert "loss_critic" in loss.keys()


class TestReinforce(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -7968,6 +8042,48 @@ def test_reinforce_reduction(self, reduction):
continue
assert loss[key].shape == torch.Size([])

@pytest.mark.parametrize("clip_value_loss", [True, None, 0.5])
def test_reinforce_value_clipping(self, clip_value_loss):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
if isinstance(clip_value_loss, bool):
with pytest.raises(
ValueError,
match="If provided, clip_value_loss must be a float.",
):
loss_fn = ReinforceLoss(
actor_network=actor,
critic_network=critic,
clip_value_loss=clip_value_loss,
)
return

loss_fn = ReinforceLoss(
actor_network=actor,
critic_network=critic,
clip_value_loss=clip_value_loss,
)

value = td.pop(loss_fn.tensor_keys.value)

if clip_value_loss:
# Test it fails without value key
with pytest.raises(
KeyError,
match="clip_value_loss is set to True, but the key "
"state_value was not found in the input tensordict. "
"Make sure that the value_key passed to A2C exists in "
"the input tensordict.",
):
loss = loss_fn(td)

# Add value back to td
td.set(loss_fn.tensor_keys.value, value)

# Test it works with value key
loss = loss_fn(td)
assert "loss_critic" in loss.keys()


@pytest.mark.parametrize("device", get_default_devices())
class TestDreamer(LossModuleTestBase):
Expand Down
34 changes: 34 additions & 0 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class A2CLoss(LossModule):
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
clip_value_loss (float, optional): If provided, it will be used to compute a clipped version of the value
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
prediction with respect to the input tensordict value estimate and use it to calculate the value loss.
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training
and preventing large updates. However, it will have no impact if the value estimate was done by the current
version of the value estimator. Defaults to ``None``.

.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -241,6 +246,7 @@ def __init__(
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
reduction: str = None,
clip_value_loss: float = None,
):
if actor is not None:
actor_network = actor
Expand Down Expand Up @@ -301,6 +307,12 @@ def __init__(
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type

if clip_value_loss:
if not isinstance(clip_value_loss, float):
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("If provided, clip_value_loss must be a float.")
clip_value_loss = torch.tensor(clip_value_loss)
self.register_buffer("clip_value_loss", clip_value_loss)

@property
def functional(self):
return self._functional
Expand Down Expand Up @@ -421,6 +433,16 @@ def _log_probs(
return log_prob, dist

def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
if self.clip_value_loss:
try:
old_state_value = tensordict.get(self.tensor_keys.value).clone()
except KeyError:
raise KeyError(
f"clip_value_loss is set to True, but "
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
f"the key {self.tensor_keys.value} was not found in the input tensordict. "
f"Make sure that the value_key passed to A2C exists in the input tensordict."
)

try:
# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
Expand All @@ -445,6 +467,18 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
f"can be used for the value loss."
)
if self.clip_value_loss:
self.clip_value_loss = self.clip_value_loss.to(state_value.device)
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
state_value_clipped = old_state_value + (
state_value - old_state_value
).clamp(-self.clip_value_loss, self.clip_value_loss)
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
loss_value_clipped = distance_loss(
target_return,
state_value_clipped,
loss_function=self.loss_critic_type,
)
# Chose the most pessimistic value prediction between clipped and non-clipped
loss_value = torch.max(loss_value, loss_value_clipped)
return self.critic_coef * loss_value

@property
Expand Down
Loading