Skip to content

Commit

Permalink
[BugFix] Fix clip_fraction in PO losses (#2021)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 19, 2024
1 parent 9747170 commit 4bce371
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
3 changes: 2 additions & 1 deletion torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,12 +856,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
gain1 = log_weight.exp() * advantage

log_weight_clip = log_weight.clamp(*self._clip_bounds)
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
ratio = log_weight_clip.exp()
gain2 = ratio * advantage

gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
td_out = TensorDict({"loss_objective": -gain}, batch_size=[])
td_out.set("clip_fraction", ratio.abs().detach())
td_out.set("clip_fraction", clip_fraction)

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
Expand Down
12 changes: 5 additions & 7 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,18 +530,16 @@ def _clip_value_loss(
and returns the most pessimistic value prediction between clipped and non-clipped options.
It also computes the clip fraction.
"""
state_value_clipped = old_state_value + (state_value - old_state_value).clamp(
-clip_value, clip_value
)
pre_clipped = state_value - old_state_value
clipped = pre_clipped.clamp(-clip_value, clip_value)
with torch.no_grad():
clip_fraction = (pre_clipped != clipped).to(state_value.dtype).mean()
state_value_clipped = old_state_value + clipped
loss_value_clipped = distance_loss(
target_return,
state_value_clipped,
loss_function=loss_critic_type,
)
# Chose the most pessimistic value prediction between clipped and non-clipped
loss_value = torch.max(loss_value, loss_value_clipped)
with torch.no_grad():
clip_fraction = (
(state_value / old_state_value).clamp(1 - clip_value, 1 + clip_value).abs()
)
return loss_value, clip_fraction

0 comments on commit 4bce371

Please sign in to comment.