Skip to content

Commit

Permalink
[RLlib] Trajectory view API: Enable by default for PPO, IMPALA, PG, A…
Browse files Browse the repository at this point in the history
…3C (tf and torch). (ray-project#11747)
  • Loading branch information
sven1977 authored Nov 12, 2020
1 parent 59ccbc0 commit 62c7ab5
Show file tree
Hide file tree
Showing 38 changed files with 578 additions and 477 deletions.
4 changes: 2 additions & 2 deletions rllib/agents/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
# Workers sample async. Note that this increases the effective
# rollout_fragment_length by up to 5x due to async buffering of batches.
"sample_async": True,
# Switch on Trajectory View API for A2/3C by default.
# NOTE: Only supported for PyTorch so far.
# Use the new "trajectory view API" to collect samples and produce
# model- and policy inputs.
"_use_trajectory_view_api": True,
})
# __sphinx_doc_end__
Expand Down
42 changes: 0 additions & 42 deletions rllib/agents/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import gym
from typing import Dict

import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()
Expand Down Expand Up @@ -89,42 +84,6 @@ def _value(self, obs):
return self.model.value_function()[0]


def view_requirements_fn(policy: Policy) -> Dict[str, ViewRequirement]:
"""Function defining the view requirements for training/postprocessing.
These go on top of the Policy's Model's own view requirements used for
the action computing forward passes.
Args:
policy (Policy): The Policy that requires the returned
ViewRequirements.
Returns:
Dict[str, ViewRequirement]: The Policy's view requirements.
"""
ret = {
# Next obs are needed for PPO postprocessing, but not in loss.
SampleBatch.NEXT_OBS: ViewRequirement(
SampleBatch.OBS, shift=1, used_for_training=False),
# Created during postprocessing.
Postprocessing.ADVANTAGES: ViewRequirement(shift=0),
Postprocessing.VALUE_TARGETS: ViewRequirement(shift=0),
# Needed for PPO's loss function.
SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(shift=0),
SampleBatch.ACTION_LOGP: ViewRequirement(shift=0),
SampleBatch.VF_PREDS: ViewRequirement(shift=0),
}
# If policy is recurrent, have to add state_out for PPO postprocessing
# (calculating GAE from next-obs and last state-out).
if policy.is_recurrent():
init_state = policy.get_initial_state()
for i, s in enumerate(init_state):
ret["state_out_{}".format(i)] = ViewRequirement(
space=gym.spaces.Box(-1.0, 1.0, shape=(s.shape[0], )),
used_for_training=False)
return ret


A3CTorchPolicy = build_torch_policy(
name="A3CTorchPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
Expand All @@ -135,5 +94,4 @@ def view_requirements_fn(policy: Policy) -> Dict[str, ViewRequirement]:
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=torch_optimizer,
mixins=[ValueNetworkMixin],
view_requirements_fn=view_requirements_fn,
)
9 changes: 8 additions & 1 deletion rllib/agents/dqn/tests/test_simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,14 @@ def test_simple_q_loss_function(self):
SampleBatch.ACTIONS: np.array([0, 1]),
SampleBatch.REWARDS: np.array([0.4, -1.23]),
SampleBatch.DONES: np.array([False, False]),
SampleBatch.NEXT_OBS: np.random.random(size=(2, 4))
SampleBatch.NEXT_OBS: np.random.random(size=(2, 4)),
SampleBatch.EPS_ID: np.array([1234, 1234]),
SampleBatch.AGENT_INDEX: np.array([0, 0]),
SampleBatch.ACTION_LOGP: np.array([-0.1, -0.1]),
SampleBatch.ACTION_DIST_INPUTS: np.array([[0.1, 0.2],
[-0.1, -0.2]]),
SampleBatch.ACTION_PROB: np.array([0.1, 0.2]),
"q_values": np.array([[0.1, 0.2], [0.2, 0.1]]),
}
# Get model vars for computing expected model outs (q-vals).
# 0=layer-kernel; 1=layer-bias; 2=q-val-kernel; 3=q-val-bias
Expand Down
4 changes: 4 additions & 0 deletions rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@
# Callback for APPO to use to update KL, target network periodically.
# The input to the callback is the learner fetches dict.
"after_train_step": None,

