Skip to content

Commit

Permalink
[RLlib-contrib] CRR. (ray-project#36616)
Browse files Browse the repository at this point in the history
  • Loading branch information
avnishn authored Oct 5, 2023
1 parent 68ad265 commit d48717a
Show file tree
Hide file tree
Showing 11 changed files with 1,140 additions and 6 deletions.
18 changes: 15 additions & 3 deletions .buildkite/pipeline.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@
- ./ci/env/env_info.sh
- pytest rllib_contrib/a3c/tests/test_a3c.py

- label: ":exploding_death_star: RLlib Contrib: Alpha Star Tests"
- label: ":exploding_death_star: RLlib Contrib: AlphaStar Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
Expand All @@ -491,10 +491,10 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/alpha_zero && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/alpha_zero && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/alpha_zero/tests/
- python alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py --run-as-test
- python rllib_contrib/alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: APEX DDPG Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
Expand Down Expand Up @@ -532,6 +532,18 @@
- pytest rllib_contrib/bandit/tests/
- python rllib_contrib/bandit/examples/bandit_linucb_interest_evolution_recsim.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: CRR Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/crr && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/crr/tests/
- python rllib_contrib/crr/examples/crr_cartpole_v1.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: DDPG Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
Expand Down
16 changes: 13 additions & 3 deletions rllib_contrib/TOC.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
# Algorithms


* [A3C](./a3c)
* [A2C](./a2c)
* [Alpha Star](./alpha_star)
* [A3C](./a3c)
* [AlphaStar](./alpha_star)
* [AlphaZero](./alpha_zero)
* [APEX-DDPG][./apex_ddpg]
* [APEX DQN](./apex_dqn/)
* [Bandit](./bandit)
* [DDPG](./ddpg)
* [CRR](./crr)
* [Decision Transformer](./dt)
* [DDPG](./ddpg)
* [ES](./es)
* [LeelaChessZero](./leela_chess_zero)
* [MAML](./maml)
* [MBMPO](./mbmpo)
* [PG](./pg)
* [QMIX](./qmix)
* [R2D2](./r2d2)
* [SimpleQ](./simple_q)
* [SlateQ](./slate_q)
* [TD3](./td3)



Expand Down
18 changes: 18 additions & 0 deletions rllib_contrib/crr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# CRR (Critic Regularized Regression)


[CRR](https://arxiv.org/abs/2006.15134) is another offline RL algorithm based on Q-learning that can learn from an offline experience replay. The challenge in applying existing Q-learning algorithms to offline RL lies in the overestimation of the Q-function, as well as, the lack of exploration beyond the observed data. The latter becomes increasingly important during bootstrapping in the bellman equation, where the Q-function queried for the next state’s Q-value(s) does not have support in the observed data. To mitigate these issues, CRR implements a simple and yet powerful idea of “value-filtered regression”. The key idea is to use a learned critic to filter-out the non-promising transitions from the replay dataset.


## Installation

```
conda create -n rllib-crr python=3.10
conda activate rllib-crr
pip install -r requirements.txt
pip install -e '.[development]'
```

## Usage

[CRR Example]()
84 changes: 84 additions & 0 deletions rllib_contrib/crr/examples/crr_cartpole_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse

from rllib_crr.crr import CRR, CRRConfig

import ray
from ray import air, tune
from ray.rllib.utils.test_utils import check_learning_achieved


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()
parser.add_argument("--run-as-test", action="store_true", default=False)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args


if __name__ == "__main__":
args = get_cli_args()

ray.init()
config = (
CRRConfig()
.environment(env="CartPole-v1", clip_actions=True)
.framework("torch")
.offline_data(
input_="dataset",
input_config={
"format": "json",
"paths": ["s3://anonymous@air-example-data/rllib/cartpole/large.json"],
},
actions_in_input_normalized=True,
)
.training(
twin_q=True,
weight_type="exp",
advantage_type="mean",
n_action_sample=4,
target_network_update_freq=10000,
tau=0.0005,
gamma=0.99,
train_batch_size=2048,
critic_hidden_activation="tanh",
critic_hiddens=[128, 128, 128],
critic_lr=0.0003,
actor_hidden_activation="tanh",
actor_hiddens=[128, 128, 128],
actor_lr=0.0003,
temperature=1.0,
max_weight=20.0,
)
.evaluation(
evaluation_interval=1,
evaluation_num_workers=1,
evaluation_duration=10,
evaluation_duration_unit="episodes",
evaluation_parallel_to_training=True,
evaluation_config=CRRConfig.overrides(input_="sampler", explore=False),
)
.rollouts(num_rollout_workers=3)
)

stop_reward = 200

tuner = tune.Tuner(
CRR,
param_space=config.to_dict(),
run_config=air.RunConfig(
stop={
"evaluation/sampler_results/episode_reward_mean": stop_reward,
"training_iteration": 100,
},
failure_config=air.FailureConfig(fail_fast="raise"),
),
)
results = tuner.fit()

if args.run_as_test:
check_learning_achieved(
results,
stop_reward,
metric="evaluation/sampler_results/episode_reward_mean",
)
18 changes: 18 additions & 0 deletions rllib_contrib/crr/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["src"]

[project]
name = "rllib-crr"
authors = [{name = "Anyscale Inc."}]
version = "0.1.0"
description = ""
readme = "README.md"
requires-python = ">=3.7, <3.11"
dependencies = ["gymnasium", "ray[rllib]==2.5.0"]

[project.optional-dependencies]
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "torch==1.12.0"]
1 change: 1 addition & 0 deletions rllib_contrib/crr/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch==1.12.0
9 changes: 9 additions & 0 deletions rllib_contrib/crr/src/rllib_crr/crr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from rllib_crr.crr.crr import CRR, CRRConfig
from rllib_crr.crr.crr_torch_model import CRRModel
from rllib_crr.crr.crr_torch_policy import CRRTorchPolicy

from ray.tune.registry import register_trainable

__all__ = ["CRR", "CRRConfig", "CRRModel", "CRRTorchPolicy"]

register_trainable("rllib-contrib-crr", CRR)
Loading

0 comments on commit d48717a

Please sign in to comment.