Skip to content

Commit

Permalink
[RLlib] Add and enhance fault-tolerance tests for APPO. (ray-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Dec 8, 2023
1 parent 563f7d8 commit 7001982
Show file tree
Hide file tree
Showing 19 changed files with 520 additions and 294 deletions.
44 changes: 42 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,48 @@ py_test(
args = ["--dir=tuned_examples/appo"]
)

# Tests against crashing or hanging environments.
# Single-agent: Crash only.
py_test(
name = "learning_tests_cartpole_crashing_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/cartpole-crashing-recreate-workers-appo.py"],
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
)
# Single-agent: Crash and stall.
py_test(
name = "learning_tests_cartpole_crashing_and_stalling_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/cartpole-crashing-and-stalling-recreate-workers-appo.py"],
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
)
# Multi-agent: Crash only.
py_test(
name = "learning_tests_multi_agent_cartpole_crashing_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/multi-agent-cartpole-crashing-recreate-workers-appo.py"],
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
)
# Multi-agent: Crash and stall.
py_test(
name = "learning_tests_multi_agent_cartpole_crashing_and_stalling_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/multi-agent-cartpole-crashing-and-stalling-recreate-workers-appo.py"],
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
)

# CQL
py_test(
name = "learning_tests_pendulum_cql",
Expand Down Expand Up @@ -1569,7 +1611,6 @@ py_test(
args = ["TestCheckpointRestorePPO"]
)


py_test(
name = "tests/test_checkpoint_restore_ppo_gpu",
main = "tests/test_algorithm_checkpoint_restore.py",
Expand All @@ -1588,7 +1629,6 @@ py_test(
args = ["TestCheckpointRestoreOffPolicy"]
)


py_test(
name = "tests/test_checkpoint_restore_off_policy_gpu",
main = "tests/test_algorithm_checkpoint_restore.py",
Expand Down
3 changes: 3 additions & 0 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,9 @@ def restore_workers(self, workers: WorkerSet) -> None:
restored = workers.probe_unhealthy_workers()

if restored:
# Count the restored workers.
self._counters["total_num_restored_workers"] += len(restored)

from_worker = workers.local_worker() or self.workers.local_worker()
# Get the state of the correct (reference) worker. E.g. The local worker
# of the main WorkerSet.
Expand Down
34 changes: 30 additions & 4 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,12 +860,11 @@ def default_resource_request(
strategy=cf.placement_strategy,
)

def concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]):
def concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]) -> None:
"""Concatenate batches that are being returned from rollout workers
Args:
batches: batches of experiences from rollout workers
batches: List of batches of experiences from EnvRunners.
"""

def aggregate_into_larger_batch():
Expand All @@ -878,6 +877,33 @@ def aggregate_into_larger_batch():
self.batch_being_built = []

for batch in batches:
# TODO (sven): Strange bug in tf/tf2 after a RolloutWorker crash and proper
# restart. The bug is related to (old, non-V2) connectors being used and
# seems to happen inside the AgentCollector's `add_action_reward_next_obs`
# method, at the end of which the number of vf_preds (and all other
# extra action outs) in the batch is one smaller than the number of obs/
# actions/rewards, which leads to a malformed train batch. IMPALA/APPO then
# crash inside the loss function (during v-trace operations). The following
# if-block prevents this from happening and it can be removed once we are
# on the new API stack for good (and use the new connectors and also no
# longer AgentCollectors, RolloutWorkers, Policies, TrajectoryView API,
# etc..):
if (
self.config.batch_mode == "truncate_episodes"
and self.config.enable_connectors
and self.config.recreate_failed_workers
and self.config.framework_str in ["tf", "tf2"]
):
if any(
SampleBatch.VF_PREDS in pb
and (
pb[SampleBatch.VF_PREDS].shape[0]
!= pb[SampleBatch.REWARDS].shape[0]
)
for pb in batch.policy_batches.values()
):
continue

self.batch_being_built.append(batch)
aggregate_into_larger_batch()

Expand Down Expand Up @@ -929,7 +955,7 @@ def get_samples_from_workers(
sample_batches = [(0, sample_batch)]
else:
# Not much we can do. Return empty list and wait.
return []
sample_batches = []

return sample_batches

Expand Down
28 changes: 18 additions & 10 deletions rllib/algorithms/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from ray.rllib.evaluation.episode import Episode
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.test_utils import framework_iterator
from ray import tune


class OnWorkerCreatedCallbacks(DefaultCallbacks):
class OnWorkersRecreatedCallbacks(DefaultCallbacks):
def on_workers_recreated(
self,
*,
Expand Down Expand Up @@ -109,11 +110,13 @@ def tearDownClass(cls):
ray.shutdown()

def test_on_workers_recreated_callback(self):
tune.register_env("env", lambda cfg: CartPoleCrashing(cfg))

config = (
APPOConfig()
.environment(CartPoleCrashing)
.callbacks(OnWorkerCreatedCallbacks)
.rollouts(num_rollout_workers=2)
.environment("env")
.callbacks(OnWorkersRecreatedCallbacks)
.rollouts(num_rollout_workers=3)
.fault_tolerance(recreate_failed_workers=True)
)

Expand All @@ -122,19 +125,24 @@ def test_on_workers_recreated_callback(self):
original_worker_ids = algo.workers.healthy_worker_ids()
for id_ in original_worker_ids:
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] == 0)
self.assertTrue(algo._counters["total_num_workers_recreated"] == 0)

# After building the algorithm, we should have 2 healthy (remote) workers.
self.assertTrue(len(original_worker_ids) == 2)
self.assertTrue(len(original_worker_ids) == 3)

# Train a bit (and have the envs/workers crash a couple of times).
for _ in range(3):
algo.train()
for _ in range(5):
print(algo.train())

# After training, each new worker should have been recreated at least once.
# After training, the `on_workers_recreated` callback should have captured
# the exact worker IDs recreated (the exact number of times) as the actor
# manager itself. This confirms that the callback is triggered correctly,
# always.
new_worker_ids = algo.workers.healthy_worker_ids()
self.assertTrue(len(new_worker_ids) == 2)
self.assertTrue(len(new_worker_ids) == 3)
for id_ in new_worker_ids:
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] >= 1)
# num_restored = algo.workers.restored_actors_history[id_]
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] > 1)
algo.stop()

def test_on_init_and_checkpoint_loaded(self):
Expand Down
3 changes: 1 addition & 2 deletions rllib/evaluation/collectors/agent_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.space_utils import (
flatten_to_single_ndarray,
get_dummy_batch_for_space,
Expand All @@ -24,7 +24,6 @@

logger = logging.getLogger(__name__)

_, tf, _ = try_import_tf()
torch, _ = try_import_torch()


Expand Down
3 changes: 0 additions & 3 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,6 @@ def postprocess_episode(
episode_id = episode.episode_id
policy_collector_group = episode.batch_builder

# TODO: (sven) Once we implement multi-agent communication channels,
# we have to resolve the restriction of only sending other agent
# batches from the same policy to the postprocess methods.
# Build SampleBatches for the given episode.
pre_batches = {}
for (eps_id, agent_id), collector in self.agent_collectors.items():
Expand Down
Loading

0 comments on commit 7001982

Please sign in to comment.