Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add and enhance fault-tolerance tests for APPO. #40743

Merged
merged 22 commits into from
Dec 8, 2023
44 changes: 42 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,48 @@ py_test(
args = ["--dir=tuned_examples/appo"]
)

# Tests against crashing or hanging environments.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added these new fault-tolerant tests for APPO.

Old fault-tolerance tests were for PG, which should be moved to rllib_contrib AND which is synchronous, not asynchronous. We should probably use PPO in the near future to cover that case again.

# Single-agent: Crash only.
py_test(
name = "learning_tests_cartpole_crashing_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/cartpole-crashing-recreate-workers-appo.py"],
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
)
# Single-agent: Crash and stall.
py_test(
name = "learning_tests_cartpole_crashing_and_stalling_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/cartpole-crashing-and-stalling-recreate-workers-appo.py"],
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
)
# Multi-agent: Crash only.
py_test(
name = "learning_tests_multi_agent_cartpole_crashing_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/multi-agent-cartpole-crashing-recreate-workers-appo.yaml"],
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
)
# Multi-agent: Crash and stall.
py_test(
name = "learning_tests_multi_agent_cartpole_crashing_and_stalling_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/multi-agent-cartpole-crashing-and-stalling-recreate-workers-appo.py"],
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
)

# CQL
py_test(
name = "learning_tests_pendulum_cql",
Expand Down Expand Up @@ -1576,7 +1618,6 @@ py_test(
args = ["TestCheckpointRestorePPO"]
)


py_test(
name = "tests/test_checkpoint_restore_ppo_gpu",
main = "tests/test_algorithm_checkpoint_restore.py",
Expand All @@ -1595,7 +1636,6 @@ py_test(
args = ["TestCheckpointRestoreOffPolicy"]
)


py_test(
name = "tests/test_checkpoint_restore_off_policy_gpu",
main = "tests/test_algorithm_checkpoint_restore.py",
Expand Down
3 changes: 3 additions & 0 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,9 @@ def restore_workers(self, workers: WorkerSet) -> None:
restored = workers.probe_unhealthy_workers()

if restored:
# Count the restored workers.
self._counters["total_num_restored_workers"] += len(restored)

from_worker = workers.local_worker() or self.workers.local_worker()
# Get the state of the correct (reference) worker. E.g. The local worker
# of the main WorkerSet.
Expand Down
34 changes: 22 additions & 12 deletions rllib/algorithms/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from ray.rllib.evaluation.episode import Episode
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.test_utils import framework_iterator
from ray import tune


class OnWorkerCreatedCallbacks(DefaultCallbacks):
class OnWorkersRecreatedCallbacks(DefaultCallbacks):
def on_workers_recreated(
self,
*,
Expand Down Expand Up @@ -102,39 +103,48 @@ def on_episode_created(
class TestCallbacks(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
ray.init(num_cpus=12)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don’t add num cpus. On workspaces you cannot run this code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


@classmethod
def tearDownClass(cls):
ray.shutdown()

def test_on_workers_recreated_callback(self):
tune.register_env("env", lambda cfg: CartPoleCrashing(cfg))

config = (
APPOConfig()
.environment(CartPoleCrashing)
.callbacks(OnWorkerCreatedCallbacks)
.rollouts(num_rollout_workers=2)
.environment("env")
.callbacks(OnWorkersRecreatedCallbacks)
.rollouts(num_rollout_workers=3)
.fault_tolerance(recreate_failed_workers=True)
)

for _ in framework_iterator(config, frameworks=("tf2", "torch")):
for _ in framework_iterator(config, frameworks=("torch", "tf2")):
algo = config.build()
original_worker_ids = algo.workers.healthy_worker_ids()
for id_ in original_worker_ids:
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] == 0)
self.assertTrue(algo._counters["total_num_workers_recreated"] == 0)

# After building the algorithm, we should have 2 healthy (remote) workers.
self.assertTrue(len(original_worker_ids) == 2)
self.assertTrue(len(original_worker_ids) == 3)

# Train a bit (and have the envs/workers crash a couple of times).
for _ in range(3):
algo.train()
for _ in range(5):
print(algo.train())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to print?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's clearer to see the results when you are debugging/watching the test run, no? Leaving this in.


# After training, each new worker should have been recreated at least once.
# After training, the `on_workers_recreated` callback should have captured
# the exact worker IDs recreated (the exact number of times) as the actor
# manager itself. This confirms that the callback is triggered correctly,
# always.
new_worker_ids = algo.workers.healthy_worker_ids()
self.assertTrue(len(new_worker_ids) == 2)
self.assertTrue(len(new_worker_ids) == 3)
for id_ in new_worker_ids:
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] >= 1)
num_restored = algo.workers.restored_actors_history[id_]
self.assertTrue(
algo._counters[f"worker_{id_}_recreated"] == num_restored
)
algo.stop()

def test_on_init_and_checkpoint_loaded(self):
Expand Down
5 changes: 5 additions & 0 deletions rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,3 +1006,8 @@ def _remote_workers(self) -> List[ActorHandle]:
)
def remote_workers(self) -> List[ActorHandle]:
return list(self.__worker_manager.actors().values())

# TODO: remove
@property
def restored_actors_history(self):
return self.__worker_manager.restored_actors_history
139 changes: 117 additions & 22 deletions rllib/examples/env/cartpole_crashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,89 @@


class CartPoleCrashing(CartPoleEnv):
"""A CartPole env that crashes from time to time.
"""A CartPole env that crashes (or stalls) from time to time.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added option to also just make this env stall (not crash) for a while.


Useful for testing faulty sub-env (within a vectorized env) handling by
RolloutWorkers.
EnvRunners.

After crashing, the env expects a `reset()` call next (calling `step()` will
result in yet another error), which may or may not take a very long time to
complete. This simulates the env having to reinitialize some sub-processes, e.g.
an external connection.

The env can also be configured to stall (and do nothing during a call to `step()`)
from time to time for a configurable amount of time.
"""

def __init__(self, config=None):
super().__init__()

config = config or {}
self.config = config or {}

# Crash probability (in each `step()`).
self.p_crash = config.get("p_crash", 0.005)
# Crash probability when `reset()` is called.
self.p_crash_reset = config.get("p_crash_reset", 0.0)
# Crash exactly after every n steps. If a 2-tuple, will uniformly sample
# crash timesteps from in between the two given values.
self.crash_after_n_steps = config.get("crash_after_n_steps")
# Only crash (with prob=p_crash) if on certain worker indices.
self._crash_after_n_steps = None
assert (
self.crash_after_n_steps is None
or isinstance(self.crash_after_n_steps, int)
or (
isinstance(self.crash_after_n_steps, tuple)
and len(self.crash_after_n_steps) == 2
)
)
# Only ever crash, if on certain worker indices.
faulty_indices = config.get("crash_on_worker_indices", None)
if faulty_indices and config.worker_index not in faulty_indices:
self.p_crash = 0.0
self.p_crash_reset = 0.0
self.crash_after_n_steps = None

# Stall probability (in each `step()`).
self.p_stall = config.get("p_stall", 0.0)
# Stall probability when `reset()` is called.
self.p_stall_reset = config.get("p_stall_reset", 0.0)
# Stall exactly after every n steps.
self.stall_after_n_steps = config.get("stall_after_n_steps")
self._stall_after_n_steps = None
# Amount of time to stall. If a 2-tuple, will uniformly sample from in between
# the two given values.
self.stall_time_sec = config.get("stall_time_sec")
assert (
self.stall_time_sec is None
or isinstance(self.stall_time_sec, (int, float))
or (
isinstance(self.stall_time_sec, tuple) and len(self.stall_time_sec) == 2
)
)

# Only ever stall, if on certain worker indices.
faulty_indices = config.get("stall_on_worker_indices", None)
if faulty_indices and config.worker_index not in faulty_indices:
self.p_stall = 0.0
self.p_stall_reset = 0.0
self.stall_after_n_steps = None

# Timestep counter for the ongoing episode.
self.timesteps = 0

# Time in seconds to initialize (in this c'tor).
sample = 0.0
if "init_time_s" in config:
init_time_s = config.get("init_time_s", 0)
else:
init_time_s = np.random.randint(
config.get("init_time_s_min", 0),
config.get("init_time_s_max", 1),
sample = (
config["init_time_s"]
if not isinstance(config["init_time_s"], tuple)
else np.random.uniform(
config["init_time_s"][0], config["init_time_s"][1]
)
)
print(f"Initializing crashing env with init-delay of {init_time_s}sec ...")
time.sleep(init_time_s)

print(f"Initializing crashing env (with init-delay of {sample}sec) ...")
time.sleep(sample)

# No env pre-checking?
self._skip_env_checking = config.get("skip_env_checking", False)
Expand All @@ -61,30 +105,81 @@ def __init__(self, config=None):
def reset(self, *, seed=None, options=None):
# Reset timestep counter for the new episode.
self.timesteps = 0
self._crash_after_n_steps = None

# Should we crash?
if self._rng.rand() < self.p_crash_reset or (
self.crash_after_n_steps is not None and self.crash_after_n_steps == 0
):
if self._should_crash(p=self.p_crash_reset):
raise EnvError(
"Simulated env crash in `reset()`! Feel free to use any "
"other exception type here instead."
# f"Simulated env crash on worker={self.config.worker_index} "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. This was a problem with the EnvContext being passed in with an empty dict and then in the env code:

self.config = config or {}

Even if config is a EnvContext (with no dict settings), python would still chose the empty dict here, which then does NOT have a worker_indexproperty (b/c it's a dict, not an EnvContext).

# f"env-idx={self.config.vector_index} during `reset()`! "
# "Feel free to use any other exception type here instead."
)
# Should we stall for a while?
self._stall_if_necessary(p=self.p_stall_reset)

return super().reset()

@override(CartPoleEnv)
def step(self, action):
# Increase timestep counter for the ongoing episode.
self.timesteps += 1

# Should we crash?
if self._rng.rand() < self.p_crash or (
self.crash_after_n_steps and self.crash_after_n_steps == self.timesteps
):
if self._should_crash(p=self.p_crash):
raise EnvError(
"Simulated env crash in `step()`! Feel free to use any "
"other exception type here instead."
# f"Simulated env crash on worker={self.config.worker_index} "
# f"env-idx={self.config.vector_index} during `step()`! "
# "Feel free to use any other exception type here instead."
)
# No crash.
# Should we stall for a while?
self._stall_if_necessary(p=self.p_stall)

return super().step(action)

def _should_crash(self, p):
rnd = self._rng.rand()
if rnd < p:
print(f"Should crash! ({rnd} < {p})")
return True
elif self.crash_after_n_steps is not None:
if self._crash_after_n_steps is None:
self._crash_after_n_steps = (
self.crash_after_n_steps
if not isinstance(self.crash_after_n_steps, tuple)
else np.random.randint(
self.crash_after_n_steps[0], self.crash_after_n_steps[1]
)
)
if self._crash_after_n_steps == self.timesteps:
print(f"Should crash! (after {self.timesteps} steps)")
return True

return False

def _stall_if_necessary(self, p):
stall = False
if self._rng.rand() < p:
stall = True
elif self.stall_after_n_steps is not None:
if self._stall_after_n_steps is None:
self._stall_after_n_steps = (
self.stall_after_n_steps
if not isinstance(self.stall_after_n_steps, tuple)
else np.random.randint(
self.stall_after_n_steps[0], self.stall_after_n_steps[1]
)
)
if self._stall_after_n_steps == self.timesteps:
stall = True

if stall:
sec = (
self.stall_time_sec
if not isinstance(self.stall_time_sec, tuple)
else np.random.uniform(self.stall_time_sec[0], self.stall_time_sec[1])
)
print(f" -> will stall for {sec}sec ...")
time.sleep(sec)


MultiAgentCartPoleCrashing = make_multi_agent(lambda config: CartPoleCrashing(config))
7 changes: 6 additions & 1 deletion rllib/execution/multi_gpu_learner_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ def __init__(

@override(LearnerThread)
def step(self) -> None:
assert self.loader_thread.is_alive()
if not self.loader_thread.is_alive():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better error message.

raise RuntimeError(
"The `_MultiGPULoaderThread` has died! Will therefore also terminate "
"the `MultiGPULearnerThread`."
)

with self.load_wait_timer:
buffer_idx, released = self.ready_tower_stacks_buffer.get()

Expand Down
Loading
Loading