Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor losses (value function, doc, input batch size) #987

Merged
merged 25 commits into from
Mar 28, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Mar 23, 2023

Description

This is a major refactoring of the loss modules.
From now on, loss and value estimation are decoupled as much as can be. This leaves the possibility to the user to pick the value function that she feels appropriate.

In a nutshell, the previous API left little choice of value estimator:

dqn_loss = DQNLoss(actor, gamma=0.9) # always uses TD0

From now on, one can pick up the value function after creating the loss:

dqn_loss = DQNLoss(actor)
dqn_loss.make_value_function(ValueFunctions.TDLambda, gamma=0.99, lmbda=0.95) # hyperparameters can also be ommitted.

Each loss is equipped with a default value function that can be accessed via

dqn_loss.default_value_function

If make_value_function is not called, then this default value estimator will be used.

The reason why we don't simply pass a ValueFunction module is that some losses (eg SAC) have an intricated value:
Q(s,a) - log_prob(a|s) for instance. Hence we can't just have the users make vf = ValueFunction(network, ...) and pass vf to the loss. With this Enum trick, we get the best of both worlds: total flexibility in user-facing constructors and full compatibility in the backend.

BC-breaking changes

  • We've renamed the value function modules: TD0Estimate, TD1Estimate, TDLambdaEstimate
  • Passing gamma or lmbda to the module raises a warning. However for bc-compatibility the value passed will still be used *unless another is passed via make_value_function.
  • Various typos
  • All value functions now inherit from a common parent class.
  • ValueEstimators and loss modules have a lot more keyword only arguments

Unaddressed

We did not take care of making all losses compatible with GAE. The only reason is that GAE requires the value (not only the next value), which we currently don't support. This can be addressed separately.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 23, 2023
@vmoens vmoens force-pushed the add_adv_module_lossesd branch from 2efc4d4 to dd6ac56 Compare March 27, 2023 10:48
@vmoens vmoens added enhancement New feature or request bc breaking backward compatibility breaking change labels Mar 28, 2023
@vmoens vmoens merged commit 0b2d2d8 into main Mar 28, 2023
@vmoens vmoens mentioned this pull request Mar 29, 2023
10 tasks
@vmoens vmoens deleted the add_adv_module_lossesd branch March 31, 2023 16:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bc breaking backward compatibility breaking change CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants