Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Adds value clipping in ClipPPOLoss loss #2005

Merged
merged 44 commits into from
Mar 18, 2024

Conversation

albertbou92
Copy link
Contributor

Description

This PR add a value clipping option in ClipPPOLoss loss.

This PR is related to #1977.

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213 if this solves the issue #15213

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Mar 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2005

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures, 1 Unrelated Failure

As of commit c72347c with merge base c371266 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 9, 2024
@@ -664,6 +664,8 @@ class ClipPPOLoss(PPOLoss):
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
clip_value_loss (bool, optional): if ``True``, the value loss will be clipped with respect to the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have this

  • not provided / False = no clipping
  • True = same clipping as log ratio
  • number = clip with that number

That way you can reuse the same semantic for other losses

@albertbou92 albertbou92 requested a review from vmoens March 10, 2024 17:23
@albertbou92
Copy link
Contributor Author

We can potentially add it to more losses (PPOLoss, KLPENPPOLoss, A2C, Reinforce) with the only difference that we would require the user to provide a float and not a bool since we can not default to the log ratio clipping value.

But I am not sure if we should see if the feature is asked first for those other losses, since I have never seen in used in those.

@vmoens
Copy link
Contributor

vmoens commented Mar 10, 2024

I don't thinks it's a problem to add it, we don't want to copy existing repos but enable people to experiment and swap configurations quickly.

@vmoens vmoens added the enhancement New feature or request label Mar 10, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Cool feature!
A couple of comments along the way but otherwise LGTM

torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
torchrl/objectives/reinforce.py Outdated Show resolved Hide resolved
torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
albertbou92 and others added 2 commits March 12, 2024 08:59
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
albertbou92 and others added 7 commits March 12, 2024 09:00
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
@albertbou92
Copy link
Contributor Author

Thanks for the feedback! I integrated the suggested changes, but I think it makes sense to add the metrics suggested in #1977 (comment) if they are helpful

@albertbou92
Copy link
Contributor Author

albertbou92 commented Mar 12, 2024

I added the clip_fraction for both the PPO loss and the Value loss

torchrl/objectives/a2c.py Show resolved Hide resolved
@@ -506,6 +527,16 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
f"can be used for the value loss."
)

if self.clip_value:
try:
old_state_value = tensordict.get(self.tensor_keys.value).clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we clone here?

Copy link
Contributor Author

@albertbou92 albertbou92 Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the self.tensor_keys.value prediction will be recomputed once we pass tensordict through the value network and we will lose this tensor. Does it make sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never write anything in-place AFAICT

torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
)
# Chose the most pessimistic value prediction between clipped and non-clipped
loss_value = torch.max(loss_value, loss_value_clipped)
clip_fraction = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get what clip_fraction represents

One thing I like to look at is what proportion of data is being clipped

Also we should be careful about this since it doesn't come for free. The overhead introduced by logging this metric could impact performance...

Copy link
Contributor Author

@albertbou92 albertbou92 Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as I understand it, it is a measure of how much the old model that collected the data differs from the current model. If the fraction values stay close to 1.0 the model is changing slowly. If the ratio consistently deviates from 1.0 and approaches or hits the clipping value, it indicates that the new model is substantially different from the old model, indicating a faster rate of model change. It might help better understand the learning process, but we could make it optional if it too costly.

torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some very minor details (clone and type annotations) and we're good to go!

@@ -506,6 +527,16 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
f"can be used for the value loss."
)

if self.clip_value:
try:
old_state_value = tensordict.get(self.tensor_keys.value).clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never write anything in-place AFAICT

torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
torchrl/objectives/reinforce.py Outdated Show resolved Hide resolved
torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
albertbou92 and others added 5 commits March 14, 2024 16:43
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
@albertbou92 albertbou92 requested a review from vmoens March 18, 2024 08:39
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens merged commit 43c6bca into pytorch:main Mar 18, 2024
60 of 67 checks passed
@vmoens vmoens deleted the clip_value_loss branch March 18, 2024 11:05
SandishKumarHN pushed a commit to SandishKumarHN/rl that referenced this pull request Mar 18, 2024
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vmoens@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants