-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Update all instances of exploration *Wrapper
to *Module
#2298
Merged
vmoens
merged 2 commits into
pytorch:main
from
kurtamohler:update-exploration-modules-0
Jul 22, 2024
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import hydra | ||
import torch | ||
|
||
from tensordict.nn import TensorDictModule | ||
from tensordict.nn import TensorDictModule, TensorDictSequential | ||
from torch import nn | ||
from torchrl._utils import logger as torchrl_logger | ||
from torchrl.collectors import SyncDataCollector | ||
|
@@ -18,7 +18,7 @@ | |
from torchrl.envs.libs.vmas import VmasEnv | ||
from torchrl.envs.utils import ExplorationType, set_exploration_type | ||
from torchrl.modules import ( | ||
AdditiveGaussianWrapper, | ||
AdditiveGaussianModule, | ||
ProbabilisticActor, | ||
TanhDelta, | ||
ValueOperator, | ||
|
@@ -102,10 +102,13 @@ def train(cfg: "DictConfig"): # noqa: F821 | |
return_log_prob=False, | ||
) | ||
|
||
policy_explore = AdditiveGaussianWrapper( | ||
policy_explore = TensorDictSequential( | ||
policy, | ||
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), | ||
action_key=env.action_key, | ||
AdditiveGaussianModule( | ||
spec=env.unbatched_action_spec, | ||
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), | ||
action_key=env.action_key, | ||
), | ||
) | ||
|
||
# Critic | ||
|
@@ -200,7 +203,7 @@ def train(cfg: "DictConfig"): # noqa: F821 | |
optim.zero_grad() | ||
target_net_updater.step() | ||
|
||
policy_explore.step(frames=current_frames) # Update exploration annealing | ||
policy_explore[1].step(frames=current_frames) # Update exploration annealing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the example I gave above, the |
||
collector.update_policy_weights_() | ||
|
||
training_time = time.time() - training_start | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not completely sure if this is the best way to do this. Does there happen to be some alternative to
TensorDictSequential
which does essentially the same thing but also provides astep
function?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it looks like the same thing was done when
EGreedyWrapper
was updated toEGreedyModule
, so I guess it's alright:rl/sota-implementations/bandits/dqn.py
Lines 89 to 91 in bdc9784
rl/sota-implementations/bandits/dqn.py
Line 125 in bdc9784
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep either that or
We could make sure that all exploration modules have the same parent class and use that update function across examples.