diff --git a/LICENSE b/LICENSE index cfe48cb8c4294..cd24136c3d3fa 100644 --- a/LICENSE +++ b/LICENSE @@ -224,3 +224,22 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +-------------------------------------------------------------------------------- + +Code in python/ray/rllib/impala/vtrace.py from +https://github.com/deepmind/scalable_agent + +Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/doc/source/impala.png b/doc/source/impala.png new file mode 100644 index 0000000000000..a7d12e4b5a0f9 Binary files /dev/null and b/doc/source/impala.png differ diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index b2f804dd05d0f..3d3a020c4f8fc 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -16,7 +16,7 @@ Tuned examples: `PongNoFrameskip-v4 `__ `[implementation] `__ -RLlib's A3C uses the AsyncGradientsOptimizer to apply gradients computed remotely on policy evaluation actors. It scales to up to 16-32 worker processes, depending on the environment. Both a TensorFlow (LSTM), and PyTorch version are available. +RLlib's A3C uses the AsyncGradientsOptimizer to apply gradients computed remotely on policy evaluation actors. It scales to up to 16-32 worker processes, depending on the environment. Both a TensorFlow (LSTM), and PyTorch version are available. Note that if you have a GPU, `IMPALA <#importance-weighted-actor-learner-architecture>`__ probably will perform better than A3C. Tuned examples: `PongDeterministic-v4 `__, `PyTorch version `__ @@ -47,6 +47,20 @@ Tuned examples: `Humanoid-v1 `__ +`[implementation] `__ +In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code `__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model `__. + +Tuned examples: `PongNoFrameskip-v4 `__, `vectorized configuration `__ + +.. figure:: impala.png + :align: center + + RLlib's IMPALA implementation scales from 16 to 128 workers on PongNoFrameskip-v4. With vectorization, similar learning performance to 128 workers can be achieved with only 32 workers. This about an order of magnitude faster than A3C, with similar sample efficiency. + Policy Gradients ---------------- `[paper] `__ `[implementation] `__ We include a vanilla policy gradients implementation as an example algorithm. This is usually outperformed by PPO. @@ -64,4 +78,4 @@ Tuned examples: `Humanoid-v1 `__ (some of them are tuned to run on GPUs). If you find better settings or tune an algorithm on a different domain, consider submitting a Pull Request! +You can run these with the ``train.py`` script as follows: + +.. code-block:: bash + + python ray/python/ray/rllib/train.py -f /path/to/tuned/example.yaml + Python API ---------- diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index 5f9d94681650e..7dbe039ab744a 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -47,6 +47,7 @@ Algorithms * `Deep Deterministic Policy Gradients `__ * `Deep Q Networks `__ * `Evolution Strategies `__ +* `Importance Weighted Actor-Learner Architecture `__ * `Policy Gradients `__ * `Proximal Policy Optimization `__ diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index cf0f1058083f6..4d6575fcf6bd1 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -19,7 +19,7 @@ def _register_all(): for key in [ "PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG", "APEX_DDPG", - "__fake", "__sigmoid_fake_data", "__parameter_tuning" + "IMPALA", "__fake", "__sigmoid_fake_data", "__parameter_tuning" ]: from ray.rllib.agents.agent import get_agent_class register_trainable(key, get_agent_class(key)) diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py index 00f630d3b0e43..f368f1470ea75 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py @@ -1,3 +1,5 @@ +"""Note: Keep in sync with changes to VTracePolicyGraph.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -77,8 +79,6 @@ def __init__(self, observation_space, action_space, config): ("advantages", advantages), ("value_targets", v_target), ] - self.state_in = self.model.state_in - self.state_out = self.model.state_out TFPolicyGraph.__init__( self, observation_space, @@ -88,29 +88,21 @@ def __init__(self, observation_space, action_space, config): action_sampler=action_dist.sample(), loss=self.loss.total_loss, loss_inputs=loss_in, - state_inputs=self.state_in, - state_outputs=self.state_out, + state_inputs=self.model.state_in, + state_outputs=self.model.state_out, seq_lens=self.model.seq_lens, max_seq_len=self.config["model"]["max_seq_len"]) - if self.config.get("summarize"): - bs = tf.to_float(tf.shape(self.observations)[0]) - tf.summary.scalar("model/policy_graph", self.loss.pi_loss / bs) - tf.summary.scalar("model/value_loss", self.loss.vf_loss / bs) - tf.summary.scalar("model/entropy", self.loss.entropy / bs) - tf.summary.scalar("model/grad_gnorm", tf.global_norm(self._grads)) - tf.summary.scalar("model/var_gnorm", tf.global_norm(self.var_list)) - self.summary_op = tf.summary.merge_all() - self.sess.run(tf.global_variables_initializer()) def extra_compute_action_fetches(self): return {"vf_preds": self.vf} def value(self, ob, *args): - feed_dict = {self.observations: [ob]} - assert len(args) == len(self.state_in), (args, self.state_in) - for k, v in zip(self.state_in, args): + feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} + assert len(args) == len(self.model.state_in), \ + (args, self.model.state_in) + for k, v in zip(self.model.state_in, args): feed_dict[k] = v vf = self.sess.run(self.vf, feed_dict) return vf[0] @@ -126,7 +118,15 @@ def gradients(self, optimizer): def extra_compute_grad_fetches(self): if self.config.get("summarize"): - return {"summary": self.summary_op} + return { + "stats": { + "policy_loss": self.loss.pi_loss, + "value_loss": self.loss.vf_loss, + "entropy": self.loss.entropy, + "grad_gnorm": tf.global_norm(self._grads), + "var_gnorm": tf.global_norm(self.var_list), + }, + } else: return {} @@ -139,7 +139,7 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None): last_r = 0.0 else: next_state = [] - for i in range(len(self.state_in)): + for i in range(len(self.model.state_in)): next_state.append([sample_batch["state_out_{}".format(i)][-1]]) last_r = self.value(sample_batch["new_obs"][-1], *next_state) return compute_advantages(sample_batch, last_r, self.config["gamma"], diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 314ea2b6dbb82..0ac52cc8b8c67 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -360,6 +360,9 @@ def get_agent_class(alg): elif alg == "PG": from ray.rllib.agents import pg return pg.PGAgent + elif alg == "IMPALA": + from ray.rllib.agents import impala + return impala.ImpalaAgent elif alg == "script": from ray.tune import script_runner return script_runner.ScriptRunner diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index 93c0a9e3d983d..d0508463c7129 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -9,7 +9,7 @@ APEX_DDPG_DEFAULT_CONFIG = merge_dicts( DDPG_CONFIG, { - "optimizer_class": "AsyncSamplesOptimizer", + "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( DDPG_CONFIG["optimizer"], { "max_weight_sync_delay": 400, diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index bbe18e174ac68..b120a0fbb75d7 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -9,7 +9,7 @@ APEX_DEFAULT_CONFIG = merge_dicts( DQN_CONFIG, { - "optimizer_class": "AsyncSamplesOptimizer", + "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( DQN_CONFIG["optimizer"], { "max_weight_sync_delay": 400, diff --git a/python/ray/rllib/agents/impala/__init__.py b/python/ray/rllib/agents/impala/__init__.py new file mode 100644 index 0000000000000..087a568c31ac3 --- /dev/null +++ b/python/ray/rllib/agents/impala/__init__.py @@ -0,0 +1,3 @@ +from ray.rllib.agents.impala.impala import ImpalaAgent, DEFAULT_CONFIG + +__all__ = ["ImpalaAgent", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py new file mode 100644 index 0000000000000..bfabeb9409799 --- /dev/null +++ b/python/ray/rllib/agents/impala/impala.py @@ -0,0 +1,123 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pickle +import os +import time + +import ray +from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph +from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph +from ray.rllib.agents.agent import Agent, with_common_config +from ray.rllib.optimizers import AsyncSamplesOptimizer +from ray.rllib.utils import FilterManager +from ray.tune.trial import Resources + +OPTIMIZER_SHARED_CONFIGS = [ + "sample_batch_size", + "train_batch_size", +] + +DEFAULT_CONFIG = with_common_config({ + # V-trace params (see vtrace.py). + "vtrace": True, + "vtrace_clip_rho_threshold": 1.0, + "vtrace_clip_pg_rho_threshold": 1.0, + + # System params. + "sample_batch_size": 50, + "train_batch_size": 500, + "min_iter_time_s": 10, + "summarize": False, + "gpu": True, + "num_workers": 2, + "num_cpus_per_worker": 1, + "num_gpus_per_worker": 0, + + # Learning params. + "grad_clip": 40.0, + "lr": 0.0001, + "vf_loss_coeff": 0.5, + "entropy_coeff": -0.01, + + # Model and preprocessor options. + "clip_rewards": True, + "preprocessor_pref": "deepmind", + "model": { + "use_lstm": False, + "max_seq_len": 20, + "dim": 80, + }, +}) + + +class ImpalaAgent(Agent): + """IMPALA implementation using DeepMind's V-trace.""" + + _agent_name = "IMPALA" + _default_config = DEFAULT_CONFIG + + @classmethod + def default_resource_request(cls, config): + cf = dict(cls._default_config, **config) + return Resources( + cpu=1, + gpu=cf["gpu"] and 1 or 0, + extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], + extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) + + def _init(self): + for k in OPTIMIZER_SHARED_CONFIGS: + if k not in self.config["optimizer"]: + self.config["optimizer"][k] = self.config[k] + if self.config["vtrace"]: + policy_cls = VTracePolicyGraph + else: + policy_cls = A3CPolicyGraph + self.local_evaluator = self.make_local_evaluator( + self.env_creator, policy_cls) + self.remote_evaluators = self.make_remote_evaluators( + self.env_creator, policy_cls, self.config["num_workers"], + {"num_cpus": 1}) + self.optimizer = AsyncSamplesOptimizer(self.local_evaluator, + self.remote_evaluators, + self.config["optimizer"]) + + def _train(self): + prev_steps = self.optimizer.num_steps_sampled + start = time.time() + self.optimizer.step() + while time.time() - start < self.config["min_iter_time_s"]: + self.optimizer.step() + FilterManager.synchronize(self.local_evaluator.filters, + self.remote_evaluators) + result = self.optimizer.collect_metrics() + result = result._replace( + timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps) + return result + + def _stop(self): + # workaround for https://github.com/ray-project/ray/issues/1516 + for ev in self.remote_evaluators: + ev.__ray_terminate__.remote() + + def _save(self, checkpoint_dir): + checkpoint_path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(self.iteration)) + agent_state = ray.get( + [a.save.remote() for a in self.remote_evaluators]) + extra_data = { + "remote_state": agent_state, + "local_state": self.local_evaluator.save() + } + pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) + return checkpoint_path + + def _restore(self, checkpoint_path): + extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) + ray.get([ + a.restore.remote(o) + for a, o in zip(self.remote_evaluators, extra_data["remote_state"]) + ]) + self.local_evaluator.restore(extra_data["local_state"]) diff --git a/python/ray/rllib/agents/impala/vtrace.py b/python/ray/rllib/agents/impala/vtrace.py new file mode 100644 index 0000000000000..ac5abf0e65924 --- /dev/null +++ b/python/ray/rllib/agents/impala/vtrace.py @@ -0,0 +1,300 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions to compute V-trace off-policy actor critic targets. + +For details and theory see: + +"IMPALA: Scalable Distributed Deep-RL with +Importance Weighted Actor-Learner Architectures" +by Espeholt, Soyer, Munos et al. + +See https://arxiv.org/abs/1802.01561 for the full paper. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import tensorflow as tf + +nest = tf.contrib.framework.nest + +VTraceFromLogitsReturns = collections.namedtuple('VTraceFromLogitsReturns', [ + 'vs', 'pg_advantages', 'log_rhos', 'behaviour_action_log_probs', + 'target_action_log_probs' +]) + +VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages') + + +def log_probs_from_logits_and_actions(policy_logits, actions): + """Computes action log-probs from policy logits and actions. + + In the notation used throughout documentation and comments, T refers to the + time dimension ranging from 0 to T-1. B refers to the batch size and + NUM_ACTIONS refers to the number of actions. + + Args: + policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with + un-normalized log-probabilities parameterizing a softmax policy. + actions: An int32 tensor of shape [T, B] with actions. + + Returns: + A float32 tensor of shape [T, B] corresponding to the sampling log + probability of the chosen action w.r.t. the policy. + """ + policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32) + actions = tf.convert_to_tensor(actions, dtype=tf.int32) + + policy_logits.shape.assert_has_rank(3) + actions.shape.assert_has_rank(2) + + return -tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=policy_logits, labels=actions) + + +def from_logits(behaviour_policy_logits, + target_policy_logits, + actions, + discounts, + rewards, + values, + bootstrap_value, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0, + name='vtrace_from_logits'): + r"""V-trace for softmax policies. + + Calculates V-trace actor critic targets for softmax polices as described in + + "IMPALA: Scalable Distributed Deep-RL with + Importance Weighted Actor-Learner Architectures" + by Espeholt, Soyer, Munos et al. + + Target policy refers to the policy we are interested in improving and + behaviour policy refers to the policy that generated the given + rewards and actions. + + In the notation used throughout documentation and comments, T refers to the + time dimension ranging from 0 to T-1. B refers to the batch size and + NUM_ACTIONS refers to the number of actions. + + Args: + behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with + un-normalized log-probabilities parametrizing the softmax behaviour + policy. + target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with + un-normalized log-probabilities parametrizing the softmax target policy. + actions: An int32 tensor of shape [T, B] of actions sampled from the + behaviour policy. + discounts: A float32 tensor of shape [T, B] with the discount encountered + when following the behaviour policy. + rewards: A float32 tensor of shape [T, B] with the rewards generated by + following the behaviour policy. + values: A float32 tensor of shape [T, B] with the value function estimates + wrt. the target policy. + bootstrap_value: A float32 of shape [B] with the value function estimate at + time T. + clip_rho_threshold: A scalar float32 tensor with the clipping threshold for + importance weights (rho) when calculating the baseline targets (vs). + rho^bar in the paper. + clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold + on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). + name: The name scope that all V-trace operations will be created in. + + Returns: + A `VTraceFromLogitsReturns` namedtuple with the following fields: + vs: A float32 tensor of shape [T, B]. Can be used as target to train a + baseline (V(x_t) - vs_t)^2. + pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an + estimate of the advantage in the calculation of policy gradients. + log_rhos: A float32 tensor of shape [T, B] containing the log importance + sampling weights (log rhos). + behaviour_action_log_probs: A float32 tensor of shape [T, B] containing + behaviour policy action log probabilities (log \mu(a_t)). + target_action_log_probs: A float32 tensor of shape [T, B] containing + target policy action probabilities (log \pi(a_t)). + """ + behaviour_policy_logits = tf.convert_to_tensor( + behaviour_policy_logits, dtype=tf.float32) + target_policy_logits = tf.convert_to_tensor( + target_policy_logits, dtype=tf.float32) + actions = tf.convert_to_tensor(actions, dtype=tf.int32) + + # Make sure tensor ranks are as expected. + # The rest will be checked by from_action_log_probs. + behaviour_policy_logits.shape.assert_has_rank(3) + target_policy_logits.shape.assert_has_rank(3) + actions.shape.assert_has_rank(2) + + with tf.name_scope( + name, + values=[ + behaviour_policy_logits, target_policy_logits, actions, + discounts, rewards, values, bootstrap_value + ]): + target_action_log_probs = log_probs_from_logits_and_actions( + target_policy_logits, actions) + behaviour_action_log_probs = log_probs_from_logits_and_actions( + behaviour_policy_logits, actions) + log_rhos = target_action_log_probs - behaviour_action_log_probs + vtrace_returns = from_importance_weights( + log_rhos=log_rhos, + discounts=discounts, + rewards=rewards, + values=values, + bootstrap_value=bootstrap_value, + clip_rho_threshold=clip_rho_threshold, + clip_pg_rho_threshold=clip_pg_rho_threshold) + return VTraceFromLogitsReturns( + log_rhos=log_rhos, + behaviour_action_log_probs=behaviour_action_log_probs, + target_action_log_probs=target_action_log_probs, + **vtrace_returns._asdict()) + + +def from_importance_weights(log_rhos, + discounts, + rewards, + values, + bootstrap_value, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0, + name='vtrace_from_importance_weights'): + r"""V-trace from log importance weights. + + Calculates V-trace actor critic targets as described in + + "IMPALA: Scalable Distributed Deep-RL with + Importance Weighted Actor-Learner Architectures" + by Espeholt, Soyer, Munos et al. + + In the notation used throughout documentation and comments, T refers to the + time dimension ranging from 0 to T-1. B refers to the batch size and + NUM_ACTIONS refers to the number of actions. This code also supports the + case where all tensors have the same number of additional dimensions, e.g., + `rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C]. + + Args: + log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the + log importance sampling weights, i.e. + log(target_policy(a) / behaviour_policy(a)). V-trace performs operations + on rhos in log-space for numerical stability. + discounts: A float32 tensor of shape [T, B] with discounts encountered when + following the behaviour policy. + rewards: A float32 tensor of shape [T, B] containing rewards generated by + following the behaviour policy. + values: A float32 tensor of shape [T, B] with the value function estimates + wrt. the target policy. + bootstrap_value: A float32 of shape [B] with the value function estimate at + time T. + clip_rho_threshold: A scalar float32 tensor with the clipping threshold for + importance weights (rho) when calculating the baseline targets (vs). + rho^bar in the paper. If None, no clipping is applied. + clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold + on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). If + None, no clipping is applied. + name: The name scope that all V-trace operations will be created in. + + Returns: + A VTraceReturns namedtuple (vs, pg_advantages) where: + vs: A float32 tensor of shape [T, B]. Can be used as target to + train a baseline (V(x_t) - vs_t)^2. + pg_advantages: A float32 tensor of shape [T, B]. Can be used as the + advantage in the calculation of policy gradients. + """ + log_rhos = tf.convert_to_tensor(log_rhos, dtype=tf.float32) + discounts = tf.convert_to_tensor(discounts, dtype=tf.float32) + rewards = tf.convert_to_tensor(rewards, dtype=tf.float32) + values = tf.convert_to_tensor(values, dtype=tf.float32) + bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32) + if clip_rho_threshold is not None: + clip_rho_threshold = tf.convert_to_tensor( + clip_rho_threshold, dtype=tf.float32) + if clip_pg_rho_threshold is not None: + clip_pg_rho_threshold = tf.convert_to_tensor( + clip_pg_rho_threshold, dtype=tf.float32) + + # Make sure tensor ranks are consistent. + rho_rank = log_rhos.shape.ndims # Usually 2. + values.shape.assert_has_rank(rho_rank) + bootstrap_value.shape.assert_has_rank(rho_rank - 1) + discounts.shape.assert_has_rank(rho_rank) + rewards.shape.assert_has_rank(rho_rank) + if clip_rho_threshold is not None: + clip_rho_threshold.shape.assert_has_rank(0) + if clip_pg_rho_threshold is not None: + clip_pg_rho_threshold.shape.assert_has_rank(0) + + with tf.name_scope( + name, + values=[log_rhos, discounts, rewards, values, bootstrap_value]): + rhos = tf.exp(log_rhos) + if clip_rho_threshold is not None: + clipped_rhos = tf.minimum( + clip_rho_threshold, rhos, name='clipped_rhos') + else: + clipped_rhos = rhos + + cs = tf.minimum(1.0, rhos, name='cs') + # Append bootstrapped value to get [v1, ..., v_t+1] + values_t_plus_1 = tf.concat( + [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) + deltas = clipped_rhos * ( + rewards + discounts * values_t_plus_1 - values) + + # All sequences are reversed, computation starts from the back. + sequences = ( + tf.reverse(discounts, axis=[0]), + tf.reverse(cs, axis=[0]), + tf.reverse(deltas, axis=[0]), + ) + + # V-trace vs are calculated through a scan from the back to the + # beginning of the given trajectory. + def scanfunc(acc, sequence_item): + discount_t, c_t, delta_t = sequence_item + return delta_t + discount_t * c_t * acc + + initial_values = tf.zeros_like(bootstrap_value) + vs_minus_v_xs = tf.scan( + fn=scanfunc, + elems=sequences, + initializer=initial_values, + parallel_iterations=1, + back_prop=False, + name='scan') + # Reverse the results back to original order. + vs_minus_v_xs = tf.reverse(vs_minus_v_xs, [0], name='vs_minus_v_xs') + + # Add V(x_s) to get v_s. + vs = tf.add(vs_minus_v_xs, values, name='vs') + + # Advantage for policy gradient. + vs_t_plus_1 = tf.concat( + [vs[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) + if clip_pg_rho_threshold is not None: + clipped_pg_rhos = tf.minimum( + clip_pg_rho_threshold, rhos, name='clipped_pg_rhos') + else: + clipped_pg_rhos = rhos + pg_advantages = ( + clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)) + + # Make sure no gradients backpropagated through the returned values. + return VTraceReturns( + vs=tf.stop_gradient(vs), + pg_advantages=tf.stop_gradient(pg_advantages)) diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py new file mode 100644 index 0000000000000..0b9c46c9a842a --- /dev/null +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -0,0 +1,217 @@ +"""Adapted from A3CPolicyGraph to add V-trace. + +Keep in sync with changes to A3CPolicyGraph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import gym + +import ray +from ray.rllib.agents.impala import vtrace +from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.misc import linear, normc_initializer +from ray.rllib.utils.error import UnsupportedSpaceException + + +class VTraceLoss(object): + def __init__(self, + actions, + actions_logp, + actions_entropy, + dones, + behaviour_logits, + target_logits, + discount, + rewards, + values, + bootstrap_value, + vf_loss_coeff=0.5, + entropy_coeff=-0.01, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0): + """Policy gradient loss with vtrace importance weighting. + + VTraceLoss takes tensors of shape [T, B, ...], where `B` is the + batch_size. The reason we need to know `B` is for V-trace to properly + handle episode cut boundaries. + + Args: + actions: An int32 tensor of shape [T, B, NUM_ACTIONS]. + actions_logp: A float32 tensor of shape [T, B]. + actions_entropy: A float32 tensor of shape [T, B]. + dones: A bool tensor of shape [T, B]. + behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS]. + target_logits: A float32 tensor of shape [T, B, NUM_ACTIONS]. + discount: A float32 scalar. + rewards: A float32 tensor of shape [T, B]. + values: A float32 tensor of shape [T, B]. + bootstrap_value: A float32 tensor of shape [B]. + """ + + # Compute vtrace on the CPU for better perf. + with tf.device("/cpu:0"): + vtrace_returns = vtrace.from_logits( + behaviour_policy_logits=behaviour_logits, + target_policy_logits=target_logits, + actions=tf.cast(actions, tf.int32), + discounts=tf.to_float(~dones) * discount, + rewards=rewards, + values=values, + bootstrap_value=bootstrap_value, + clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), + clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, + tf.float32)) + + # The policy gradients loss + self.pi_loss = -tf.reduce_sum( + actions_logp * vtrace_returns.pg_advantages) + + # The baseline loss + delta = values - vtrace_returns.vs + self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta)) + + # The entropy loss + self.entropy = tf.reduce_sum(actions_entropy) + + # The summed weighted loss + self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff + + self.entropy * entropy_coeff) + + +class VTracePolicyGraph(TFPolicyGraph): + def __init__(self, observation_space, action_space, config): + config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) + assert config["batch_mode"] == "truncate_episodes", \ + "Must use `truncate_episodes` batch mode with V-trace." + self.config = config + self.sess = tf.get_default_session() + + # Setup the policy + self.observations = tf.placeholder( + tf.float32, [None] + list(observation_space.shape)) + dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"]) + self.model = ModelCatalog.get_model(self.observations, logit_dim, + self.config["model"]) + action_dist = dist_class(self.model.outputs) + values = tf.reshape( + linear(self.model.last_layer, 1, "value", normc_initializer(1.0)), + [-1]) + self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + + # Setup the policy loss + if isinstance(action_space, gym.spaces.Box): + ac_size = action_space.shape[0] + actions = tf.placeholder(tf.float32, [None, ac_size], name="ac") + elif isinstance(action_space, gym.spaces.Discrete): + ac_size = action_space.n + actions = tf.placeholder(tf.int64, [None], name="ac") + else: + raise UnsupportedSpaceException( + "Action space {} is not supported for IMPALA.".format( + action_space)) + dones = tf.placeholder(tf.bool, [None], name="dones") + rewards = tf.placeholder(tf.float32, [None], name="rewards") + behaviour_logits = tf.placeholder( + tf.float32, [None, ac_size], name="behaviour_logits") + + def to_batches(tensor): + if self.config["model"]["use_lstm"]: + B = tf.shape(self.model.seq_lens)[0] + T = tf.shape(tensor)[0] // B + else: + # Important: chop the tensor into batches at known episode cut + # boundaries. TODO(ekl) this is kind of a hack + T = (self.config["sample_batch_size"] // + self.config["num_envs_per_worker"]) + B = tf.shape(tensor)[0] // T + rs = tf.reshape(tensor, + tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) + # swap B and T axes + return tf.transpose( + rs, + [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) + + if self.config["clip_rewards"]: + clipped_rewards = tf.clip_by_value(rewards, -1, 1) + else: + clipped_rewards = rewards + + # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. + self.loss = VTraceLoss( + actions=to_batches(actions)[:-1], + actions_logp=to_batches(action_dist.logp(actions))[:-1], + actions_entropy=to_batches(action_dist.entropy())[:-1], + dones=to_batches(dones)[:-1], + behaviour_logits=to_batches(behaviour_logits)[:-1], + target_logits=to_batches(self.model.outputs)[:-1], + discount=config["gamma"], + rewards=to_batches(clipped_rewards)[:-1], + values=to_batches(values)[:-1], + bootstrap_value=to_batches(values)[-1], + vf_loss_coeff=self.config["vf_loss_coeff"], + entropy_coeff=self.config["entropy_coeff"], + clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], + clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"]) + + # Initialize TFPolicyGraph + loss_in = [ + ("actions", actions), + ("dones", dones), + ("behaviour_logits", behaviour_logits), + ("rewards", rewards), + ("obs", self.observations), + ] + TFPolicyGraph.__init__( + self, + observation_space, + action_space, + self.sess, + obs_input=self.observations, + action_sampler=action_dist.sample(), + loss=self.loss.total_loss, + loss_inputs=loss_in, + state_inputs=self.model.state_in, + state_outputs=self.model.state_out, + seq_lens=self.model.seq_lens, + max_seq_len=self.config["model"]["max_seq_len"]) + + self.sess.run(tf.global_variables_initializer()) + + def optimizer(self): + return tf.train.AdamOptimizer(self.config["lr"]) + + def gradients(self, optimizer): + grads = tf.gradients(self.loss.total_loss, self.var_list) + self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) + clipped_grads = list(zip(self.grads, self.var_list)) + return clipped_grads + + def extra_compute_action_fetches(self): + return {"behaviour_logits": self.model.outputs} + + def extra_compute_grad_fetches(self): + if self.config.get("summarize"): + return { + "stats": { + "policy_loss": self.loss.pi_loss, + "value_loss": self.loss.vf_loss, + "entropy": self.loss.entropy, + "grad_gnorm": tf.global_norm(self._grads), + "var_gnorm": tf.global_norm(self.var_list), + }, + } + else: + return {} + + def postprocess_trajectory(self, sample_batch, other_agent_batches=None): + del sample_batch.data["new_obs"] # not used, so save some bandwidth + return sample_batch + + def get_initial_state(self): + return self.model.state_init diff --git a/python/ray/rllib/optimizers/__init__.py b/python/ray/rllib/optimizers/__init__.py index eadb38620de67..f7ede66f72872 100644 --- a/python/ray/rllib/optimizers/__init__.py +++ b/python/ray/rllib/optimizers/__init__.py @@ -1,4 +1,5 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer +from ray.rllib.optimizers.async_replay_optimizer import AsyncReplayOptimizer from ray.rllib.optimizers.async_samples_optimizer import AsyncSamplesOptimizer from ray.rllib.optimizers.async_gradients_optimizer import \ AsyncGradientsOptimizer @@ -7,6 +8,7 @@ from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer __all__ = [ - "PolicyOptimizer", "AsyncSamplesOptimizer", "AsyncGradientsOptimizer", - "SyncSamplesOptimizer", "SyncReplayOptimizer", "LocalMultiGPUOptimizer" + "PolicyOptimizer", "AsyncReplayOptimizer", "AsyncSamplesOptimizer", + "AsyncGradientsOptimizer", "SyncSamplesOptimizer", "SyncReplayOptimizer", + "LocalMultiGPUOptimizer" ] diff --git a/python/ray/rllib/optimizers/async_gradients_optimizer.py b/python/ray/rllib/optimizers/async_gradients_optimizer.py index 397fabba9ba97..fc7fdb2488a33 100644 --- a/python/ray/rllib/optimizers/async_gradients_optimizer.py +++ b/python/ray/rllib/optimizers/async_gradients_optimizer.py @@ -20,6 +20,7 @@ def _init(self, grads_per_step=100): self.wait_timer = TimerStat() self.dispatch_timer = TimerStat() self.grads_per_step = grads_per_step + self.learner_stats = {} if not self.remote_evaluators: raise ValueError( "Async optimizer requires at least 1 remote evaluator") @@ -41,6 +42,8 @@ def step(self): with self.wait_timer: fut, e = gradient_queue.pop(0) gradient, info = ray.get(fut) + if "stats" in info: + self.learner_stats = info["stats"] if gradient is not None: with self.apply_timer: @@ -61,4 +64,5 @@ def stats(self): "wait_time_ms": round(1000 * self.wait_timer.mean, 3), "apply_time_ms": round(1000 * self.apply_timer.mean, 3), "dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3), + "learner": self.learner_stats, }) diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py new file mode 100644 index 0000000000000..0037ea7a07385 --- /dev/null +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -0,0 +1,295 @@ +"""Implements Distributed Prioritized Experience Replay. + +https://arxiv.org/abs/1803.00933""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import random +import time +import threading + +import numpy as np +from six.moves import queue + +import ray +from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer +from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer +from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.utils.actors import TaskPool, create_colocated +from ray.rllib.utils.timer import TimerStat +from ray.rllib.utils.window_stat import WindowStat + +SAMPLE_QUEUE_DEPTH = 2 +REPLAY_QUEUE_DEPTH = 4 +LEARNER_QUEUE_MAX_SIZE = 16 + + +@ray.remote +class ReplayActor(object): + """A replay buffer shard. + + Ray actors are single-threaded, so for scalability multiple replay actors + may be created to increase parallelism.""" + + def __init__(self, num_shards, learning_starts, buffer_size, + train_batch_size, prioritized_replay_alpha, + prioritized_replay_beta, prioritized_replay_eps, + clip_rewards): + self.replay_starts = learning_starts // num_shards + self.buffer_size = buffer_size // num_shards + self.train_batch_size = train_batch_size + self.prioritized_replay_beta = prioritized_replay_beta + self.prioritized_replay_eps = prioritized_replay_eps + + self.replay_buffer = PrioritizedReplayBuffer( + self.buffer_size, + alpha=prioritized_replay_alpha, + clip_rewards=clip_rewards) + + # Metrics + self.add_batch_timer = TimerStat() + self.replay_timer = TimerStat() + self.update_priorities_timer = TimerStat() + + def get_host(self): + return os.uname()[1] + + def add_batch(self, batch): + PolicyOptimizer._check_not_multiagent(batch) + with self.add_batch_timer: + for row in batch.rows(): + self.replay_buffer.add(row["obs"], row["actions"], + row["rewards"], row["new_obs"], + row["dones"], row["weights"]) + + def replay(self): + with self.replay_timer: + if len(self.replay_buffer) < self.replay_starts: + return None + + (obses_t, actions, rewards, obses_tp1, dones, weights, + batch_indexes) = self.replay_buffer.sample( + self.train_batch_size, beta=self.prioritized_replay_beta) + + batch = SampleBatch({ + "obs": obses_t, + "actions": actions, + "rewards": rewards, + "new_obs": obses_tp1, + "dones": dones, + "weights": weights, + "batch_indexes": batch_indexes + }) + return batch + + def update_priorities(self, batch_indexes, td_errors): + with self.update_priorities_timer: + new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps) + self.replay_buffer.update_priorities(batch_indexes, new_priorities) + + def stats(self): + stat = { + "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), + "replay_time_ms": round(1000 * self.replay_timer.mean, 3), + "update_priorities_time_ms": round( + 1000 * self.update_priorities_timer.mean, 3), + } + stat.update(self.replay_buffer.stats()) + return stat + + +class LearnerThread(threading.Thread): + """Background thread that updates the local model from replay data. + + The learner thread communicates with the main thread through Queues. This + is needed since Ray operations can only be run on the main thread. In + addition, moving heavyweight gradient ops session runs off the main thread + improves overall throughput. + """ + + def __init__(self, local_evaluator): + threading.Thread.__init__(self) + self.learner_queue_size = WindowStat("size", 50) + self.local_evaluator = local_evaluator + self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) + self.outqueue = queue.Queue() + self.queue_timer = TimerStat() + self.grad_timer = TimerStat() + self.daemon = True + self.weights_updated = False + + def run(self): + while True: + self.step() + + def step(self): + with self.queue_timer: + ra, replay = self.inqueue.get() + if replay is not None: + with self.grad_timer: + td_error = self.local_evaluator.compute_apply(replay)[ + "td_error"] + self.outqueue.put((ra, replay, td_error, replay.count)) + self.learner_queue_size.push(self.inqueue.qsize()) + self.weights_updated = True + + +class AsyncReplayOptimizer(PolicyOptimizer): + """Main event loop of the Ape-X optimizer (async sampling with replay). + + This class coordinates the data transfers between the learner thread, + remote evaluators (Ape-X actors), and replay buffer actors. + + This optimizer requires that policy evaluators return an additional + "td_error" array in the info return of compute_gradients(). This error + term will be used for sample prioritization.""" + + def _init(self, + learning_starts=1000, + buffer_size=10000, + prioritized_replay=True, + prioritized_replay_alpha=0.6, + prioritized_replay_beta=0.4, + prioritized_replay_eps=1e-6, + train_batch_size=512, + sample_batch_size=50, + num_replay_buffer_shards=1, + max_weight_sync_delay=400, + clip_rewards=True, + debug=False): + + self.debug = debug + self.replay_starts = learning_starts + self.prioritized_replay_beta = prioritized_replay_beta + self.prioritized_replay_eps = prioritized_replay_eps + self.max_weight_sync_delay = max_weight_sync_delay + + self.learner = LearnerThread(self.local_evaluator) + self.learner.start() + + self.replay_actors = create_colocated(ReplayActor, [ + num_replay_buffer_shards, learning_starts, buffer_size, + train_batch_size, prioritized_replay_alpha, + prioritized_replay_beta, prioritized_replay_eps, clip_rewards + ], num_replay_buffer_shards) + assert len(self.remote_evaluators) > 0 + + # Stats + self.timers = { + k: TimerStat() + for k in [ + "put_weights", "get_samples", "enqueue", "sample_processing", + "replay_processing", "update_priorities", "train", "sample" + ] + } + self.num_weight_syncs = 0 + self.learning_started = False + + # Number of worker steps since the last weight update + self.steps_since_update = {} + + # Otherwise kick of replay tasks for local gradient updates + self.replay_tasks = TaskPool() + for ra in self.replay_actors: + for _ in range(REPLAY_QUEUE_DEPTH): + self.replay_tasks.add(ra, ra.replay.remote()) + + # Kick off async background sampling + self.sample_tasks = TaskPool() + weights = self.local_evaluator.get_weights() + for ev in self.remote_evaluators: + ev.set_weights.remote(weights) + self.steps_since_update[ev] = 0 + for _ in range(SAMPLE_QUEUE_DEPTH): + self.sample_tasks.add(ev, ev.sample_with_count.remote()) + + def step(self): + start = time.time() + sample_timesteps, train_timesteps = self._step() + time_delta = time.time() - start + self.timers["sample"].push(time_delta) + self.timers["sample"].push_units_processed(sample_timesteps) + if train_timesteps > 0: + self.learning_started = True + if self.learning_started: + self.timers["train"].push(time_delta) + self.timers["train"].push_units_processed(train_timesteps) + self.num_steps_sampled += sample_timesteps + self.num_steps_trained += train_timesteps + + def _step(self): + sample_timesteps, train_timesteps = 0, 0 + weights = None + + with self.timers["sample_processing"]: + completed = list(self.sample_tasks.completed()) + counts = ray.get([c[1][1] for c in completed]) + for i, (ev, (sample_batch, count)) in enumerate(completed): + sample_timesteps += counts[i] + + # Send the data to the replay buffer + random.choice( + self.replay_actors).add_batch.remote(sample_batch) + + # Update weights if needed + self.steps_since_update[ev] += counts[i] + if self.steps_since_update[ev] >= self.max_weight_sync_delay: + # Note that it's important to pull new weights once + # updated to avoid excessive correlation between actors + if weights is None or self.learner.weights_updated: + self.learner.weights_updated = False + with self.timers["put_weights"]: + weights = ray.put( + self.local_evaluator.get_weights()) + ev.set_weights.remote(weights) + self.num_weight_syncs += 1 + self.steps_since_update[ev] = 0 + + # Kick off another sample request + self.sample_tasks.add(ev, ev.sample_with_count.remote()) + + with self.timers["replay_processing"]: + for ra, replay in self.replay_tasks.completed(): + self.replay_tasks.add(ra, ra.replay.remote()) + with self.timers["get_samples"]: + samples = ray.get(replay) + with self.timers["enqueue"]: + self.learner.inqueue.put((ra, samples)) + + with self.timers["update_priorities"]: + while not self.learner.outqueue.empty(): + ra, replay, td_error, count = self.learner.outqueue.get() + ra.update_priorities.remote(replay["batch_indexes"], td_error) + train_timesteps += count + + return sample_timesteps, train_timesteps + + def stats(self): + replay_stats = ray.get(self.replay_actors[0].stats.remote()) + timing = { + "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) + for k in self.timers + } + timing["learner_grad_time_ms"] = round( + 1000 * self.learner.grad_timer.mean, 3) + timing["learner_dequeue_time_ms"] = round( + 1000 * self.learner.queue_timer.mean, 3) + stats = { + "sample_throughput": round(self.timers["sample"].mean_throughput, + 3), + "train_throughput": round(self.timers["train"].mean_throughput, 3), + "num_weight_syncs": self.num_weight_syncs, + } + debug_stats = { + "replay_shard_0": replay_stats, + "timing_breakdown": timing, + "pending_sample_tasks": self.sample_tasks.count, + "pending_replay_tasks": self.replay_tasks.count, + "learner_queue": self.learner.learner_queue_size.stats(), + } + if self.debug: + stats.update(debug_stats) + return dict(PolicyOptimizer.stats(self), **stats) diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index ebc8676cd073c..3b6bb861b4824 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -1,108 +1,28 @@ -"""Implements Distributed Prioritized Experience Replay. +"""Implements the IMPALA architecture. -https://arxiv.org/abs/1803.00933""" +https://arxiv.org/abs/1802.01561""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import random import time import threading -import numpy as np from six.moves import queue import ray from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.utils.actors import TaskPool, create_colocated +from ray.rllib.utils.actors import TaskPool from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat SAMPLE_QUEUE_DEPTH = 2 -REPLAY_QUEUE_DEPTH = 4 LEARNER_QUEUE_MAX_SIZE = 16 -@ray.remote -class ReplayActor(object): - """A replay buffer shard. - - Ray actors are single-threaded, so for scalability multiple replay actors - may be created to increase parallelism.""" - - def __init__(self, num_shards, learning_starts, buffer_size, - train_batch_size, prioritized_replay_alpha, - prioritized_replay_beta, prioritized_replay_eps, - clip_rewards): - self.replay_starts = learning_starts // num_shards - self.buffer_size = buffer_size // num_shards - self.train_batch_size = train_batch_size - self.prioritized_replay_beta = prioritized_replay_beta - self.prioritized_replay_eps = prioritized_replay_eps - - self.replay_buffer = PrioritizedReplayBuffer( - self.buffer_size, - alpha=prioritized_replay_alpha, - clip_rewards=clip_rewards) - - # Metrics - self.add_batch_timer = TimerStat() - self.replay_timer = TimerStat() - self.update_priorities_timer = TimerStat() - - def get_host(self): - return os.uname()[1] - - def add_batch(self, batch): - PolicyOptimizer._check_not_multiagent(batch) - with self.add_batch_timer: - for row in batch.rows(): - self.replay_buffer.add(row["obs"], row["actions"], - row["rewards"], row["new_obs"], - row["dones"], row["weights"]) - - def replay(self): - with self.replay_timer: - if len(self.replay_buffer) < self.replay_starts: - return None - - (obses_t, actions, rewards, obses_tp1, dones, weights, - batch_indexes) = self.replay_buffer.sample( - self.train_batch_size, beta=self.prioritized_replay_beta) - - batch = SampleBatch({ - "obs": obses_t, - "actions": actions, - "rewards": rewards, - "new_obs": obses_tp1, - "dones": dones, - "weights": weights, - "batch_indexes": batch_indexes - }) - return batch - - def update_priorities(self, batch_indexes, td_errors): - with self.update_priorities_timer: - new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps) - self.replay_buffer.update_priorities(batch_indexes, new_priorities) - - def stats(self): - stat = { - "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), - "replay_time_ms": round(1000 * self.replay_timer.mean, 3), - "update_priorities_time_ms": round( - 1000 * self.update_priorities_timer.mean, 3), - } - stat.update(self.replay_buffer.stats()) - return stat - - class LearnerThread(threading.Thread): - """Background thread that updates the local model from replay data. + """Background thread that updates the local model from sample trajectories. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In @@ -119,7 +39,8 @@ def __init__(self, local_evaluator): self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.daemon = True - self.weights_updated = False + self.weights_updated = 0 + self.stats = {} def run(self): while True: @@ -127,86 +48,57 @@ def run(self): def step(self): with self.queue_timer: - ra, replay = self.inqueue.get() - if replay is not None: + ra, batch = self.inqueue.get() + + if batch is not None: with self.grad_timer: - td_error = self.local_evaluator.compute_apply(replay)[ - "td_error"] - self.outqueue.put((ra, replay, td_error, replay.count)) + fetches = self.local_evaluator.compute_apply(batch) + self.weights_updated += 1 + if "stats" in fetches: + self.stats = fetches["stats"] + self.outqueue.put(batch.count) self.learner_queue_size.push(self.inqueue.qsize()) - self.weights_updated = True class AsyncSamplesOptimizer(PolicyOptimizer): - """Main event loop of the Ape-X optimizer (async sampling with replay). - - This class coordinates the data transfers between the learner thread, - remote evaluators (Ape-X actors), and replay buffer actors. + """Main event loop of the IMPALA architecture. - This optimizer requires that policy evaluators return an additional - "td_error" array in the info return of compute_gradients(). This error - term will be used for sample prioritization.""" + This class coordinates the data transfers between the learner thread + and remote evaluators (IMPALA actors). + """ - def _init(self, - learning_starts=1000, - buffer_size=10000, - prioritized_replay=True, - prioritized_replay_alpha=0.6, - prioritized_replay_beta=0.4, - prioritized_replay_eps=1e-6, - train_batch_size=512, - sample_batch_size=50, - num_replay_buffer_shards=1, - max_weight_sync_delay=400, - clip_rewards=True, - debug=False): + def _init(self, train_batch_size=512, sample_batch_size=50, debug=False): self.debug = debug - self.replay_starts = learning_starts - self.prioritized_replay_beta = prioritized_replay_beta - self.prioritized_replay_eps = prioritized_replay_eps - self.max_weight_sync_delay = max_weight_sync_delay + self.learning_started = False + self.train_batch_size = train_batch_size self.learner = LearnerThread(self.local_evaluator) self.learner.start() - self.replay_actors = create_colocated(ReplayActor, [ - num_replay_buffer_shards, learning_starts, buffer_size, - train_batch_size, prioritized_replay_alpha, - prioritized_replay_beta, prioritized_replay_eps, clip_rewards - ], num_replay_buffer_shards) assert len(self.remote_evaluators) > 0 # Stats self.timers = { k: TimerStat() - for k in [ - "put_weights", "get_samples", "enqueue", "sample_processing", - "replay_processing", "update_priorities", "train", "sample" - ] + for k in + ["put_weights", "enqueue", "sample_processing", "train", "sample"] } self.num_weight_syncs = 0 self.learning_started = False - # Number of worker steps since the last weight update - self.steps_since_update = {} - - # Otherwise kick of replay tasks for local gradient updates - self.replay_tasks = TaskPool() - for ra in self.replay_actors: - for _ in range(REPLAY_QUEUE_DEPTH): - self.replay_tasks.add(ra, ra.replay.remote()) - # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) - self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): - self.sample_tasks.add(ev, ev.sample_with_count.remote()) + self.sample_tasks.add(ev, ev.sample.remote()) + + self.batch_buffer = [] def step(self): + assert self.learner.is_alive() start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start @@ -225,50 +117,37 @@ def _step(self): weights = None with self.timers["sample_processing"]: - completed = list(self.sample_tasks.completed()) - counts = ray.get([c[1][1] for c in completed]) - for i, (ev, (sample_batch, count)) in enumerate(completed): - sample_timesteps += counts[i] - - # Send the data to the replay buffer - random.choice( - self.replay_actors).add_batch.remote(sample_batch) - - # Update weights if needed - self.steps_since_update[ev] += counts[i] - if self.steps_since_update[ev] >= self.max_weight_sync_delay: - # Note that it's important to pull new weights once - # updated to avoid excessive correlation between actors - if weights is None or self.learner.weights_updated: - self.learner.weights_updated = False - with self.timers["put_weights"]: - weights = ray.put( - self.local_evaluator.get_weights()) - ev.set_weights.remote(weights) - self.num_weight_syncs += 1 - self.steps_since_update[ev] = 0 + for ev, sample_batch in self.sample_tasks.completed_prefetch(): + sample_batch = ray.get(sample_batch) + sample_timesteps += sample_batch.count + self.batch_buffer.append(sample_batch) + if sum(b.count + for b in self.batch_buffer) >= self.train_batch_size: + train_batch = self.batch_buffer[0].concat_samples( + self.batch_buffer) + with self.timers["enqueue"]: + self.learner.inqueue.put((ev, train_batch)) + self.batch_buffer = [] + + # Note that it's important to pull new weights once + # updated to avoid excessive correlation between actors + if weights is None or self.learner.weights_updated: + self.learner.weights_updated = False + with self.timers["put_weights"]: + weights = ray.put(self.local_evaluator.get_weights()) + ev.set_weights.remote(weights) + self.num_weight_syncs += 1 # Kick off another sample request - self.sample_tasks.add(ev, ev.sample_with_count.remote()) - - with self.timers["replay_processing"]: - for ra, replay in self.replay_tasks.completed(): - self.replay_tasks.add(ra, ra.replay.remote()) - with self.timers["get_samples"]: - samples = ray.get(replay) - with self.timers["enqueue"]: - self.learner.inqueue.put((ra, samples)) + self.sample_tasks.add(ev, ev.sample.remote()) - with self.timers["update_priorities"]: - while not self.learner.outqueue.empty(): - ra, replay, td_error, count = self.learner.outqueue.get() - ra.update_priorities.remote(replay["batch_indexes"], td_error) - train_timesteps += count + while not self.learner.outqueue.empty(): + count = self.learner.outqueue.get() + train_timesteps += count return sample_timesteps, train_timesteps def stats(self): - replay_stats = ray.get(self.replay_actors[0].stats.remote()) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers @@ -284,12 +163,12 @@ def stats(self): "num_weight_syncs": self.num_weight_syncs, } debug_stats = { - "replay_shard_0": replay_stats, "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, - "pending_replay_tasks": self.replay_tasks.count, "learner_queue": self.learner.learner_queue_size.stats(), } if self.debug: stats.update(debug_stats) + if self.learner.stats: + stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) diff --git a/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml b/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml new file mode 100644 index 0000000000000..9525f4115521e --- /dev/null +++ b/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml @@ -0,0 +1,11 @@ +# This can reach 18-19 reward within 10 minutes on a Tesla M60 GPU (e.g., G3 EC2 node) +# with 32 workers and 10 envs per worker. This is more efficient than the non-vectorized +# configuration which requires 128 workers to achieve the same performance. +pong-impala-vectorized: + env: PongNoFrameskip-v4 + run: IMPALA + config: + sample_batch_size: 500 # 50 * num_envs_per_worker + train_batch_size: 500 + num_workers: 32 + num_envs_per_worker: 10 diff --git a/python/ray/rllib/tuned_examples/pong-impala.yaml b/python/ray/rllib/tuned_examples/pong-impala.yaml new file mode 100644 index 0000000000000..b54c79849c5ab --- /dev/null +++ b/python/ray/rllib/tuned_examples/pong-impala.yaml @@ -0,0 +1,13 @@ +# This can reach 18-19 reward within 10 minutes on a Tesla M60 GPU (e.g., G3 EC2 node): +# 128 workers -> 8 minutes +# 32 workers -> 17 minutes +# 16 workers -> 40 min+ +# See also: pong-impala-vectorized.yaml +pong-impala: + env: PongNoFrameskip-v4 + run: IMPALA + config: + sample_batch_size: 50 + train_batch_size: 500 + num_workers: 128 + num_envs_per_worker: 1 diff --git a/python/ray/rllib/utils/actors.py b/python/ray/rllib/utils/actors.py index c663087eb04bf..68788cbc99da3 100644 --- a/python/ray/rllib/utils/actors.py +++ b/python/ray/rllib/utils/actors.py @@ -12,6 +12,7 @@ class TaskPool(object): def __init__(self): self._tasks = {} self._objects = {} + self._fetching = [] def add(self, worker, all_obj_ids): if isinstance(all_obj_ids, list): @@ -28,6 +29,25 @@ def completed(self): for obj_id in ready: yield (self._tasks.pop(obj_id), self._objects.pop(obj_id)) + def completed_prefetch(self): + """Similar to completed but only returns once the object is local. + + Assumes obj_id only is one id.""" + + for worker, obj_id in self.completed(): + plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id()) + ray.worker.global_worker.plasma_client.fetch([plasma_id]) + self._fetching.append((worker, obj_id)) + + remaining = [] + for worker, obj_id in self._fetching: + plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id()) + if ray.worker.global_worker.plasma_client.contains(plasma_id): + yield (worker, obj_id) + else: + remaining.append((worker, obj_id)) + self._fetching = remaining + @property def count(self): return len(self._tasks) diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index dea16b4d705fe..51e27602d385f 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -128,6 +128,13 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2, "use_pytorch": true}' +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env CartPole-v1 \ + --run A3C \ + --stop '{"training_iteration": 2}' \ + --config '{"num_workers": 2, "model": {"use_lstm": true}}' + docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ @@ -177,6 +184,20 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 1}' +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env CartPole-v0 \ + --run IMPALA \ + --stop '{"training_iteration": 2}' \ + --config '{"gpu": false, "num_workers": 2, "min_iter_time_s": 1}' + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env CartPole-v0 \ + --run IMPALA \ + --stop '{"training_iteration": 2}' \ + --config '{"gpu": false, "num_workers": 2, "min_iter_time_s": 1, "model": {"use_lstm": true}}' + docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env MountainCarContinuous-v0 \