Skip to content

Commit

Permalink
Merge pull request #454 from instadeepai/feat/single-vs-multiple-optim
Browse files Browse the repository at this point in the history
Feature/MAPPO Obs Networks Fix + Multiple Optims
  • Loading branch information
KaleabTessera authored Mar 23, 2022
2 parents 5da3856 + 82116fe commit 1021b3f
Show file tree
Hide file tree
Showing 13 changed files with 227 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ repos:
additional_dependencies: [flake8-isort]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910
rev: v0.941
hooks:
- id: mypy
exclude: ^docs/
Expand Down
23 changes: 15 additions & 8 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ We include a number of systems running on continuous control tasks.

- **MAD4PG**:
a MAD4PG system running on the RoboCup environment.
- *Recurrent*
- *Recurrent*
- [state_based][robocup_mad4pg_ff_state_based].

## Discrete control
Expand All @@ -83,26 +83,31 @@ We also include a number of systems running on discrete action space environment

- **VDN**:
a VDN system running on the discrete action space simple_spread MPE environment.
- *Recurrent*
- *Recurrent*
- [centralised][debug_vdn_rec_cen].

### PettingZoo - Multi-Agent Atari

- **MADQN**:
a MADQN system running on the two-player competitive Atari Pong environment.
- *Recurrent*
- *Recurrent*
- [decentralised][pz_madqn_pong_rec_dec].

- **MAPPO**:
a MAPPO system running on two-player cooperative Atari Pong.
- *feedforward*
- [decentralised][pz_mappo_coop_pong_ff_dec].

### PettingZoo - Multi-Agent Particle Environment

- **MADDPG**:
a MADDPG system running on the Simple Speaker Listener environment.
- *Feedforward*
- *Feedforward*
- [decentralised][pz_maddpg_mpe_ssl_ff_dec].

- **MADDPG**:
a MADDPG system running on the Simple Spread environment.
- *Feedforward*
- *Feedforward*
- [decentralised][pz_maddpg_mpe_ss_ff_dec].

### SMAC - StarCraft Multi-Agent Challenge
Expand All @@ -116,19 +121,19 @@ We also include a number of systems running on discrete action space environment

- **QMIX**:
a QMIX system running on the SMAC environment.
- *Recurrent*
- *Recurrent*
- [centralised][smac_qmix_rec_cen].

- **VDN**:
a VDN system running on the SMAC environment.
- *Recurrent*
- *Recurrent*
- [centralised][smac_vdn_rec_cen].

### OpenSpiel - Tic Tac Toe

- **MADQN**:
a MADQN system running on the OpenSpiel environment.
- *Feedforward*
- *Feedforward*
- [decentralised][openspiel_madqn_ff_dec].

<!-- Examples -->
Expand Down Expand Up @@ -179,6 +184,8 @@ We also include a number of systems running on discrete action space environment