# Use the new "trajectory view API" to collect samples and produce
# model- and policy inputs.
"_use_trajectory_view_api": True,
})
# __sphinx_doc_end__
# yapf: enable
Expand Down
4 changes: 1 addition & 3 deletions rllib/agents/impala/vtrace_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.impala.vtrace_tf_policy import postprocess_trajectory
import ray.rllib.agents.impala.vtrace_torch as vtrace
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.sample_batch import SampleBatch
Expand Down Expand Up @@ -209,7 +208,7 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False):
T = tensor.shape[0] // B
else:
# Important: chop the tensor into batches at known episode cut
# boundaries. TODO(ekl) this is kind of a hack
# boundaries.
T = policy.config["rollout_fragment_length"]
B = tensor.shape[0] // T
rs = torch.reshape(tensor, [B, T] + list(tensor.shape[1:]))
Expand Down Expand Up @@ -266,7 +265,6 @@ def setup_mixins(policy, obs_space, action_space, config):
loss_fn=build_vtrace_loss,
get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG,
stats_fn=stats,
postprocess_fn=postprocess_trajectory,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=choose_optimizer,
before_init=setup_mixins,
Expand Down
13 changes: 7 additions & 6 deletions rllib/agents/maml/maml_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,6 @@ def split_placeholders(self, placeholder, split):

def maml_loss(policy, model, dist_class, train_batch):
logits, state = model.from_batch(train_batch)

policy._loss_input_dict["split"] = tf1.placeholder(
tf.int32,
name="Meta-Update-Splitting",
shape=(policy.config["inner_adaptation_steps"] + 1,
policy.config["num_workers"]))
policy.cur_lr = policy.config["lr"]

if policy.config["worker_index"]:
Expand Down Expand Up @@ -413,6 +407,13 @@ def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
KLCoeffMixin.__init__(policy, config)

# Create the `split` placeholder.
policy._loss_input_dict["split"] = tf1.placeholder(
tf.int32,
name="Meta-Update-Splitting",
shape=(policy.config["inner_adaptation_steps"] + 1,
policy.config["num_workers"]))


MAMLTFPolicy = build_tf_policy(
name="MAMLTFPolicy",
Expand Down
6 changes: 5 additions & 1 deletion rllib/agents/maml/maml_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ def maml_loss(policy, model, dist_class, train_batch):
else:
policy.var_list = model.named_parameters()

# `split` may not exist yet (during test-loss call), use a dummy value.
# Cannot use get here due to train_batch being a TrackingDict.
split = train_batch["split"] if "split" in train_batch else \
torch.tensor([[8, 8], [8, 8]])
policy.loss_obj = MAMLLoss(
model=model,
dist_class=dist_class,
Expand All @@ -357,7 +361,7 @@ def maml_loss(policy, model, dist_class, train_batch):
policy_vars=policy.var_list,
obs=train_batch[SampleBatch.CUR_OBS],
num_tasks=policy.config["num_workers"],
split=train_batch["split"],
split=split,
config=policy.config,
inner_adaptation_steps=policy.config["inner_adaptation_steps"],
entropy_coeff=policy.config["entropy_coeff"],
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/pg/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
# Switch on Trajectory View API for PG by default.
# NOTE: Only supported for PyTorch so far.
# Use the new "trajectory view API" to collect samples and produce
# model- and policy inputs.
"_use_trajectory_view_api": True,
})

Expand Down
2 changes: 0 additions & 2 deletions rllib/agents/pg/pg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Dict, List, Type, Union

import ray
from ray.rllib.agents.a3c.a3c_torch_policy import view_requirements_fn
from ray.rllib.agents.pg.utils import post_process_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
Expand Down Expand Up @@ -79,5 +78,4 @@ def pg_loss_stats(policy: Policy,
loss_fn=pg_torch_loss,
stats_fn=pg_loss_stats,
postprocess_fn=post_process_advantages,
view_requirements_fn=view_requirements_fn,
)
31 changes: 16 additions & 15 deletions rllib/agents/pg/tests/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def test_pg_loss_functions(self):

