-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Update all instances of exploration *Wrapper
to *Module
#2298
[Refactor] Update all instances of exploration *Wrapper
to *Module
#2298
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2298
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 1 Pending, 3 Unrelated FailuresAs of commit 085bef2 with merge base bdc9784 (): NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -108,7 +108,7 @@ def main(cfg: "DictConfig"): # noqa: F821 | |||
for _, tensordict in enumerate(collector): | |||
sampling_time = time.time() - sampling_start | |||
# Update exploration policy | |||
exploration_policy.step(tensordict.numel()) | |||
exploration_policy[1].step(tensordict.numel()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not completely sure if this is the best way to do this. Does there happen to be some alternative to TensorDictSequential
which does essentially the same thing but also provides a step
function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it looks like the same thing was done when EGreedyWrapper
was updated to EGreedyModule
, so I guess it's alright:
rl/sota-implementations/bandits/dqn.py
Lines 89 to 91 in bdc9784
policy = TensorDictSequential( | |
actor, | |
EGreedyModule( |
rl/sota-implementations/bandits/dqn.py
Line 125 in bdc9784
policy[1].step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep either that or
def update_exploration(module):
if isinstance(module, ExplorationModule):
module.set()
policy.apply(update_exploration)
We could make sure that all exploration modules have the same parent class and use that update function across examples.
9c4ccbd
to
b40430b
Compare
*Wrapper
to *Module
*Wrapper
to *Module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for this
In a second time we could consider refactoring the update methods, happy to read your thoughts about this
@@ -108,7 +108,7 @@ def main(cfg: "DictConfig"): # noqa: F821 | |||
for _, tensordict in enumerate(collector): | |||
sampling_time = time.time() - sampling_start | |||
# Update exploration policy | |||
exploration_policy.step(tensordict.numel()) | |||
exploration_policy[1].step(tensordict.numel()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep either that or
def update_exploration(module):
if isinstance(module, ExplorationModule):
module.set()
policy.apply(update_exploration)
We could make sure that all exploration modules have the same parent class and use that update function across examples.
@@ -200,7 +203,7 @@ def train(cfg: "DictConfig"): # noqa: F821 | |||
optim.zero_grad() | |||
target_net_updater.step() | |||
|
|||
policy_explore.step(frames=current_frames) # Update exploration annealing | |||
policy_explore[1].step(frames=current_frames) # Update exploration annealing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the example I gave above, the update_exploration
or step_exploration
should be turned into a class to allow us to pass the current_frames
Description
Update instances of
AdditiveGaussianWrapper
-->AdditiveGaussianModule
OrnsteinUhlenbeckProcessWrapper
-->OrnsteinUhlenbeckProcessModule
everywhere in the code base, except in
test/test_exploration.py
, which should still test both the wrappers and modules until we finally remove the wrappers in the future.Motivation and Context
close #2295
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!