Skip to content

Commit

Permalink
[RLlib] Fix Trainer.add_policy for num_workers>0 (self play example…
Browse files Browse the repository at this point in the history
… scripts). (ray-project#17566)
  • Loading branch information
sven1977 authored Aug 5, 2021
1 parent 0eb0e0f commit 3b44726
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 64 deletions.
38 changes: 22 additions & 16 deletions rllib/agents/tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import gym
from random import choice
import unittest

Expand All @@ -20,8 +19,6 @@ def tearDownClass(cls):
ray.shutdown()

def test_add_delete_policy(self):
env = gym.make("CartPole-v0")

config = pg.DEFAULT_CONFIG.copy()
config.update({
"env": MultiAgentCartPole,
Expand All @@ -30,34 +27,30 @@ def test_add_delete_policy(self):
"num_agents": 4,
},
},
"num_workers": 2, # Test on remote workers as well.
"multiagent": {
# Start with a single policy.
"policies": {
"p0": (None, env.observation_space, env.action_space, {}),
},
"policies": {"p0"},
"policy_mapping_fn": lambda aid, episode, **kwargs: "p0",
"policy_map_capacity": 2,
},
})

# TODO: (sven) this will work for tf, once we have the DynamicTFPolicy
# refactor PR merged.
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
for _ in framework_iterator(config):
trainer = pg.PGTrainer(config=config)
r = trainer.train()
self.assertTrue("p0" in r["policy_reward_min"])
for i in range(1, 4):
checkpoints = []
for i in range(1, 3):

def new_mapping_fn(agent_id, episode, **kwargs):
return f"p{choice([i, i - 1])}"

