forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib-contrib] APEX DPPG. (ray-project#36596)
- Loading branch information
Showing
9 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# APEX DDPG (Distributed Prioritized Experience Replay) | ||
|
||
[APEX DDPG](https://arxiv.org/pdf/1803.00933.pdf) Distributed Prioritized Experience Replay is an algorithm that decouples | ||
active learning from sampling. Actors interact with their own instances of the environment by selecting actions according | ||
to a shared neural network, and accumulate the resulting experience in a shared experience replay memory; the learner replays samples of experience and updates the neural network. The architecture relies on prioritized experience replay to | ||
focus only on the most significant data generated by the actors. | ||
|
||
## Installation | ||
|
||
``` | ||
conda create -n rllib-apex-ddpg python=3.10 | ||
conda activate rllib-apex-ddpg | ||
pip install -r requirements.txt | ||
pip install -e '.[development]' | ||
``` | ||
|
||
## Usage | ||
|
||
[APEX-DDPG Example]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import argparse | ||
|
||
from rllib_apex_ddpg.apex_ddpg import ApexDDPG, ApexDDPGConfig | ||
|
||
import ray | ||
from ray import air, tune | ||
from ray.rllib.utils.test_utils import check_learning_achieved | ||
|
||
|
||
def get_cli_args(): | ||
"""Create CLI parser and return parsed arguments""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--run-as-test", action="store_true", default=False) | ||
args = parser.parse_args() | ||
print(f"Running with following CLI args: {args}") | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_cli_args() | ||
|
||
ray.init() | ||
|
||
config = ( | ||
ApexDDPGConfig() | ||
.rollouts(num_rollout_workers=3) | ||
.framework("torch") | ||
.environment("Pendulum-v1", clip_rewards=False) | ||
.training(n_step=1, target_network_update_freq=50000, tau=1.0, use_huber=True) | ||
.evaluation(evaluation_interval=5, evaluation_duration=10) | ||
) | ||
|
||
stop_reward = -320 | ||
|
||
tuner = tune.Tuner( | ||
ApexDDPG, | ||
param_space=config.to_dict(), | ||
run_config=air.RunConfig( | ||
stop={ | ||
"sampler_results/episode_reward_mean": stop_reward, | ||
"timesteps_total": 1500000, | ||
}, | ||
failure_config=air.FailureConfig(fail_fast="raise"), | ||
), | ||
) | ||
results = tuner.fit() | ||
|
||
if args.run_as_test: | ||
check_learning_achieved(results, stop_reward) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["setuptools>=61.0"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.setuptools.packages.find] | ||
where = ["src"] | ||
|
||
[project] | ||
name = "rllib-apex-ddpg" | ||
authors = [{name = "Anyscale Inc."}] | ||
version = "0.1.0" | ||
description = "" | ||
readme = "README.md" | ||
requires-python = ">=3.7, <3.11" | ||
dependencies = ["gymnasium[atari]", "ray[rllib]==2.5.0"] | ||
|
||
[project.optional-dependencies] | ||
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
tensorflow==2.11.0 | ||
torch==1.12.0 |
7 changes: 7 additions & 0 deletions
7
rllib_contrib/apex_ddpg/src/rllib_apex_ddpg/apex_ddpg/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from rllib_apex_ddpg.apex_ddpg.apex_ddpg import ApexDDPG, ApexDDPGConfig | ||
|
||
from ray.tune.registry import register_trainable | ||
|
||
__all__ = ["ApexDDPGConfig", "ApexDDPG"] | ||
|
||
register_trainable("rllib-contrib-apex-ddpg", ApexDDPG) |
147 changes: 147 additions & 0 deletions
147
rllib_contrib/apex_ddpg/src/rllib_apex_ddpg/apex_ddpg/apex_ddpg.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from typing import Optional | ||
|
||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided | ||
from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQN | ||
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig | ||
from ray.rllib.utils.annotations import override | ||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE | ||
from ray.rllib.utils.typing import ResultDict | ||
|
||
|
||
class ApexDDPGConfig(DDPGConfig): | ||
"""Defines a configuration class from which an ApexDDPG Trainer can be built. | ||
Example: | ||
>>> from ray.rllib.algorithms.apex_ddpg.apex_ddpg import ApexDDPGConfig | ||
>>> config = ApexDDPGConfig().training(lr=0.01).resources(num_gpus=1) | ||
>>> print(config.to_dict()) # doctest: +SKIP | ||
>>> # Build a Trainer object from the config and run one training iteration. | ||
>>> algo = config.build(env="Pendulum-v1") | ||
>>> algo.train() # doctest: +SKIP | ||
Example: | ||
>>> from ray.rllib.algorithms.apex_ddpg.apex_ddpg import ApexDDPGConfig | ||
>>> from ray import tune | ||
>>> import ray.air as air | ||
>>> config = ApexDDPGConfig() | ||
>>> # Print out some default values. | ||
>>> print(config.lr) # doctest: +SKIP | ||
0.0004 | ||
>>> # Update the config object. | ||
>>> config.training(lr=tune.grid_search([0.001, 0.0001])) | ||
>>> # Set the config object's env. | ||
>>> config.environment(env="Pendulum-v1") | ||
>>> # Use to_dict() to get the old-style python config dict | ||
>>> # when running with tune. | ||
>>> tune.Tuner( # doctest: +SKIP | ||
... "APEX_DDPG", | ||
... run_config=air.RunConfig(stop={"episode_reward_mean": 200}), | ||
... param_space=config.to_dict(), | ||
... ).fit() | ||
""" | ||
|
||
def __init__(self, algo_class=None): | ||
"""Initializes an ApexDDPGConfig instance.""" | ||
super().__init__(algo_class=algo_class or ApexDDPG) | ||
|
||
# fmt: off | ||
# __sphinx_doc_begin__ | ||
# ApexDDPG-specific settings. | ||
self.optimizer = { | ||
"max_weight_sync_delay": 400, | ||
"num_replay_buffer_shards": 4, | ||
"debug": False, | ||
} | ||
# Overwrite the default max_requests_in_flight_per_replay_worker. | ||
self.max_requests_in_flight_per_replay_worker = float("inf") | ||
self.timeout_s_sampler_manager = 0.0 | ||
self.timeout_s_replay_manager = 0.0 | ||
|
||
# Override some of Trainer/DDPG's default values with ApexDDPG-specific values. | ||
self.n_step = 3 | ||
self.exploration_config = {"type": "PerWorkerOrnsteinUhlenbeckNoise"} | ||
self.num_gpus = 0 | ||
self.num_rollout_workers = 32 | ||
self.min_sample_timesteps_per_iteration = 25000 | ||
self.min_time_s_per_iteration = 30 | ||
self.train_batch_size = 512 | ||
self.rollout_fragment_length = 50 | ||
self.replay_buffer_config = { | ||
"type": "MultiAgentPrioritizedReplayBuffer", | ||
"capacity": 2000000, | ||
"no_local_replay_buffer": True, | ||
# Alpha parameter for prioritized replay buffer. | ||
"prioritized_replay_alpha": 0.6, | ||
# Beta parameter for sampling from prioritized replay buffer. | ||
"prioritized_replay_beta": 0.4, | ||
# Epsilon to add to the TD errors when updating priorities. | ||
"prioritized_replay_eps": 1e-6, | ||
# Whether all shards of the replay buffer must be co-located | ||
# with the learner process (running the execution plan). | ||
# This is preferred b/c the learner process should have quick | ||
# access to the data from the buffer shards, avoiding network | ||
# traffic each time samples from the buffer(s) are drawn. | ||
# Set this to False for relaxing this constraint and allowing | ||
# replay shards to be created on node(s) other than the one | ||
# on which the learner is located. | ||
"replay_buffer_shards_colocated_with_driver": True, | ||
# Whether to compute priorities on workers. | ||
"worker_side_prioritization": True, | ||
# Specify prioritized replay by supplying a buffer type that supports | ||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer. | ||
"prioritized_replay": DEPRECATED_VALUE, | ||
} | ||
# Number of timesteps to collect from rollout workers before we start | ||
# sampling from replay buffers for learning. Whether we count this in agent | ||
# steps or environment steps depends on config.multi_agent(count_steps_by=..). | ||
self.num_steps_sampled_before_learning_starts = 50000 | ||
self.target_network_update_freq = 500000 | ||
self.training_intensity = 1 | ||
# __sphinx_doc_end__ | ||
# fmt: on | ||
|
||
@override(DDPGConfig) | ||
def training( | ||
self, | ||
*, | ||
timeout_s_sampler_manager: Optional[float] = NotProvided, | ||
timeout_s_replay_manager: Optional[float] = NotProvided, | ||
**kwargs, | ||
) -> "ApexDDPGConfig": | ||
"""Sets the training related configuration. | ||
Args: | ||
timeout_s_sampler_manager: The timeout for waiting for sampling results | ||
for workers -- typically if this is too low, the manager won't be able | ||
to retrieve ready sampling results. | ||
timeout_s_replay_manager: The timeout for waiting for replay worker | ||
results -- typically if this is too low, the manager won't be able to | ||
retrieve ready replay requests. | ||
Returns: | ||
This updated ApexDDPGConfig object. | ||
""" | ||
super().training(**kwargs) | ||
|
||
if timeout_s_sampler_manager is not NotProvided: | ||
self.timeout_s_sampler_manager = timeout_s_sampler_manager | ||
if timeout_s_replay_manager is not NotProvided: | ||
self.timeout_s_replay_manager = timeout_s_replay_manager | ||
|
||
return self | ||
|
||
|
||
class ApexDDPG(DDPG, ApexDQN): | ||
@classmethod | ||
@override(DDPG) | ||
def get_default_config(cls) -> AlgorithmConfig: | ||
return ApexDDPGConfig() | ||
|
||
@override(DDPG) | ||
def setup(self, config: AlgorithmConfig): | ||
return ApexDQN.setup(self, config) | ||
|
||
@override(DDPG) | ||
def training_step(self) -> ResultDict: | ||
"""Use APEX-DQN's training iteration function.""" | ||
return ApexDQN.training_step(self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import unittest | ||
|
||
import pytest | ||
from rllib_apex_ddpg.apex_ddpg.apex_ddpg import ApexDDPGConfig | ||
|
||
import ray | ||
from ray.rllib.utils.test_utils import ( | ||
check, | ||
check_compute_single_action, | ||
check_train_results, | ||
framework_iterator, | ||
) | ||
|
||
|
||
class TestApexDDPG(unittest.TestCase): | ||
def setUp(self): | ||
ray.init() | ||
|
||
def tearDown(self): | ||
ray.shutdown() | ||
|
||
def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self): | ||
"""Test whether APEX-DDPG can be built on all frameworks.""" | ||
config = ( | ||
ApexDDPGConfig() | ||
.environment(env="Pendulum-v1") | ||
.rollouts(num_rollout_workers=2) | ||
.reporting(min_sample_timesteps_per_iteration=100) | ||
.training( | ||
num_steps_sampled_before_learning_starts=0, | ||
optimizer={"num_replay_buffer_shards": 1}, | ||
) | ||
) | ||
|
||
num_iterations = 1 | ||
|
||
for _ in framework_iterator(config, with_eager_tracing=True): | ||
algo = config.build() | ||
|
||
# Test per-worker scale distribution. | ||
infos = algo.workers.foreach_policy(lambda p, _: p.get_exploration_state()) | ||
scale = [i["cur_scale"] for i in infos] | ||
expected = [ | ||
0.4 ** (1 + (i + 1) / float(config.num_rollout_workers - 1) * 7) | ||
for i in range(config.num_rollout_workers) | ||
] | ||
check(scale, [0.0] + expected) | ||
|
||
for _ in range(num_iterations): | ||
results = algo.train() | ||
check_train_results(results) | ||
print(results) | ||
check_compute_single_action(algo) | ||
|
||
# Test again per-worker scale distribution | ||
# (should not have changed). | ||
infos = algo.workers.foreach_policy(lambda p, _: p.get_exploration_state()) | ||
scale = [i["cur_scale"] for i in infos] | ||
check(scale, [0.0] + expected) | ||
|
||
algo.stop() | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
|
||
sys.exit(pytest.main(["-v", __file__])) |