-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Adds value clipping in ClipPPOLoss loss #2005
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2005
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New Failures, 1 Unrelated FailureAs of commit c72347c with merge base c371266 (): NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchrl/objectives/ppo.py
Outdated
@@ -664,6 +664,8 @@ class ClipPPOLoss(PPOLoss): | |||
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, | |||
``"mean"``: the sum of the output will be divided by the number of | |||
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. | |||
clip_value_loss (bool, optional): if ``True``, the value loss will be clipped with respect to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have this
- not provided / False = no clipping
- True = same clipping as log ratio
- number = clip with that number
That way you can reuse the same semantic for other losses
We can potentially add it to more losses (PPOLoss, KLPENPPOLoss, A2C, Reinforce) with the only difference that we would require the user to provide a float and not a bool since we can not default to the log ratio clipping value. But I am not sure if we should see if the feature is asked first for those other losses, since I have never seen in used in those. |
I don't thinks it's a problem to add it, we don't want to copy existing repos but enable people to experiment and swap configurations quickly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! Cool feature!
A couple of comments along the way but otherwise LGTM
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Thanks for the feedback! I integrated the suggested changes, but I think it makes sense to add the metrics suggested in #1977 (comment) if they are helpful |
I added the |
torchrl/objectives/ppo.py
Outdated
@@ -506,6 +527,16 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: | |||
f"can be used for the value loss." | |||
) | |||
|
|||
if self.clip_value: | |||
try: | |||
old_state_value = tensordict.get(self.tensor_keys.value).clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we clone here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the self.tensor_keys.value
prediction will be recomputed once we pass tensordict through the value network and we will lose this tensor. Does it make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We never write anything in-place AFAICT
torchrl/objectives/a2c.py
Outdated
) | ||
# Chose the most pessimistic value prediction between clipped and non-clipped | ||
loss_value = torch.max(loss_value, loss_value_clipped) | ||
clip_fraction = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get what clip_fraction represents
One thing I like to look at is what proportion of data is being clipped
Also we should be careful about this since it doesn't come for free. The overhead introduced by logging this metric could impact performance...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, as I understand it, it is a measure of how much the old model that collected the data differs from the current model. If the fraction values stay close to 1.0 the model is changing slowly. If the ratio consistently deviates from 1.0 and approaches or hits the clipping value, it indicates that the new model is substantially different from the old model, indicating a faster rate of model change. It might help better understand the learning process, but we could make it optional if it too costly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some very minor details (clone and type annotations) and we're good to go!
torchrl/objectives/ppo.py
Outdated
@@ -506,6 +527,16 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: | |||
f"can be used for the value loss." | |||
) | |||
|
|||
if self.clip_value: | |||
try: | |||
old_state_value = tensordict.get(self.tensor_keys.value).clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We never write anything in-place AFAICT
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Co-authored-by: Vincent Moens <vincentmoens@gmail.com> Co-authored-by: Vincent Moens <vmoens@meta.com>
Description
This PR add a value clipping option in ClipPPOLoss loss.
This PR is related to #1977.
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!