-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add and enhance fault-tolerance tests for APPO. #40743
Changes from 11 commits
11fa93f
3515c90
a1c9228
d040ae0
5ccd02a
a018866
af2346e
c46778d
510823c
b45d4c2
e7af56f
6523878
e5d03d4
f25f4ac
7f9d024
ebfca45
a5cb79e
43fa7f7
d1a9e7a
77a7140
a3824dd
b0bebed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,9 +11,10 @@ | |
from ray.rllib.evaluation.episode import Episode | ||
from ray.rllib.examples.env.random_env import RandomEnv | ||
from ray.rllib.utils.test_utils import framework_iterator | ||
from ray import tune | ||
|
||
|
||
class OnWorkerCreatedCallbacks(DefaultCallbacks): | ||
class OnWorkersRecreatedCallbacks(DefaultCallbacks): | ||
def on_workers_recreated( | ||
self, | ||
*, | ||
|
@@ -102,39 +103,48 @@ def on_episode_created( | |
class TestCallbacks(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
ray.init() | ||
ray.init(num_cpus=12) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please don’t add num cpus. On workspaces you cannot run this code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
ray.shutdown() | ||
|
||
def test_on_workers_recreated_callback(self): | ||
tune.register_env("env", lambda cfg: CartPoleCrashing(cfg)) | ||
|
||
config = ( | ||
APPOConfig() | ||
.environment(CartPoleCrashing) | ||
.callbacks(OnWorkerCreatedCallbacks) | ||
.rollouts(num_rollout_workers=2) | ||
.environment("env") | ||
.callbacks(OnWorkersRecreatedCallbacks) | ||
.rollouts(num_rollout_workers=3) | ||
.fault_tolerance(recreate_failed_workers=True) | ||
) | ||
|
||
for _ in framework_iterator(config, frameworks=("tf2", "torch")): | ||
for _ in framework_iterator(config, frameworks=("torch", "tf2")): | ||
algo = config.build() | ||
original_worker_ids = algo.workers.healthy_worker_ids() | ||
for id_ in original_worker_ids: | ||
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] == 0) | ||
self.assertTrue(algo._counters["total_num_workers_recreated"] == 0) | ||
|
||
# After building the algorithm, we should have 2 healthy (remote) workers. | ||
self.assertTrue(len(original_worker_ids) == 2) | ||
self.assertTrue(len(original_worker_ids) == 3) | ||
|
||
# Train a bit (and have the envs/workers crash a couple of times). | ||
for _ in range(3): | ||
algo.train() | ||
for _ in range(5): | ||
print(algo.train()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to print? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's clearer to see the results when you are debugging/watching the test run, no? Leaving this in. |
||
|
||
# After training, each new worker should have been recreated at least once. | ||
# After training, the `on_workers_recreated` callback should have captured | ||
# the exact worker IDs recreated (the exact number of times) as the actor | ||
# manager itself. This confirms that the callback is triggered correctly, | ||
# always. | ||
new_worker_ids = algo.workers.healthy_worker_ids() | ||
self.assertTrue(len(new_worker_ids) == 2) | ||
self.assertTrue(len(new_worker_ids) == 3) | ||
for id_ in new_worker_ids: | ||
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] >= 1) | ||
num_restored = algo.workers.restored_actors_history[id_] | ||
self.assertTrue( | ||
algo._counters[f"worker_{id_}_recreated"] == num_restored | ||
) | ||
algo.stop() | ||
|
||
def test_on_init_and_checkpoint_loaded(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,45 +11,89 @@ | |
|
||
|
||
class CartPoleCrashing(CartPoleEnv): | ||
"""A CartPole env that crashes from time to time. | ||
"""A CartPole env that crashes (or stalls) from time to time. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added option to also just make this env stall (not crash) for a while. |
||
|
||
Useful for testing faulty sub-env (within a vectorized env) handling by | ||
RolloutWorkers. | ||
EnvRunners. | ||
|
||
After crashing, the env expects a `reset()` call next (calling `step()` will | ||
result in yet another error), which may or may not take a very long time to | ||
complete. This simulates the env having to reinitialize some sub-processes, e.g. | ||
an external connection. | ||
|
||
The env can also be configured to stall (and do nothing during a call to `step()`) | ||
from time to time for a configurable amount of time. | ||
""" | ||
|
||
def __init__(self, config=None): | ||
super().__init__() | ||
|
||
config = config or {} | ||
self.config = config or {} | ||
|
||
# Crash probability (in each `step()`). | ||
self.p_crash = config.get("p_crash", 0.005) | ||
# Crash probability when `reset()` is called. | ||
self.p_crash_reset = config.get("p_crash_reset", 0.0) | ||
# Crash exactly after every n steps. If a 2-tuple, will uniformly sample | ||
# crash timesteps from in between the two given values. | ||
self.crash_after_n_steps = config.get("crash_after_n_steps") | ||
# Only crash (with prob=p_crash) if on certain worker indices. | ||
self._crash_after_n_steps = None | ||
assert ( | ||
self.crash_after_n_steps is None | ||
or isinstance(self.crash_after_n_steps, int) | ||
or ( | ||
isinstance(self.crash_after_n_steps, tuple) | ||
and len(self.crash_after_n_steps) == 2 | ||
) | ||
) | ||
# Only ever crash, if on certain worker indices. | ||
faulty_indices = config.get("crash_on_worker_indices", None) | ||
if faulty_indices and config.worker_index not in faulty_indices: | ||
self.p_crash = 0.0 | ||
self.p_crash_reset = 0.0 | ||
self.crash_after_n_steps = None | ||
|
||
# Stall probability (in each `step()`). | ||
self.p_stall = config.get("p_stall", 0.0) | ||
# Stall probability when `reset()` is called. | ||
self.p_stall_reset = config.get("p_stall_reset", 0.0) | ||
# Stall exactly after every n steps. | ||
self.stall_after_n_steps = config.get("stall_after_n_steps") | ||
self._stall_after_n_steps = None | ||
# Amount of time to stall. If a 2-tuple, will uniformly sample from in between | ||
# the two given values. | ||
self.stall_time_sec = config.get("stall_time_sec") | ||
assert ( | ||
self.stall_time_sec is None | ||
or isinstance(self.stall_time_sec, (int, float)) | ||
or ( | ||
isinstance(self.stall_time_sec, tuple) and len(self.stall_time_sec) == 2 | ||
) | ||
) | ||
|
||
# Only ever stall, if on certain worker indices. | ||
faulty_indices = config.get("stall_on_worker_indices", None) | ||
if faulty_indices and config.worker_index not in faulty_indices: | ||
self.p_stall = 0.0 | ||
self.p_stall_reset = 0.0 | ||
self.stall_after_n_steps = None | ||
|
||
# Timestep counter for the ongoing episode. | ||
self.timesteps = 0 | ||
|
||
# Time in seconds to initialize (in this c'tor). | ||
sample = 0.0 | ||
if "init_time_s" in config: | ||
init_time_s = config.get("init_time_s", 0) | ||
else: | ||
init_time_s = np.random.randint( | ||
config.get("init_time_s_min", 0), | ||
config.get("init_time_s_max", 1), | ||
sample = ( | ||
config["init_time_s"] | ||
if not isinstance(config["init_time_s"], tuple) | ||
else np.random.uniform( | ||
config["init_time_s"][0], config["init_time_s"][1] | ||
) | ||
) | ||
print(f"Initializing crashing env with init-delay of {init_time_s}sec ...") | ||
time.sleep(init_time_s) | ||
|
||
print(f"Initializing crashing env (with init-delay of {sample}sec) ...") | ||
time.sleep(sample) | ||
|
||
# No env pre-checking? | ||
self._skip_env_checking = config.get("skip_env_checking", False) | ||
|
@@ -61,30 +105,81 @@ def __init__(self, config=None): | |
def reset(self, *, seed=None, options=None): | ||
# Reset timestep counter for the new episode. | ||
self.timesteps = 0 | ||
self._crash_after_n_steps = None | ||
|
||
# Should we crash? | ||
if self._rng.rand() < self.p_crash_reset or ( | ||
self.crash_after_n_steps is not None and self.crash_after_n_steps == 0 | ||
): | ||
if self._should_crash(p=self.p_crash_reset): | ||
raise EnvError( | ||
"Simulated env crash in `reset()`! Feel free to use any " | ||
"other exception type here instead." | ||
# f"Simulated env crash on worker={self.config.worker_index} " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. This was a problem with the EnvContext being passed in with an empty dict and then in the env code:
Even if |
||
# f"env-idx={self.config.vector_index} during `reset()`! " | ||
# "Feel free to use any other exception type here instead." | ||
) | ||
# Should we stall for a while? | ||
self._stall_if_necessary(p=self.p_stall_reset) | ||
|
||
return super().reset() | ||
|
||
@override(CartPoleEnv) | ||
def step(self, action): | ||
# Increase timestep counter for the ongoing episode. | ||
self.timesteps += 1 | ||
|
||
# Should we crash? | ||
if self._rng.rand() < self.p_crash or ( | ||
self.crash_after_n_steps and self.crash_after_n_steps == self.timesteps | ||
): | ||
if self._should_crash(p=self.p_crash): | ||
raise EnvError( | ||
"Simulated env crash in `step()`! Feel free to use any " | ||
"other exception type here instead." | ||
# f"Simulated env crash on worker={self.config.worker_index} " | ||
# f"env-idx={self.config.vector_index} during `step()`! " | ||
# "Feel free to use any other exception type here instead." | ||
) | ||
# No crash. | ||
# Should we stall for a while? | ||
self._stall_if_necessary(p=self.p_stall) | ||
|
||
return super().step(action) | ||
|
||
def _should_crash(self, p): | ||
rnd = self._rng.rand() | ||
if rnd < p: | ||
print(f"Should crash! ({rnd} < {p})") | ||
return True | ||
elif self.crash_after_n_steps is not None: | ||
if self._crash_after_n_steps is None: | ||
self._crash_after_n_steps = ( | ||
self.crash_after_n_steps | ||
if not isinstance(self.crash_after_n_steps, tuple) | ||
else np.random.randint( | ||
self.crash_after_n_steps[0], self.crash_after_n_steps[1] | ||
) | ||
) | ||
if self._crash_after_n_steps == self.timesteps: | ||
print(f"Should crash! (after {self.timesteps} steps)") | ||
return True | ||
|
||
return False | ||
|
||
def _stall_if_necessary(self, p): | ||
stall = False | ||
if self._rng.rand() < p: | ||
stall = True | ||
elif self.stall_after_n_steps is not None: | ||
if self._stall_after_n_steps is None: | ||
self._stall_after_n_steps = ( | ||
self.stall_after_n_steps | ||
if not isinstance(self.stall_after_n_steps, tuple) | ||
else np.random.randint( | ||
self.stall_after_n_steps[0], self.stall_after_n_steps[1] | ||
) | ||
) | ||
if self._stall_after_n_steps == self.timesteps: | ||
stall = True | ||
|
||
if stall: | ||
sec = ( | ||
self.stall_time_sec | ||
if not isinstance(self.stall_time_sec, tuple) | ||
else np.random.uniform(self.stall_time_sec[0], self.stall_time_sec[1]) | ||
) | ||
print(f" -> will stall for {sec}sec ...") | ||
time.sleep(sec) | ||
|
||
|
||
MultiAgentCartPoleCrashing = make_multi_agent(lambda config: CartPoleCrashing(config)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,7 +140,12 @@ def __init__( | |
|
||
@override(LearnerThread) | ||
def step(self) -> None: | ||
assert self.loader_thread.is_alive() | ||
if not self.loader_thread.is_alive(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better error message. |
||
raise RuntimeError( | ||
"The `_MultiGPULoaderThread` has died! Will therefore also terminate " | ||
"the `MultiGPULearnerThread`." | ||
) | ||
|
||
with self.load_wait_timer: | ||
buffer_idx, released = self.ready_tower_stacks_buffer.get() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added these new fault-tolerant tests for APPO.
Old fault-tolerance tests were for PG, which should be moved to
rllib_contrib
AND which is synchronous, not asynchronous. We should probably use PPO in the near future to cover that case again.