Skip to content

Commit

Permalink
[RLlib] Refactor: All tf static graph code should reside inside Polic…
Browse files Browse the repository at this point in the history
…y class. (ray-project#17169)
  • Loading branch information
sven1977 authored Jul 20, 2021
1 parent efed070 commit 5a313ba
Show file tree
Hide file tree
Showing 42 changed files with 1,020 additions and 806 deletions.
4 changes: 2 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ py_test(
"--env", "CartPole-v0",
"--run", "IMPALA",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_iter_time_s\": 1, \"num_data_loader_buffers\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0}'",
"--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_iter_time_s\": 1, \"num_multi_gpu_tower_stacks\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0}'",
"--ray-num-cpus", "4",
]
)
Expand All @@ -1003,7 +1003,7 @@ py_test(
"--env", "CartPole-v0",
"--run", "IMPALA",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_iter_time_s\": 1, \"num_data_loader_buffers\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0, \"model\": {\"use_lstm\": true}}'",
"--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_iter_time_s\": 1, \"num_multi_gpu_tower_stacks\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0, \"model\": {\"use_lstm\": true}}'",
"--ray-num-cpus", "4",
]
)
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/a3c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import ComputeGradients, AverageGradients, \
ApplyGradients, TrainTFMultiGPU, TrainOneStep
ApplyGradients, MultiGPUTrainOneStep, TrainOneStep
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.evaluation.worker_set import WorkerSet
Expand Down Expand Up @@ -66,7 +66,7 @@ def execution_plan(workers: WorkerSet,
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = TrainTFMultiGPU(
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay
from ray.rllib.execution.train_ops import TrainTFMultiGPU, TrainOneStep, \
from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \
UpdateTargetNetwork
from ray.rllib.offline.shuffled_input import ShuffledInput
from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy
Expand Down Expand Up @@ -103,7 +103,7 @@ def update_prio(item):
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = TrainTFMultiGPU(
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork, \
TrainTFMultiGPU
MultiGPUTrainOneStep
from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
Expand Down Expand Up @@ -255,7 +255,7 @@ def update_prio(item):
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = TrainTFMultiGPU(
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
Expand Down
5 changes: 4 additions & 1 deletion rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_distribution_inputs_and_class(policy: Policy,
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

policy.q_values = q_vals
policy.q_func_vars = model.variables()

return policy.q_values, Categorical, [] # state-out


Expand Down Expand Up @@ -304,6 +304,9 @@ def adam_optimizer(policy: Policy, config: TrainerConfigDict

def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer",
loss: TensorType) -> ModelGradients:
if not hasattr(policy, "q_func_vars"):
policy.q_func_vars = policy.model.variables()

return minimize_and_clip(
optimizer,
loss,
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/dqn/simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainTFMultiGPU, TrainOneStep, \
from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \
UpdateTargetNetwork
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict
Expand Down Expand Up @@ -143,7 +143,7 @@ def execution_plan(workers: WorkerSet,
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = TrainTFMultiGPU(
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
Expand Down
8 changes: 6 additions & 2 deletions rllib/agents/dqn/simple_q_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def do_update():

@override(TFPolicy)
def variables(self):
if not hasattr(self, "q_func_vars"):
self.q_func_vars = self.model.variables()
if not hasattr(self, "target_q_func_vars"):
self.target_q_func_vars = self.target_q_model.variables()
return self.q_func_vars + self.target_q_func_vars


Expand Down Expand Up @@ -114,7 +118,6 @@ def get_distribution_inputs_and_class(
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

policy.q_values = q_vals
policy.q_func_vars = q_model.variables()
return policy.q_values, (TorchCategorical
if policy.config["framework"] == "torch" else
Categorical), [] # state-outs
Expand Down Expand Up @@ -144,7 +147,8 @@ def build_q_losses(policy: Policy, model: ModelV2,
policy.target_q_model,
train_batch[SampleBatch.NEXT_OBS],
explore=False)
policy.target_q_func_vars = policy.target_q_model.variables()
if not hasattr(policy, "target_q_func_vars"):
policy.target_q_func_vars = policy.target_q_model.variables()

# q scores for actions which we know were selected in the given state.
one_hot_selection = tf.one_hot(
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/dqn/simple_q_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def build_q_losses(policy: Policy, model, dist_class,
# q network evaluation
q_t = compute_q_values(
policy,
policy.model,
model,
train_batch[SampleBatch.CUR_OBS],
explore=False,
is_training=True)
Expand Down
28 changes: 14 additions & 14 deletions rllib/agents/dqn/tests/test_simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,20 @@ def test_simple_q_fake_multi_gpu_learning(self):
# Fake GPU setup.
config["num_gpus"] = 2
config["_fake_gpus"] = True
config["framework"] = "tf"

trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
num_iterations = 200
learnt = False
for i in range(num_iterations):
results = trainer.train()
print("reward={}".format(results["episode_reward_mean"]))
if results["episode_reward_mean"] > 75.0:
learnt = True
break
assert learnt, "SimpleQ multi-GPU (with fake-GPUs) did not " \
"learn CartPole!"
trainer.stop()

for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
num_iterations = 200
learnt = False
for i in range(num_iterations):
results = trainer.train()
print("reward={}".format(results["episode_reward_mean"]))
if results["episode_reward_mean"] > 75.0:
learnt = True
break
assert learnt, "SimpleQ multi-GPU (with fake-GPUs) did not " \
"learn CartPole!"
trainer.stop()

def test_simple_q_loss_function(self):
"""Tests the Simple-Q loss function results on all frameworks."""
Expand Down
78 changes: 50 additions & 28 deletions rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.multi_gpu_learner import TFMultiGPULearner
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
from ray.rllib.execution.common import STEPS_TRAINED_COUNTER, \
_get_global_vars, _get_shared_metrics
Expand All @@ -14,6 +14,7 @@
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.tune.trainable import Trainable
from ray.tune.utils.placement_groups import PlacementGroupFactory

Expand Down Expand Up @@ -42,31 +43,41 @@
"train_batch_size": 500,
"min_iter_time_s": 10,
"num_workers": 2,
# number of GPUs the learner should use.
# Number of GPUs the learner should use.
"num_gpus": 1,
# set >1 to load data into GPUs in parallel. Increases GPU memory usage
# proportionally with the number of buffers.
"num_data_loader_buffers": 1,
# how many train batches should be retained for minibatching. This conf
# For each stack of multi-GPU towers, how many slots should we reserve for
# parallel data loading? Set this to >1 to load data into GPUs in
# parallel. This will increase GPU memory usage proportionally with the
# number of stacks.
# Example:
# 2 GPUs and `num_multi_gpu_tower_stacks=3`:
# - One tower stack consists of 2 GPUs, each with a copy of the
# model/graph.
# - Each of the stacks will create 3 slots for batch data on each of its
# GPUs, increasing memory requirements on each GPU by 3x.
# - This enables us to preload data into these stacks while another stack
# is performing gradient calculations.
"num_multi_gpu_tower_stacks": 1,
# How many train batches should be retained for minibatching. This conf
# only has an effect if `num_sgd_iter > 1`.
"minibatch_buffer_size": 1,
# number of passes to make over each train batch
# Number of passes to make over each train batch.
"num_sgd_iter": 1,
# set >0 to enable experience replay. Saved samples will be replayed with
# Set >0 to enable experience replay. Saved samples will be replayed with
# a p:1 proportion to new data samples.
"replay_proportion": 0.0,
# number of sample batches to store for replay. The number of transitions
# Number of sample batches to store for replay. The number of transitions
# saved total will be (replay_buffer_num_slots * rollout_fragment_length).
"replay_buffer_num_slots": 0,
# max queue size for train batches feeding into the learner
# Max queue size for train batches feeding into the learner.
"learner_queue_size": 16,
# wait for train batches to be available in minibatch buffer queue
# Wait for train batches to be available in minibatch buffer queue
# this many seconds. This may need to be increased e.g. when training
# with a slow environment
# with a slow environment.
"learner_queue_timeout": 300,
# level of queuing for sampling.
# Level of queuing for sampling.
"max_sample_requests_in_flight_per_worker": 2,
# max number of workers to broadcast one set of weights to
# Max number of workers to broadcast one set of weights to.
"broadcast_interval": 1,
# Use n (`num_aggregation_workers`) extra Actors for multi-level
# aggregation of the data produced by the m RolloutWorkers
Expand All @@ -77,22 +88,25 @@

# Learning params.
"grad_clip": 40.0,
# either "adam" or "rmsprop"
# Either "adam" or "rmsprop".
"opt_type": "adam",
"lr": 0.0005,
"lr_schedule": None,
# rmsprop considered
# `opt_type=rmsprop` settings.
"decay": 0.99,
"momentum": 0.0,
"epsilon": 0.1,
# balancing the three losses
# Balancing the three losses.
"vf_loss_coeff": 0.5,
"entropy_coeff": 0.01,
"entropy_coeff_schedule": None,

# Callback for APPO to use to update KL, target network periodically.
# The input to the callback is the learner fetches dict.
"after_train_step": None,

# DEPRECATED:
"num_data_loader_buffers": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# yapf: enable
Expand Down Expand Up @@ -140,23 +154,23 @@ def default_resource_request(cls, config):


def make_learner_thread(local_worker, config):
if not config["simple_optimizer"] and (
config["num_gpus"] > 1 or config["num_data_loader_buffers"] > 1):
if not config["simple_optimizer"]:
logger.info(
"Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format(
config["num_gpus"], config["num_data_loader_buffers"]))
if config["num_data_loader_buffers"] < config["minibatch_buffer_size"]:
"Enabling multi-GPU mode, {} GPUs, {} parallel tower-stacks".
format(config["num_gpus"], config["num_multi_gpu_tower_stacks"]))
if config["num_multi_gpu_tower_stacks"] < \
config["minibatch_buffer_size"]:
raise ValueError(
"In multi-gpu mode you must have at least as many "
"parallel data loader buffers as minibatch buffers: "
"{} vs {}".format(config["num_data_loader_buffers"],
"In multi-GPU mode you must have at least as many "
"parallel multi-GPU towers as minibatch buffers: "
"{} vs {}".format(config["num_multi_gpu_tower_stacks"],
config["minibatch_buffer_size"]))
learner_thread = TFMultiGPULearner(
learner_thread = MultiGPULearnerThread(
local_worker,
num_gpus=config["num_gpus"],
lr=config["lr"],
train_batch_size=config["train_batch_size"],
num_data_loader_buffers=config["num_data_loader_buffers"],
num_multi_gpu_tower_stacks=config["num_multi_gpu_tower_stacks"],
minibatch_buffer_size=config["minibatch_buffer_size"],
num_sgd_iter=config["num_sgd_iter"],
learner_queue_size=config["learner_queue_size"],
Expand Down Expand Up @@ -190,8 +204,16 @@ def get_policy_class(config):


def validate_config(config):
if config["num_data_loader_buffers"] != DEPRECATED_VALUE:
deprecation_warning(
"num_data_loader_buffers",
"num_multi_gpu_tower_stacks",
error=False)
config["num_multi_gpu_tower_stacks"] = \
config["num_data_loader_buffers"]

if config["entropy_coeff"] < 0.0:
raise DeprecationWarning("`entropy_coeff` must be >= 0.0!")
raise ValueError("`entropy_coeff` must be >= 0.0!")

if config["vtrace"] and not config["in_evaluation"]:
if config["batch_mode"] != "truncate_episodes":
Expand Down
28 changes: 28 additions & 0 deletions rllib/agents/impala/tests/test_impala.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import unittest

import ray
Expand Down Expand Up @@ -55,6 +56,7 @@ def test_impala_lr_schedule(self):
[0, 0.0005],
[10000, 0.000001],
]
config["num_gpus"] = 0 # Do not use any (fake) GPUs.
config["env"] = "CartPole-v0"

def get_lr(result):
Expand All @@ -75,6 +77,32 @@ def get_lr(result):
finally:
trainer.stop()

def test_impala_fake_multi_gpu_learning(self):
"""Test whether IMPALATrainer can learn CartPole w/ faked multi-GPU."""
config = copy.deepcopy(impala.DEFAULT_CONFIG)
# Fake GPU setup.
config["_fake_gpus"] = True
config["num_gpus"] = 2

config["train_batch_size"] *= 2

# Test w/ LSTMs.
config["model"]["use_lstm"] = True

for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = impala.ImpalaTrainer(config=config, env="CartPole-v0")
num_iterations = 200
learnt = False
for i in range(num_iterations):
results = trainer.train()
print(results)
if results["episode_reward_mean"] > 55.0:
learnt = True
break
assert learnt, \
"IMPALA multi-GPU (with fake-GPUs) did not learn CartPole!"
trainer.stop()


if __name__ == "__main__":
import pytest
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/impala/vtrace_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _make_time_major(*args, **kw):
clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"])

# Store loss object only for multi-GPU tower 0.
if policy.device == values.device:
if model is policy.model_gpu_towers[0]:
policy.loss = loss

return loss.total_loss
Expand Down Expand Up @@ -229,7 +229,7 @@ def stats(policy, train_batch):
values_batched = make_time_major(
policy,
train_batch.get("seq_lens"),
policy.model.value_function(),
policy.model_gpu_towers[0].value_function(),
drop_last=policy.config["vtrace"])

return {
Expand Down
Loading

0 comments on commit 5a313ba

Please sign in to comment.