# Add a new policy.
pid = f"p{i}"
new_pol = trainer.add_policy(
f"p{i}",
pid,
trainer._policy_class,
observation_space=env.observation_space,
action_space=env.action_space,
config={},
# Test changing the mapping fn.
policy_mapping_fn=new_mapping_fn,
# Change the list of policies to train.
Expand All @@ -70,9 +63,22 @@ def new_mapping_fn(agent_id, episode, **kwargs):
self.assertTrue(len(pol_map) == i + 1)
r = trainer.train()
self.assertTrue("p1" in r["policy_reward_min"])
checkpoints.append(trainer.save())

# Test restoring from the checkpoint (which has more policies
# than what's defined in the config dict).
test = pg.PGTrainer(config=config)
test.restore(checkpoints[-1])
test.train()
# Test creating an action with the added (and restored) policy.
a = test.compute_single_action(
test.get_policy("p0").observation_space.sample(),
policy_id=pid)
self.assertTrue(test.get_policy("p0").action_space.contains(a))
test.stop()

# Delete all added policies again from trainer.
for i in range(3, 0, -1):
for i in range(2, 0, -1):
trainer.remove_policy(
f"p{i}",
policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i - 1}",
Expand Down Expand Up @@ -130,7 +136,7 @@ def test_evaluation_wo_evaluation_worker_set(self):

# Try again using `create_env_on_driver=True`.
# This force-adds the env on the local-worker, so this Trainer
# can `evaluate` even though, it doesn't have an evaluation-worker
# can `evaluate` even though it doesn't have an evaluation-worker
# set.
config["create_env_on_driver"] = True
trainer_w_env_on_driver = a3c.A3CTrainer(config=config)
Expand Down
27 changes: 21 additions & 6 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def compute_single_action(
unsquash_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
) -> TensorStructType:
"""Computes an action for the specified policy on the local Worker.
"""Computes an action for the specified policy on the local worker.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_single_action() on it
Expand Down Expand Up @@ -982,17 +982,31 @@ def compute_single_action(
any: The computed action if full_fetch=False, or
tuple: The full output of policy.compute_actions() if
full_fetch=True or we have an RNN-based Policy.
Raises:
KeyError: If the `policy_id` cannot be found in this Trainer's
local worker.
"""
policy = self.get_policy(policy_id)
if policy is None:
raise KeyError(
f"PolicyID '{policy_id}' not found in PolicyMap of the "
f"Trainer's local worker!")

local_worker = self.workers.local_worker()

if state is None:
state = []

# Check the preprocessor and preprocess, if necessary.
pp = self.workers.local_worker().preprocessors[policy_id]
pp = local_worker.preprocessors[policy_id]
if type(pp).__name__ != "NoPreprocessor":
observation = pp.transform(observation)
filtered_observation = self.workers.local_worker().filters[policy_id](
filtered_observation = local_worker.filters[policy_id](
observation, update=False)

result = self.get_policy(policy_id).compute_single_action(
# Compute the action.
result = policy.compute_single_action(
filtered_observation,
state,
prev_action,
Expand All @@ -1002,10 +1016,12 @@ def compute_single_action(
clip_actions=clip_actions,
explore=explore)

# Return 3-Tuple: Action, states, and extra-action fetches.
if state or full_fetch:
return result
# Ensure backward compatibility.
else:
return result[0] # backwards compatibility
return result[0]

@Deprecated(new="compute_single_action", error=False)
def compute_action(self, *args, **kwargs):
Expand Down Expand Up @@ -1193,7 +1209,6 @@ def fn(worker: RolloutWorker):
observation_space=observation_space,
action_space=action_space,
config=config,
policy_config=self.config,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
)
Expand Down
59 changes: 34 additions & 25 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import numpy as np
import gym
import logging
import pickle
import platform
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \
TYPE_CHECKING, Union

import ray
from ray import cloudpickle as pickle
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.external_env import ExternalEnv
Expand All @@ -29,7 +29,6 @@
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import force_list, merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
Expand Down Expand Up @@ -1057,7 +1056,6 @@ def add_policy(
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
config: Optional[PartialTrainerConfigDict] = None,
policy_config: Optional[TrainerConfigDict] = None,
policy_mapping_fn: Optional[Callable[
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
Expand Down Expand Up @@ -1095,14 +1093,16 @@ def add_policy(
policy_dict = _determine_spaces_for_multi_agent_dict(
{
policy_id: PolicySpec(policy_cls, observation_space,
action_space, config)
action_space, config or {})
},
self.env,
spaces=self.spaces,
policy_config=policy_config)

policy_config=self.policy_config,
)
self._build_policy_map(
policy_dict, policy_config, seed=policy_config.get("seed"))
policy_dict,
self.policy_config,
seed=self.policy_config.get("seed"))
new_policy = self.policy_map[policy_id]

self.filters[policy_id] = get_filter(
Expand Down Expand Up @@ -1244,24 +1244,40 @@ def get_filters(self, flush_after: bool = False) -> dict:
@DeveloperAPI
def save(self) -> bytes:
filters = self.get_filters(flush_after=True)
state = {
pid: self.policy_map[pid].get_state()
for pid in self.policy_map
}
return pickle.dumps({"filters": filters, "state": state})
state = {}
policy_specs = {}
for pid in self.policy_map:
state[pid] = self.policy_map[pid].get_state()
policy_specs[pid] = self.policy_map.policy_specs[pid]
return pickle.dumps({
"filters": filters,
"state": state,
"policy_specs": policy_specs,
})

@DeveloperAPI
def restore(self, objs: bytes) -> None:
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
for pid, state in objs["state"].items():
if pid not in self.policy_map:
logger.warning(
f"pid={pid} not found in policy_map! It was probably added"
" on-the-fly and is not part of the static `config."
"multiagent.policies` dict. Ignoring it for now.")
continue
self.policy_map[pid].set_state(state)
pol_spec = objs.get("policy_specs", {}).get(pid)
if not pol_spec:
logger.warning(
f"PolicyID '{pid}' was probably added on-the-fly (not"
" part of the static `multagent.policies` config) and"
" no PolicySpec objects found in the pickled policy "
"state. Will not add `{pid}`, but ignore it for now.")
else:
self.add_policy(
policy_id=pid,
policy_cls=pol_spec.policy_class,
observation_space=pol_spec.observation_space,
action_space=pol_spec.action_space,
config=pol_spec.config,
)
else:
self.policy_map[pid].set_state(state)

@DeveloperAPI
def set_global_vars(self, global_vars: dict) -> None:
Expand Down Expand Up @@ -1480,10 +1496,3 @@ def _validate_env(env: Any) -> EnvType:
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
"function returned {} ({}).".format(env, type(env)))
return env


def _has_tensorflow_graph(policy_dict: MultiAgentPolicyConfigDict) -> bool:
for policy, _, _, _ in policy_dict.values():
if issubclass(policy, TFPolicy):
return True
return False
18 changes: 14 additions & 4 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,12 @@ def new_episode(env_id):
extra_batch_callback,
env_id=env_id)
# Call each policy's Exploration.on_episode_start method.
# types: Policy
for p in worker.policy_map.values():
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_start(
policy=p,
Expand Down Expand Up @@ -902,8 +906,14 @@ def _process_observations(
if ma_sample_batch:
outputs.append(ma_sample_batch)

# Call each policy's Exploration.on_episode_end method.
for p in worker.policy_map.values():
# Call each (in-memory) policy's Exploration.on_episode_end
# method.
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_end(
policy=p,
Expand Down
2 changes: 2 additions & 0 deletions rllib/examples/self_play_with_open_spiel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
default="tf",
help="The DL framework specifier.")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument(
"--from-checkpoint",
type=str,
Expand Down Expand Up @@ -197,6 +198,7 @@ def policy_mapping_fn(agent_id, episode, **kwargs):
# Always just train the "main" policy.
"policies_to_train": ["main"],
},
"num_workers": args.num_workers,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": args.framework,
Expand Down
31 changes: 19 additions & 12 deletions rllib/policy/policy_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import gym
import os
import pickle
from typing import Callable, Dict, Optional, Type, TYPE_CHECKING
from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING

from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -39,10 +39,19 @@ def __init__(
"""Initializes a PolicyMap instance.
Args:
maxlen (int): The maximum number of policies to hold in memory.
worker_index (int): The worker index of the RolloutWorker this map
resides in.
num_workers (int): The total number of remote workers in the
WorkerSet to which this map's RolloutWorker belongs to.
capacity (int): The maximum number of policies to hold in memory.
The least used ones are written to disk/S3 and retrieved
when needed.
path (str):
path (str): The path to store the policy pickle files to. Files
will have the name: [policy_id].[worker idx].policy.pkl.
policy_config (TrainerConfigDict): The Trainer's base config dict.
session_creator (Optional[Callable[[], tf1.Session]): An optional
tf1.Session creation callable.
seed (int): An optional seed (used to seed tf policies).
"""
super().__init__()

Expand All @@ -53,22 +62,22 @@ def __init__(

# The file extension for stashed policies (that are no longer available
# in-memory but can be reinstated any time from storage).
self.extension = ".policy.pkl"
self.extension = f".{self.worker_index}.policy.pkl"

# Dictionary of keys that may be looked up (cached or not).
self.valid_keys = set()
self.valid_keys: Set[str] = set()
# The actual cache with the in-memory policy objects.
self.cache = {}
self.cache: Dict[str, Policy] = {}
# The doubly-linked list holding the currently in-memory objects.
self.deque = deque(maxlen=capacity or 10)
# The file path where to store overflowing policies.
self.path = path or "."
# The core config to use. Each single policy's config override is
# added on top of this.
self.policy_config = policy_config or {}
self.policy_config: TrainerConfigDict = policy_config or {}
# The orig classes/obs+act spaces, and config overrides of the
# Policies.
self.policy_specs = {} # type: Dict[PolicyID, PolicySpec]
self.policy_specs: Dict[PolicyID, PolicySpec] = {}

def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"],
observation_space: gym.Space, action_space: gym.Space,
Expand Down Expand Up @@ -140,7 +149,7 @@ def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"],
def __getitem__(self, item):
# Never seen this key -> Error.
if item not in self.valid_keys:
raise KeyError(f"'{item}' not a valid key!")
raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!")

# Item already in cache -> Rearrange deque (least recently used) and
# return.
Expand Down Expand Up @@ -250,9 +259,7 @@ def _stash_to_disk(self):
policy_state = policy.get_state()
# Closes policy's tf session, if any.
self._close_session(policy)
# Remove from memory.
# TODO: (sven) This should clear the tf Graph as well, if the Trainer
# would not hold parts of the graph (e.g. in tf multi-GPU setups).
# Remove from memory. This will clear the tf Graph as well.
del self.cache[delkey]
# Write state to disk.
with open(self.path + "/" + delkey + self.extension, "wb") as f:
Expand Down
Loading

0 comments on commit 3b44726

Please sign in to comment.