-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix policy_to_train
logic for new API stack generically for all algorithms.
#41529
[RLlib] Fix policy_to_train
logic for new API stack generically for all algorithms.
#41529
Conversation
…set_trainable_module_fn
…set_trainable_module_fn
…set_trainable_module_fn
main = "examples/self_play_with_open_spiel.py", | ||
tags = ["team:rllib", "exclusive", "examples"], | ||
size = "small", | ||
srcs = ["examples/self_play_with_open_spiel.py"], | ||
args = ["--framework=tf", "--env=connect_four", "--win-rate-threshold=0.9", "--num-episodes-human-play=0", "--as-test", "--min-win-rate=0.6"] | ||
args = ["--framework=tf", "--env=connect_four", "--win-rate-threshold=0.9", "--num-episodes-human-play=0", "--as-test", "--min-league-size=3"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see below on why min-win-rate does not make sense
rllib/algorithms/algorithm.py
Outdated
@@ -778,7 +778,27 @@ def setup(self, config: AlgorithmConfig) -> None: | |||
modules_to_load=modules_to_load, | |||
rl_module_ckpt_dirs=rl_module_ckpt_dirs, | |||
) | |||
# sync the weights from the learner group to the rollout workers | |||
# Setup proper policies-to-train/shoul-module-be-updated functions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ground truth for the "is module trainable" (now renamed into the more clear: "should_module_be_updated_fn") is now the learner group. This is due to the fact that it might very well be in the future that we do NOT need this information at all anymore on the EnvRunners (RolloutWorkers) as these are only concerned with sampling and not training/updating (separation of concerns).
Either way, for now, both LearnerGroup AND RolloutWorkers will carry this information properly at all times.
@@ -23,7 +23,7 @@ | |||
from ray.rllib.core.learner.learner import LearnerHyperparameters | |||
from ray.rllib.core.learner.learner_group_config import LearnerGroupConfig, ModuleSpec | |||
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec | |||
from ray.rllib.core.rl_module.rl_module import ModuleID, SingleAgentRLModuleSpec | |||
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this into utils.typing
for better structure of the lib.
@@ -169,8 +169,6 @@ def training_step(self) -> ResultDict: | |||
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() | |||
|
|||
# Updating the policy. | |||
is_module_trainable = self.workers.local_worker().is_policy_to_train |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These algo specific "hacks" are not needed anymore, which makes the algo specific training_step
code a little easier to read.
@@ -437,21 +437,6 @@ def training_step(self) -> ResultDict: | |||
if self.config._enable_new_api_stack: | |||
# TODO (Kourosh) Clearly define what train_batch_size | |||
# vs. sgd_minibatch_size and num_sgd_iter is in the config. | |||
# TODO (Kourosh) Do this inside the Learner so that we don't have to do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same: see BC above
# lengths in B. See SampleBatch for more information. | ||
if ( | ||
self.module[pid].is_stateful() | ||
or policy_batch.get("seq_lens") is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug fix: For APPO, we do have SEQ_LENS and need to slice on the B-axis (not T-axis), but the Module might still not be stateful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does policy_batch.get("seq_lens") is not None
cover self.module[pid].is_stateful()
or no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think these are different concepts. You might have a stateful model that does NOT operate on (time series) sequences (e.g. DreamerV3; well, it does operate on sequences but does NOT need the seq_lens
key) or vice-versa (like transformers, operating on sequences but statelessly).
@@ -51,10 +51,13 @@ def _get_backend_config(learner_class: Type["Learner"]) -> str: | |||
return backend_config | |||
|
|||
|
|||
def _is_module_trainable(module_id: ModuleID, batch: MultiAgentBatch) -> bool: | |||
"""Default implemntation for is_module_trainable() | |||
def _default_should_module_be_updated_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaner (more descriptive) naming.
With the old name, one could be confused about the module itself not being trainable (frozen weights, etc..?). However, the control here happens on a much higher level (LearnerGroup decides on filtering single-agent batches for modules that should not be updated before(!) even sending individual batches to the individual Learner workers).
@@ -433,25 +436,6 @@ def remove_module(self, module_id: ModuleID) -> None: | |||
refs.append(ref) | |||
ray.get(refs) | |||
|
|||
def set_weights(self, weights: Mapping[str, Any]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this to the right place: getter before setter
…set_trainable_module_fn
# lengths in B. See SampleBatch for more information. | ||
if ( | ||
self.module[pid].is_stateful() | ||
or policy_batch.get("seq_lens") is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does policy_batch.get("seq_lens") is not None
cover self.module[pid].is_stateful()
or no?
""" | ||
if self.is_local: | ||
return self._learner.get_state() | ||
self._learner.set_module_state(weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume these are exactly copy pasted.
learner_state = self._get_results(results)[0] | ||
return { | ||
"learner_state": learner_state, | ||
"should_module_be_updated_fn": self.should_module_be_updated_fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
rllib/core/learner/learner_group.py
Outdated
# If container given, construct a simple callable returning True | ||
# if the ModuleID is found in the list/set of IDs. | ||
elif not callable(should_module_be_updated_fn): | ||
assert isinstance(should_module_be_updated_fn, (list, set, tuple)), ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't raise assertion errors for things that user have to take an action on. Only use assertion for internal violation of assumption, things that should point to a bug. When people report an assertion error we immediately know we have a bug an not mis-use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right! Fixed.
|
||
self._should_module_be_updated_fn = should_module_be_updated_fn | ||
|
||
# TODO (sven): Why did we chose to re-invent the wheel here and provide load/save |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair. I think it's this philosophy that things should be usable stand-alone.
) | ||
|
||
ray.init( | ||
num_cpus=args.num_cpus or None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok, I think this is the pattern we should use for all the setup calls of all tests then
Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Signed-off-by: Sven Mika <sven@anyscale.io>
…o fix_set_trainable_module_fn
…set_trainable_module_fn
…set_trainable_module_fn
…set_trainable_module_fn
…set_trainable_module_fn
…set_trainable_module_fn
This PR solves a couple of problems related to APPO/IMPALA on the new API stack:
policy_to_train
logic for the new API stack, generically for all algorithms. No more specific logic will be required in each algorithm'straining_step
method as the mail filter step moves completely intoLearnerGroup
.The new API stack should now work properly with APPO and IMPALA as an additional CI test cases prove.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.