Skip to content

Commit

Permalink
Merge branch 'develop' into feature/eval-intervals
Browse files Browse the repository at this point in the history
  • Loading branch information
arnupretorius authored Nov 18, 2021
2 parents c393002 + fa3a5ea commit 52cd450
Show file tree
Hide file tree
Showing 6 changed files with 725 additions and 0 deletions.
6 changes: 6 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ ENV SC2PATH /home/app/mava/3rdparty/StarCraftII
COPY . /home/app/mava
RUN python -m pip uninstall -y enum34
RUN python -m pip install --upgrade pip

# pyparsing is required as a prerequisite to the flatland install.
# The actual package installation order does not seem to correlate
# with the order of packages in flatland_requirements (system.py).
# Therefore the package is manually installed here.
RUN pip install pyparsing==3.0.3
RUN python -m pip install -e .[flatland]
RUN python -m pip install -e .[open_spiel]
RUN python -m pip install -e .[tf,envs,reverb,launchpad,testing_formatting,record_episode]
Expand Down
100 changes: 100 additions & 0 deletions examples/meltingpot/test_on_scenarios.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import sonnet as snt
from absl import app, flags

from mava.components.tf.modules.exploration.exploration_scheduling import (
LinearExplorationScheduler,
)
from mava.systems.tf import madqn
from mava.utils import lp_utils
from mava.utils.environments.meltingpot_utils import EnvironmentFactory
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS

flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "./logs", "Base dir to store experiments.")
flags.DEFINE_string("scenario", "clean_up_0", "scenario to evaluste on")


def main(_: Any) -> None:
"""Evaluate on a scenario
Args:
_ (Any): ...
"""

# Environment.
environment_factory = EnvironmentFactory(scenario=FLAGS.scenario)

# Networks.
network_factory = lp_utils.partial_kwargs(madqn.make_default_networks)

# Checkpointer appends "Checkpoints" to checkpoint_dir
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# distributed program
program = madqn.MADQN(
environment_factory=environment_factory, # type: ignore
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=1,
exploration_scheduler_fn=LinearExplorationScheduler,
epsilon_min=0.05,
epsilon_decay=1e-4,
importance_sampling_exponent=0.2,
optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

# Launch.
lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
99 changes: 99 additions & 0 deletions examples/meltingpot/train_on_substrates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import sonnet as snt
from absl import app, flags

from mava.components.tf.modules.exploration.exploration_scheduling import (
LinearExplorationScheduler,
)
from mava.systems.tf import madqn
from mava.utils import lp_utils
from mava.utils.environments.meltingpot_utils import EnvironmentFactory
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS

flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "./logs", "Base dir to store experiments.")
flags.DEFINE_string("substrate", "clean_up", "substrate to train on.")


def main(_: Any) -> None:
"""Train on substrate
Args:
_ (Any): ...
"""
# Environment.
environment_factory = EnvironmentFactory(substrate=FLAGS.substrate)

# Networks.
network_factory = lp_utils.partial_kwargs(madqn.make_default_networks)

# Checkpointer appends "Checkpoints" to checkpoint_dir
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# distributed program
program = madqn.MADQN(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=1,
exploration_scheduler_fn=LinearExplorationScheduler,
epsilon_min=0.05,
epsilon_decay=1e-4,
importance_sampling_exponent=0.2,
optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

# Launch.
lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
122 changes: 122 additions & 0 deletions mava/utils/environments/meltingpot_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Union

try:
from meltingpot.python import scenario, substrate # type: ignore
from meltingpot.python.scenario import AVAILABLE_SCENARIOS, Scenario # type: ignore
from meltingpot.python.substrate import ( # type: ignore
AVAILABLE_SUBSTRATES,
Substrate,
)
from ml_collections import config_dict # type: ignore

from mava.wrappers.meltingpot import MeltingpotEnvWrapper
except ModuleNotFoundError:
Scenario = Any
Substrate = Any


class EnvironmentFactory:
def __init__(self, substrate: str = None, scenario: str = None):
"""Initializes the env factory object
sets the substrate/scenario using the available ones in meltingpot
Args:
substrate (str, optional): what substrate to use. Defaults to None.
scenario (str, optional): what scenario to use. Defaults to None.
"""
assert (substrate is None) or (
scenario is None
), "substrate or scenario must be specified"
assert not (
substrate is not None and scenario is not None
), "Cannot specify both substrate and scenario"

if substrate is not None:
substrates = [*AVAILABLE_SUBSTRATES]
assert (
substrate in substrates
), f"substrate cannot be f{substrate}, use any of {substrates}"
self._substrate_name = substrate
self._env_fn = self._substrate

elif scenario is not None:
scenarios = [*[k for k in AVAILABLE_SCENARIOS]]
assert (
scenario in scenarios
), f"substrate cannot be f{substrate}, use any of {scenarios}"
self._scenario_name = scenario
self._env_fn = self._scenario

def _substrate(self) -> Substrate:
"""Returns a substrate as an environment
Returns:
[Substrate]: A substrate
"""
env = load_substrate(self._substrate_name)
return MeltingpotEnvWrapper(env)

def _scenario(self) -> Scenario:
"""Returns a scenario as an environment
Returns:
[Scenario]: A scenario or None
"""

env = load_scenario(self._scenario_name)
return MeltingpotEnvWrapper(env)

def __call__(self, evaluation: bool = False) -> Union[Substrate, Scenario]:
"""Creates an environment
Returns:
(Union[Substrate, Scenario]): The created environment
"""
env = self._env_fn() # type: ignore
return env


def load_substrate(substrate_name: str) -> Substrate:
"""Loads a substrate from the available substrates
Args:
substrate_name (str): substrate name
Returns:
Substrate: A multi-agent environment
"""
config = substrate.get_config(substrate_name)
env_config = config_dict.ConfigDict(config)

return substrate.build(env_config)


def load_scenario(scenario_name: str) -> Scenario:
"""Loads a substrate from the available substrates
Args:
scenerio_name (str): scenario name
Returns:
Scenario: A multi-agent environment with background bots
"""
config = scenario.get_config(scenario_name)
env_config = config_dict.ConfigDict(config)

return scenario.build(env_config)
Loading

0 comments on commit 52cd450

Please sign in to comment.