[BugFix] Loss function convert_to_functional seems to generate backprop issues with shared policy-value networks #1034
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
At least in PPO (but also probably in the rest of the Objective classes), calling self.convert_to_functional in the following lines when using a policy-value shared architecture causes the value network not to update correctly. Could the current code be stopping the value loss from being propagated to the first layers?
https://github.com/pytorch/rl/blob/main/torchrl/objectives/ppo.py#L113
https://github.com/pytorch/rl/blob/main/torchrl/objectives/ppo.py#L117
This PR removes the calling of self.convert_to_functional. I ran it in a test case for Atari Pong env and as the plots show it solved the issue. Green line is the modified version of the code. However, a proper solution should probably be discussed because I am not sure if simply removing these lines causes problems somewhere else.