Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Loss function convert_to_functional seems to generate backprop issues with shared policy-value networks #1034

Closed

Conversation

albertbou92
Copy link
Contributor

@albertbou92 albertbou92 commented Apr 11, 2023

Description

At least in PPO (but also probably in the rest of the Objective classes), calling self.convert_to_functional in the following lines when using a policy-value shared architecture causes the value network not to update correctly. Could the current code be stopping the value loss from being propagated to the first layers?
https://github.com/pytorch/rl/blob/main/torchrl/objectives/ppo.py#L113
https://github.com/pytorch/rl/blob/main/torchrl/objectives/ppo.py#L117

This PR removes the calling of self.convert_to_functional. I ran it in a test case for Atari Pong env and as the plots show it solved the issue. Green line is the modified version of the code. However, a proper solution should probably be discussed because I am not sure if simply removing these lines causes problems somewhere else.

PPO_bug1

PPO_bug2

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 11, 2023
@vmoens
Copy link
Contributor

vmoens commented Apr 11, 2023

When you backprop the two losses, I guess that in your implementation the gradients for the common modules are a sum of gradient for the two losses, right?
What compare_against does is that it feeds the module (in this case the critic) with a detached version of the params used for the other module (in this case the policy). This is what is done in some algos but I understand that it may be suboptimal in other cases, especially if the reset of your critic is very shallow.

Have you tried to just remove the compare_against?

If that works, a solution could be an optional call compare_against with a default to not doing it.

@albertbou92 albertbou92 force-pushed the convert_to_functional_bug branch from f13d8f5 to bdf9949 Compare April 11, 2023 08:34
@albertbou92
Copy link
Contributor Author

Some updates. I actually only need to modify this line: https://github.com/pytorch/rl/blob/main/torchrl/objectives/ppo.py#L117

Here is what works and what does not:
self.convert_to_functional(critic, "critic", compare_against=self.actor_params) # Value network update is not correct
self.convert_to_functional(critic, "critic") # Value network update is not correct either
self.critic = critic # This works

So removing the compare_against it does not seem to be the solution.

@albertbou92
Copy link
Contributor Author

When you backprop the two losses, I guess that in your implementation the gradients for the common modules are a sum of gradient for the two losses, right? What compare_against does is that it feeds the module (in this case the critic) with a detached version of the params used for the other module (in this case the policy). This is what is done in some algos but I understand that it may be suboptimal in other cases, especially if the reset of your critic is very shallow.

Have you tried to just remove the compare_against?

If that works, a solution could be an optional call compare_against with a default to not doing it.

Yes for the code to work the gradients have to be the sum of the two losses

@vmoens
Copy link
Contributor

vmoens commented Apr 11, 2023

You're right, even without that, there is not param tighing. Let me fix that to have a more predictable behaviour.

@vmoens
Copy link
Contributor

vmoens commented Apr 13, 2023

Closed by #1037

@vmoens vmoens closed this Apr 13, 2023
@albertbou92 albertbou92 deleted the convert_to_functional_bug branch January 18, 2024 10:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants