Skip to content

Commit

Permalink
mrege: Merge remote-tracking branch 'origin/develop' into feature/sma…
Browse files Browse the repository at this point in the history
…c-env-upgrades
  • Loading branch information
KaleabTessera committed Nov 8, 2021
2 parents c93232a + 57ae3c8 commit 80bdefb
Show file tree
Hide file tree
Showing 52 changed files with 2,696 additions and 1,260 deletions.
19 changes: 0 additions & 19 deletions .pylintrc

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def main(_: Any) -> None:
# Networks.
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks,
vmin=-10,
vmax=50,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def main(_: Any) -> None:
)

# Networks.
network_factory = lp_utils.partial_kwargs(mad4pg.make_default_networks)
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks,
vmin=-10,
vmax=50,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example running feedforward MADDPG on debug MPE environments.
NB: Using multiple trainers with non-shared weights is still in its
experimental phase of development. This feature will become faster and
more stable in future Mava updates."""

import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import sonnet as snt
from absl import app, flags

from mava.systems.tf import maddpg
from mava.systems.tf.maddpg import make_default_networks
from mava.utils import enums, lp_utils
from mava.utils.environments import debugging_utils
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS
flags.DEFINE_string(
"env_name",
"simple_spread",
"Debugging environment name (str).",
)
flags.DEFINE_string(
"action_space",
"continuous",
"Environment action space type (str).",
)
flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "~/mava/", "Base dir to store experiments.")


def main(_: Any) -> None:

# environment
environment_factory = functools.partial(
debugging_utils.make_environment,
env_name=FLAGS.env_name,
action_space=FLAGS.action_space,
)

# networks
network_factory = lp_utils.partial_kwargs(make_default_networks)

# Checkpointer appends "Checkpoints" to checkpoint_dir
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# distributed program
"""NB: Using multiple trainers with non-shared weights is still in its
experimental phase of development. This feature will become faster and
more stable in future Mava updates."""
program = maddpg.MADDPG(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=2,
shared_weights=False,
trainer_networks=enums.Trainer.one_trainer_per_network,
network_sampling_setup=enums.NetworkSampler.fixed_agent_networks,
policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
max_gradient_norm=40.0,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def main(_: Any) -> None:
# Networks.
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks,
vmin=-10,
vmax=50,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def main(_: Any) -> None:

# Networks.
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks, archecture_type=ArchitectureType.recurrent
mad4pg.make_default_networks,
vmin=-10,
vmax=50,
archecture_type=ArchitectureType.recurrent,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ def main(_: Any) -> None:
)

# Networks.
network_factory = lp_utils.partial_kwargs(mad4pg.make_default_networks)
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks,
vmin=-150,
vmax=150,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def main(_: Any) -> None:
)

# Networks.
network_factory = lp_utils.partial_kwargs(mad4pg.make_default_networks)
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks,
vmin=-150,
vmax=150,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"
Expand Down
5 changes: 4 additions & 1 deletion examples/robocup/recurrent/state_based/run_mad4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def main(_: Any) -> None:

# Networks.
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks, archecture_type=ArchitectureType.recurrent
mad4pg.make_default_networks,
archecture_type=ArchitectureType.recurrent,
vmin=-5,
vmax=5,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
Expand Down
Loading

0 comments on commit 80bdefb

Please sign in to comment.