# Fake CartPole episode of n time steps.
train_batch = {
SampleBatch.CUR_OBS: np.array([[0.1, 0.2, 0.3,
0.4], [0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2]]),
SampleBatch.OBS: np.array([[0.1, 0.2, 0.3,
0.4], [0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2]]),
SampleBatch.ACTIONS: np.array([0, 1, 1]),
SampleBatch.PREV_ACTIONS: np.array([1, 0, 1]),
SampleBatch.REWARDS: np.array([1.0, 1.0, 1.0]),
SampleBatch.PREV_REWARDS: np.array([-1.0, -1.0, -1.0]),
SampleBatch.DONES: np.array([False, False, True])
SampleBatch.DONES: np.array([False, False, True]),
SampleBatch.EPS_ID: np.array([1234, 1234, 1234]),
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
}

for fw, sess in framework_iterator(config, session=True):
Expand All @@ -63,31 +63,32 @@ def test_pg_loss_functions(self):
# to train_batch dict.
# A = [0.99^2 * 1.0 + 0.99 * 1.0 + 1.0, 0.99 * 1.0 + 1.0, 1.0] =
# [2.9701, 1.99, 1.0]
train_batch = pg.post_process_advantages(policy, train_batch)
train_batch_ = pg.post_process_advantages(policy,
train_batch.copy())
if fw == "torch":
train_batch = policy._lazy_tensor_dict(train_batch)
train_batch_ = policy._lazy_tensor_dict(train_batch_)

# Check Advantage values.
check(train_batch[Postprocessing.ADVANTAGES], [2.9701, 1.99, 1.0])
check(train_batch_[Postprocessing.ADVANTAGES], [2.9701, 1.99, 1.0])

# Actual loss results.
if sess:
results = policy.get_session().run(
policy._loss,
feed_dict=policy._get_loss_inputs_dict(
train_batch, shuffle=False))
train_batch_, shuffle=False))
else:
results = (pg.pg_tf_loss
if fw in ["tf2", "tfe"] else pg.pg_torch_loss)(
policy,
policy.model,
dist_class=dist_cls,
train_batch=train_batch)
train_batch=train_batch_)

# Calculate expected results.
if fw != "torch":
expected_logits = fc(
fc(train_batch[SampleBatch.CUR_OBS],
fc(train_batch_[SampleBatch.OBS],
vars[0],
vars[1],
framework=fw),
Expand All @@ -96,16 +97,16 @@ def test_pg_loss_functions(self):
framework=fw)
else:
expected_logits = fc(
fc(train_batch[SampleBatch.CUR_OBS],
fc(train_batch_[SampleBatch.OBS],
vars[2],
vars[3],
framework=fw),
vars[0],
vars[1],
framework=fw)
expected_logp = dist_cls(expected_logits, policy.model).logp(
train_batch[SampleBatch.ACTIONS])
adv = train_batch[Postprocessing.ADVANTAGES]
train_batch_[SampleBatch.ACTIONS])
adv = train_batch_[Postprocessing.ADVANTAGES]
if sess:
expected_logp = sess.run(expected_logp)
elif fw == "torch":
Expand Down
2 changes: 2 additions & 0 deletions rllib/agents/ppo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
"vf_loss_coeff": 0.5,
"entropy_coeff": 0.01,
"entropy_coeff_schedule": None,
# Trajectory View API not supported for DD-PPO yet.
"_use_trajectory_view_api": False,
},
_allow_unknown_configs=True,
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ppo/ddppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"truncate_episodes": True,
# This is auto set based on sample batch size.
"train_batch_size": -1,
# Trajectory View API not supported yet for DD-PPO.
# Trajectory View API not supported for DD-PPO yet.
"_use_trajectory_view_api": False,
},
_allow_unknown_configs=True,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@
# Whether to fake GPUs (using CPUs).
# Set this to True for debugging on non-GPU machines (set `num_gpus` > 0).
"_fake_gpus": False,
# Switch on Trajectory View API for PPO by default.
# NOTE: Only supported for PyTorch so far.

# Use the new "trajectory view API" to collect samples and produce
# model- and policy inputs.
"_use_trajectory_view_api": True,
})

Expand Down
4 changes: 1 addition & 3 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from typing import Dict, List, Type, Union

import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping, \
view_requirements_fn
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
setup_config
from ray.rllib.evaluation.postprocessing import Postprocessing
Expand Down Expand Up @@ -271,5 +270,4 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
ValueNetworkMixin
],
view_requirements_fn=view_requirements_fn,
)
Loading

0 comments on commit 62c7ab5

Please sign in to comment.