[pz_madqn_pong_rec_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py

[pz_mappo_coop_pong_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/petting_zoo/butterfly/cooperative_pong/feedforward/decentralised/run_mappo.py

[pz_maddpg_mpe_ssl_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/petting_zoo/mpe/simple_speaker_listener/feedforward/decentralised/run_maddpg.py

[pz_maddpg_mpe_ss_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/petting_zoo/mpe/simple_spread/feedforward/decentralised/run_maddpg.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def main(_: Any) -> None:
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=1,
optimizer=snt.optimizers.Adam(learning_rate=5e-4),
policy_optimizer=snt.optimizers.Adam(learning_rate=5e-4),
critic_optimizer=snt.optimizers.Adam(learning_rate=5e-4),
checkpoint_subpath=checkpoint_dir,
max_gradient_norm=40.0,
architecture=architectures.CentralisedValueCritic,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example running MAPPO on Cooperative Atari Pong."""

import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import numpy as np
from absl import app, flags
from acme.tf.networks import AtariTorso
from supersuit import dtype_v0

from mava.systems.tf import mappo
from mava.utils import lp_utils
from mava.utils.environments import pettingzoo_utils
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS
flags.DEFINE_string(
"env_class",
"butterfly",
"Pettingzoo environment class, e.g. atari (str).",
)
flags.DEFINE_string(
"env_name",
"cooperative_pong_v3",
"Pettingzoo environment name, e.g. pong (str).",
)

flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.")


def main(_: Any) -> None:
"""Run example."""

# Environment
environment_factory = functools.partial(
pettingzoo_utils.make_environment,
env_class=FLAGS.env_class,
env_name=FLAGS.env_name,
env_preprocess_wrappers=[(dtype_v0, {"dtype": np.float32})],
)

# Networks.
network_factory = lp_utils.partial_kwargs(
mappo.make_default_networks, observation_network=AtariTorso()
)

# Checkpointer appends "Checkpoints" to checkpoint_dir
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# Distributed program
program = mappo.MAPPO(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=1,
checkpoint_subpath=checkpoint_dir,
num_epochs=5,
batch_size=32,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

# Launch.
lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
10 changes: 7 additions & 3 deletions mava/systems/tf/mappo/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class MAPPOConfig:
Args:
environment_spec: description of the action and observation spaces etc. for
each agent in the system.
optimizer: optimizer(s) for updating networks.
policy_optimizer: optimizer(s) for updating policy networks.
critic_optimizer: optimizer for updating critic networks. This is not
used if using single optim.
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
checkpoint_minute_interval (int): The number of minutes to wait between
Expand Down Expand Up @@ -74,7 +76,8 @@ class MAPPOConfig:
"""

environment_spec: specs.EnvironmentSpec
optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]]
policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]]
critic_optimizer: snt.Optimizer
agent_net_keys: Dict[str, str]
checkpoint_minute_interval: int
sequence_length: int = 10
Expand Down Expand Up @@ -319,7 +322,8 @@ def make_trainer(
critic_networks=critic_networks,
dataset=dataset,
agent_net_keys=agent_net_keys,
optimizer=self._config.optimizer,
critic_optimizer=self._config.critic_optimizer,
policy_optimizer=self._config.policy_optimizer,
minibatch_size=self._config.minibatch_size,
num_epochs=self._config.num_epochs,
discount=self._config.discount,
Expand Down
6 changes: 5 additions & 1 deletion mava/systems/tf/mappo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def make_default_networks(
256,
),
critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256),
observation_network: snt.Module = None,
seed: Optional[int] = None,
) -> Dict[str, snt.Module]:
"""Default networks for mappo.
Expand Down Expand Up @@ -81,7 +82,10 @@ def make_default_networks(
for key in specs.keys():

# Create the shared observation network; here simply a state-less operation.
observation_network = tf2_utils.to_sonnet_module(tf.identity)
if observation_network is None:
observation_network = tf2_utils.to_sonnet_module(tf.identity)
else:
observation_network = observation_network

# Note: The discrete case must be placed first as it inherits from BoundedArray.
if isinstance(specs[key].actions, dm_env.specs.DiscreteArray): # discrete
Expand Down
23 changes: 17 additions & 6 deletions mava/systems/tf/mappo/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def __init__(
shared_weights: bool = True,
agent_net_keys: Dict[str, str] = {},
executor_variable_update_period: int = 100,
optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] = snt.optimizers.Adam(
policy_optimizer: Union[
snt.Optimizer, Dict[str, snt.Optimizer]
] = snt.optimizers.Adam(learning_rate=5e-4),
critic_optimizer: Optional[snt.Optimizer] = snt.optimizers.Adam(
learning_rate=5e-4
),
discount: float = 0.99,
Expand All @@ -82,7 +85,7 @@ def __init__(
train_loop_fn_kwargs: Dict = {},
eval_loop_fn_kwargs: Dict = {},
evaluator_interval: Optional[dict] = None,
learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None,
learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None,
normalize_advantage: bool = False,
):
"""Initialise the system
Expand Down Expand Up @@ -114,8 +117,10 @@ def __init__(
Defaults to {}.
executor_variable_update_period : number of steps before
updating executor variables from the variable source. Defaults to 100.
optimizer : optimizer(s) for updating networks.
policy_optimizer : optimizer(s) for updating policy networks.
Defaults to snt.optimizers.Adam(learning_rate=5e-4).
critic_optimizer : optimizer for updating critic
networks. This is not used if using single optim.
discount : discount factor to use for TD updates. Defaults
to 0.99.
lambda_gae : scalar determining the mix of bootstrapping
Expand Down Expand Up @@ -160,8 +165,13 @@ def __init__(
to the training loop. Defaults to {}.
eval_loop_fn_kwargs: possible keyword arguments to send to
the evaluation loop. Defaults to {}.
learning_rate_scheduler_fn: an optional learning rate scheduler for
the optimiser.
learning_rate_scheduler_fn: dict with two functions/classes (one for the
policy and one for the critic optimizer), that takes in a trainer
step t and returns the current learning rate,
e.g. {"policy": policy_lr_schedule ,"critic": critic_lr_schedule}.
See
examples/debugging/simple_spread/feedforward/decentralised/run_maddpg_lr_schedule.py
for an example.
evaluator_interval: An optional condition that is used to
evaluate/test system performance after [evaluator_interval]
condition has been met. If None, evaluation will
Expand Down Expand Up @@ -267,7 +277,8 @@ def __init__(
sequence_length=self._sequence_length,
sequence_period=self._sequence_period,
checkpoint=checkpoint,
optimizer=optimizer,
policy_optimizer=policy_optimizer,
critic_optimizer=critic_optimizer,
checkpoint_subpath=checkpoint_subpath,
checkpoint_minute_interval=checkpoint_minute_interval,
evaluator_interval=evaluator_interval,
Expand Down
Loading

0 comments on commit 1021b3f

Please sign in to comment.