Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 331777247
Change-Id: Ieca800cbfd6eb9484689598baf351a408ffff2aa
  • Loading branch information
RLaxDev authored and hbq1 committed Sep 15, 2020
1 parent cac9c2a commit a8ab342
Show file tree
Hide file tree
Showing 14 changed files with 433 additions and 33 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ RLax can be installed with pip directly from github, with the following command:

`pip install git+git://github.com/deepmind/rlax.git`.

or from PyPI:

`pip install rlax`

All RLax code may then be just in time compiled for different hardware
(e.g. CPU, GPU, TPU) using `jax.jit`.

In order to run the `examples/` you will also need to install
In order to run the `examples/` you will also need to clone the repo and
install the additional requirements:
[optax](https://github.com/deepmind/optax),
[haiku](https://github.com/deepmind/haiku), and
[bsuite](https://github.com/deepmind/bsuite).

Expand Down Expand Up @@ -42,7 +48,7 @@ on BSuite's version of the Catch environment (a common unit-test for
agent development in the reinforcement learning literature):

Other examples of JAX reinforcement learning agents using `rlax` can be found in
[bsuite](`https://github.com/deepmind/bsuite/tree/master/bsuite/baselines`).
[bsuite](https://github.com/deepmind/bsuite/tree/master/bsuite/baselines).

## Background

Expand Down
42 changes: 28 additions & 14 deletions rlax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from rlax._src.multistep import general_off_policy_returns_from_q_and_v
from rlax._src.multistep import lambda_returns
from rlax._src.multistep import n_step_bootstrapped_returns
from rlax._src.multistep import truncated_generalized_advantage_estimation
from rlax._src.nested_updates import incremental_update
from rlax._src.nested_updates import periodic_update
from rlax._src.nonlinear_bellman import HYPERBOLIC_SIN_PAIR
Expand All @@ -68,6 +69,7 @@
from rlax._src.nonlinear_bellman import transformed_q_lambda
from rlax._src.nonlinear_bellman import transformed_retrace
from rlax._src.nonlinear_bellman import TxPair
from rlax._src.policy_gradients import clipped_surrogate_pg_loss
from rlax._src.policy_gradients import dpg_loss
from rlax._src.policy_gradients import entropy_loss
from rlax._src.policy_gradients import policy_gradient_loss
Expand All @@ -78,6 +80,7 @@
from rlax._src.pop_art import normalize
from rlax._src.pop_art import pop
from rlax._src.pop_art import popart
from rlax._src.pop_art import PopArtState
from rlax._src.pop_art import unnormalize
from rlax._src.pop_art import unnormalize_linear
from rlax._src.transforms import identity
Expand All @@ -101,6 +104,7 @@
from rlax._src.value_learning import persistent_q_learning
from rlax._src.value_learning import q_lambda
from rlax._src.value_learning import q_learning
from rlax._src.value_learning import quantile_expected_sarsa
from rlax._src.value_learning import quantile_q_learning
from rlax._src.value_learning import qv_learning
from rlax._src.value_learning import qv_max
Expand All @@ -127,6 +131,7 @@
"categorical_q_learning",
"categorical_td_learning",
"clip_gradient",
"clipped_surrogate_pg_loss",
"discounted_returns",
"double_q_learning",
"dpg_loss",
Expand All @@ -140,33 +145,43 @@
"HYPERBOLIC_SIN_PAIR",
"squashed_gaussian",
"clipped_entropy_softmax",
"art",
"compute_parametric_kl_penalty_and_dual_loss",
"general_off_policy_returns_from_action_values",
"general_off_policy_returns_from_q_and_v",
"greedy",
"huber_loss",
"identity",
"IDENTITY_PAIR",
"incremental_update",
"l2_loss",
"LagrangePenalty",
"lambda_returns",
"leaky_vtrace",
"likelihood",
"log_loss",
"logit",
"log_loss",
"mpo_compute_weights_and_temperature_loss",
"mpo_loss",
"multivariate_normal_kl_divergence",
"normalize",
"n_step_bootstrapped_returns",
"one_hot",
"periodic_update",
"persistent_q_learning",
"pixel_control_rewards",
"policy_gradient_loss",
"pop",
"popart",
"PopArtState",
"power",
"q_lambda",
"q_learning",
"general_off_policy_returns_from_action_values",
"general_off_policy_returns_from_q_and_v",
"qpg_loss",
"quantile_expected_sarsa",
"quantile_q_learning",
"qv_learning",
"qv_max",
"q_lambda",
"q_learning",
"rm_loss",
"rpg_loss",
"sarsa",
Expand All @@ -175,32 +190,31 @@
"signed_expm1",
"signed_hyperbolic",
"SIGNED_HYPERBOLIC_PAIR",
"SIGNED_LOGP1_PAIR",
"signed_logp1",
"SIGNED_LOGP1_PAIR",
"signed_parabolic",
"softmax",
"td_lambda",
"td_learning",
"transform_to_2hot",
"transform_from_2hot",
"transformed_general_off_policy_returns_from_action_values",
"transformed_lambda_returns",
"transformed_n_step_q_learning",
"transformed_n_step_returns",
"transformed_q_lambda",
"transformed_retrace",
"transform_from_2hot",
"transform_to_2hot",
"tree_map_zipped",
"tree_select",
"tree_split_key",
"truncated_generalized_advantage_estimation",
"TxPair",
"vtrace",
"vtrace_td_error_and_advantage",
"LagrangePenalty",
"mpo_compute_weights_and_temperature_loss",
"mpo_loss",
"compute_parametric_kl_penalty_and_dual_loss",
"unnormalize",
"unnormalize_linear",
"vmpo_compute_weights_and_temperature_loss",
"vmpo_loss",
"vtrace",
"vtrace_td_error_and_advantage",
)

# _________________________________________
Expand Down
2 changes: 1 addition & 1 deletion rlax/_src/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def log_prob_fn(sample: Array, logits: Array):
def entropy_fn(logits: Array):
probs = jax.nn.softmax(logits / temperature)
probs = _mix_with_uniform(probs, epsilon)
return -jnp.sum(probs * jnp.log(probs), axis=-1)
return -jnp.nansum(probs * jnp.log(probs), axis=-1)

def kl_fn(p_logits: Array, q_logits: Array):
return categorical_kl_divergence(p_logits, q_logits, temperature)
Expand Down
54 changes: 51 additions & 3 deletions rlax/_src/multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
of experience; trajectories are not assumed to align with episode boundaries,
and bootstrapping is used to estimate returns beyond the end of a trajectory.
"""

from typing import Union
import chex
import jax.numpy as jnp
from rlax._src import base

Array = chex.Array
Scalar = chex.Scalar
Numeric = chex.Numeric


Expand Down Expand Up @@ -95,6 +96,7 @@ def lambda_returns(
"""
chex.assert_rank([r_t, discount_t, v_t, lambda_], [1, 1, 1, {0, 1}])
chex.assert_type([r_t, discount_t, v_t, lambda_], float)
chex.assert_equal_shape([r_t, discount_t, v_t])

# If scalar make into vector.
lambda_ = jnp.ones_like(discount_t) * lambda_
Expand Down Expand Up @@ -132,6 +134,7 @@ def n_step_bootstrapped_returns(
"""
chex.assert_rank([r_t, discount_t, v_t], 1)
chex.assert_type([r_t, discount_t, v_t], float)
chex.assert_equal_shape([r_t, discount_t, v_t])
seq_len = r_t.shape[0]

# Pad end of reward and discount sequences with 0 and 1 respectively.
Expand Down Expand Up @@ -211,10 +214,11 @@ def importance_corrected_td_errors(
values: sequence of state values under π for all timesteps t in [0, T].
Returns:
Off-policy estimates of the multistep lambda returns from each state.
Off-policy estimates of the multistep td errors.
"""
chex.assert_rank([r_t, discount_t, rho_tm1, values], [1, 1, 1, 1])
chex.assert_type([r_t, discount_t, rho_tm1, values], float)
chex.assert_equal_shape([r_t, discount_t, rho_tm1, values[1:]])

v_tm1 = values[:-1] # Predictions to compute errors for.
v_t = values[1:] # Values for bootstrapping.
Expand All @@ -233,6 +237,48 @@ def importance_corrected_td_errors(
return rho_tm1 * jnp.array(errors)


def truncated_generalized_advantage_estimation(
r_t: Array,
discount_t: Array,
lambda_: Union[Array, Scalar],
values: Array,
) -> Array:
"""Computes truncated generalized advantage estimates for a sequence length k.
The advantages are computed in a backwards fashion according to the equation:
Âₜ = δₜ + (γλ) * δₜ₊₁ + ... + ... + (γλ)ᵏ⁻ᵗ⁺¹ * δₖ₋₁
where δₜ = rₜ₊₁ + γₜ₊₁ * v(sₜ₊₁) - v(sₜ).
See Proximal Policy Optimization Algorithms, Schulman et al.:
https://arxiv.org/abs/1707.06347
* Note: This paper uses a different notation than the RLax standard
convention that follows Sutton & Barto. We use rₜ₊₁ to denote the reward
received after acting in state sₜ, while the PPO paper uses rₜ.
Args:
r_t: Sequence of rewards at times [1, k]
discount_t: Sequence of discounts at times [1, k]
lambda_: Mixing parameter; a scalar or sequence of lambda_t at times [1, k]
values: Sequence of values under π at times [0, k]
Returns:
Multistep truncated generalized advantage estimation.
"""
chex.assert_rank([r_t, values, discount_t], 1)
chex.assert_type([r_t, values, discount_t], float)
lambda_ = jnp.ones_like(discount_t) * lambda_ # If scalar, make into vector.

delta_t = r_t + discount_t * values[1:] - values[:-1]

# Iterate backwards to calculate advantages.
advantage_t = [0.]
for t in reversed(range(delta_t.shape[0])):
advantage_t.insert(0,
delta_t[t] + lambda_[t] * discount_t[t] * advantage_t[0])
return jnp.array(advantage_t[:-1])


def general_off_policy_returns_from_action_values(
q_t: Array,
a_t: Array,
Expand Down Expand Up @@ -274,6 +320,8 @@ def general_off_policy_returns_from_action_values(
chex.assert_rank([q_t, a_t, r_t, discount_t, c_t, pi_t], [2, 1, 1, 1, 1, 2])
chex.assert_type([q_t, a_t, r_t, discount_t, c_t, pi_t],
[float, int, float, float, float, float])
chex.assert_equal_shape(
[q_t[..., 0], a_t, r_t, discount_t, c_t, pi_t[..., 0]])

# Get the expected values and the values of actually selected actions.
exp_q_t = (pi_t * q_t).sum(axis=-1)
Expand Down Expand Up @@ -325,6 +373,7 @@ def general_off_policy_returns_from_q_and_v(
"""
chex.assert_rank([q_t, v_t, r_t, discount_t, c_t], 1)
chex.assert_type([q_t, v_t, r_t, discount_t, c_t], float)
chex.assert_equal_shape([q_t, v_t[:-1], r_t[:-1], discount_t[:-1], c_t])

# Work backwards to compute `G_K-1`, ..., `G_1`, `G_0`.
g = r_t[-1] + discount_t[-1] * v_t[-1] # G_K-1.
Expand All @@ -334,4 +383,3 @@ def general_off_policy_returns_from_q_and_v(
returns.insert(0, g)

return jnp.array(returns)

84 changes: 84 additions & 0 deletions rlax/_src/multistep_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl.testing import parameterized
import chex
import jax
import jax.numpy as jnp
import numpy as np
from rlax._src import multistep

Expand Down Expand Up @@ -152,6 +153,89 @@ def test_importance_corrected_td_errors_batch(self):
np.testing.assert_allclose(td_direct, td_from_returns, rtol=1e-5)


class TruncatedGeneralizedAdvantageEstimationTest(parameterized.TestCase):

def setUp(self):
super(TruncatedGeneralizedAdvantageEstimationTest, self).setUp()

self.r_t = jnp.array([[0., 0., 1., 0., -0.5],
[0., 0., 0., 0., 1.]])
self.v_t = jnp.array([[1., 4., -3., -2., -1., -1.],
[-3., -2., -1., 0.0, 5., -1.]])
self.discount_t = jnp.array([[0.99, 0.99, 0.99, 0.99, 0.99],
[0.9, 0.9, 0.9, 0.0, 0.9]])
self.dummy_rho_tm1 = jnp.array([[1., 1., 1., 1., 1],
[1., 1., 1., 1., 1.]])
self.array_lambda = jnp.array([[0.9, 0.9, 0.9, 0.9, 0.9],
[0.9, 0.9, 0.9, 0.9, 0.9]])

# Different expected results for different values of lambda.
self.expected = dict()
self.expected[1.] = np.array(
[[-1.45118, -4.4557, 2.5396, 0.5249, -0.49],
[3., 2., 1., 0., -4.9]],
dtype=np.float32)
self.expected[0.7] = np.array(
[[-0.676979, -5.248167, 2.4846, 0.6704, -0.49],
[2.2899, 1.73, 1., 0., -4.9]],
dtype=np.float32)
self.expected[0.4] = np.array(
[[0.56731, -6.042, 2.3431, 0.815, -0.49],
[1.725, 1.46, 1., 0., -4.9]],
dtype=np.float32)

@chex.all_variants()
@parameterized.named_parameters(
('lambda1', 1.0),
('lambda0.7', 0.7),
('lambda0.4', 0.4))
def test_truncated_gae(self, lambda_):
"""Tests truncated GAE for a full batch."""
batched_advantage_fn_variant = self.variant(jax.vmap(
multistep.truncated_generalized_advantage_estimation,
in_axes=(0, 0, None, 0), out_axes=0))
actual = batched_advantage_fn_variant(
self.r_t, self.discount_t, lambda_, self.v_t)
np.testing.assert_allclose(self.expected[lambda_], actual, atol=1e-3)

@chex.all_variants()
def test_array_lambda(self):
"""Tests that truncated GAE is consistent with scalar or array lambda_."""
scalar_lambda_fn = self.variant(jax.vmap(
multistep.truncated_generalized_advantage_estimation,
in_axes=(0, 0, None, 0), out_axes=0))
array_lambda_fn = self.variant(jax.vmap(
multistep.truncated_generalized_advantage_estimation))
scalar_lambda_result = scalar_lambda_fn(
self.r_t, self.discount_t, 0.9, self.v_t)
array_lambda_result = array_lambda_fn(
self.r_t, self.discount_t, self.array_lambda, self.v_t)
np.testing.assert_allclose(scalar_lambda_result, array_lambda_result,
atol=1e-3)

@chex.all_variants()
@parameterized.named_parameters(
('lambda1', 1.0),
('lambda0.7', 0.7),
('lambda0.4', 0.4))
def test_gae_as_special_case_of_importance_corrected_td_errors(self, lambda_):
"""Tests that truncated GAE yields same output as importance corrected td errors with dummy ratios."""
batched_gae_fn_variant = self.variant(jax.vmap(
multistep.truncated_generalized_advantage_estimation,
in_axes=(0, 0, None, 0), out_axes=0))
gae_result = batched_gae_fn_variant(
self.r_t, self.discount_t, lambda_, self.v_t)

batched_ictd_errors_fn_variant = self.variant(jax.vmap(
multistep.importance_corrected_td_errors))
ictd_errors_result = batched_ictd_errors_fn_variant(
self.r_t,
self.discount_t,
self.dummy_rho_tm1,
jnp.ones_like(self.discount_t) * lambda_,
self.v_t)
np.testing.assert_allclose(gae_result, ictd_errors_result, atol=1e-3)

if __name__ == '__main__':
jax.config.update('jax_numpy_rank_promotion', 'raise')
absltest.main()
Loading

0 comments on commit a8ab342

Please sign in to comment.