forked from instadeepai/Mava
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' into feature/eval-intervals
- Loading branch information
Showing
6 changed files
with
725 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# python3 | ||
# Copyright 2021 InstaDeep Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import functools | ||
from datetime import datetime | ||
from typing import Any | ||
|
||
import launchpad as lp | ||
import sonnet as snt | ||
from absl import app, flags | ||
|
||
from mava.components.tf.modules.exploration.exploration_scheduling import ( | ||
LinearExplorationScheduler, | ||
) | ||
from mava.systems.tf import madqn | ||
from mava.utils import lp_utils | ||
from mava.utils.environments.meltingpot_utils import EnvironmentFactory | ||
from mava.utils.loggers import logger_utils | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string( | ||
"mava_id", | ||
str(datetime.now()), | ||
"Experiment identifier that can be used to continue experiments.", | ||
) | ||
flags.DEFINE_string("base_dir", "./logs", "Base dir to store experiments.") | ||
flags.DEFINE_string("scenario", "clean_up_0", "scenario to evaluste on") | ||
|
||
|
||
def main(_: Any) -> None: | ||
"""Evaluate on a scenario | ||
Args: | ||
_ (Any): ... | ||
""" | ||
|
||
# Environment. | ||
environment_factory = EnvironmentFactory(scenario=FLAGS.scenario) | ||
|
||
# Networks. | ||
network_factory = lp_utils.partial_kwargs(madqn.make_default_networks) | ||
|
||
# Checkpointer appends "Checkpoints" to checkpoint_dir | ||
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" | ||
|
||
# Log every [log_every] seconds. | ||
log_every = 10 | ||
logger_factory = functools.partial( | ||
logger_utils.make_logger, | ||
directory=FLAGS.base_dir, | ||
to_terminal=True, | ||
to_tensorboard=True, | ||
time_stamp=FLAGS.mava_id, | ||
time_delta=log_every, | ||
) | ||
|
||
# distributed program | ||
program = madqn.MADQN( | ||
environment_factory=environment_factory, # type: ignore | ||
network_factory=network_factory, | ||
logger_factory=logger_factory, | ||
num_executors=1, | ||
exploration_scheduler_fn=LinearExplorationScheduler, | ||
epsilon_min=0.05, | ||
epsilon_decay=1e-4, | ||
importance_sampling_exponent=0.2, | ||
optimizer=snt.optimizers.Adam(learning_rate=1e-4), | ||
checkpoint_subpath=checkpoint_dir, | ||
).build() | ||
|
||
# Ensure only trainer runs on gpu, while other processes run on cpu. | ||
local_resources = lp_utils.to_device( | ||
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] | ||
) | ||
|
||
# Launch. | ||
lp.launch( | ||
program, | ||
lp.LaunchType.LOCAL_MULTI_PROCESSING, | ||
terminal="current_terminal", | ||
local_resources=local_resources, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# python3 | ||
# Copyright 2021 InstaDeep Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import functools | ||
from datetime import datetime | ||
from typing import Any | ||
|
||
import launchpad as lp | ||
import sonnet as snt | ||
from absl import app, flags | ||
|
||
from mava.components.tf.modules.exploration.exploration_scheduling import ( | ||
LinearExplorationScheduler, | ||
) | ||
from mava.systems.tf import madqn | ||
from mava.utils import lp_utils | ||
from mava.utils.environments.meltingpot_utils import EnvironmentFactory | ||
from mava.utils.loggers import logger_utils | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string( | ||
"mava_id", | ||
str(datetime.now()), | ||
"Experiment identifier that can be used to continue experiments.", | ||
) | ||
flags.DEFINE_string("base_dir", "./logs", "Base dir to store experiments.") | ||
flags.DEFINE_string("substrate", "clean_up", "substrate to train on.") | ||
|
||
|
||
def main(_: Any) -> None: | ||
"""Train on substrate | ||
Args: | ||
_ (Any): ... | ||
""" | ||
# Environment. | ||
environment_factory = EnvironmentFactory(substrate=FLAGS.substrate) | ||
|
||
# Networks. | ||
network_factory = lp_utils.partial_kwargs(madqn.make_default_networks) | ||
|
||
# Checkpointer appends "Checkpoints" to checkpoint_dir | ||
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" | ||
|
||
# Log every [log_every] seconds. | ||
log_every = 10 | ||
logger_factory = functools.partial( | ||
logger_utils.make_logger, | ||
directory=FLAGS.base_dir, | ||
to_terminal=True, | ||
to_tensorboard=True, | ||
time_stamp=FLAGS.mava_id, | ||
time_delta=log_every, | ||
) | ||
|
||
# distributed program | ||
program = madqn.MADQN( | ||
environment_factory=environment_factory, | ||
network_factory=network_factory, | ||
logger_factory=logger_factory, | ||
num_executors=1, | ||
exploration_scheduler_fn=LinearExplorationScheduler, | ||
epsilon_min=0.05, | ||
epsilon_decay=1e-4, | ||
importance_sampling_exponent=0.2, | ||
optimizer=snt.optimizers.Adam(learning_rate=1e-4), | ||
checkpoint_subpath=checkpoint_dir, | ||
).build() | ||
|
||
# Ensure only trainer runs on gpu, while other processes run on cpu. | ||
local_resources = lp_utils.to_device( | ||
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] | ||
) | ||
|
||
# Launch. | ||
lp.launch( | ||
program, | ||
lp.LaunchType.LOCAL_MULTI_PROCESSING, | ||
terminal="current_terminal", | ||
local_resources=local_resources, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# python3 | ||
# Copyright 2021 InstaDeep Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Union | ||
|
||
try: | ||
from meltingpot.python import scenario, substrate # type: ignore | ||
from meltingpot.python.scenario import AVAILABLE_SCENARIOS, Scenario # type: ignore | ||
from meltingpot.python.substrate import ( # type: ignore | ||
AVAILABLE_SUBSTRATES, | ||
Substrate, | ||
) | ||
from ml_collections import config_dict # type: ignore | ||
|
||
from mava.wrappers.meltingpot import MeltingpotEnvWrapper | ||
except ModuleNotFoundError: | ||
Scenario = Any | ||
Substrate = Any | ||
|
||
|
||
class EnvironmentFactory: | ||
def __init__(self, substrate: str = None, scenario: str = None): | ||
"""Initializes the env factory object | ||
sets the substrate/scenario using the available ones in meltingpot | ||
Args: | ||
substrate (str, optional): what substrate to use. Defaults to None. | ||
scenario (str, optional): what scenario to use. Defaults to None. | ||
""" | ||
assert (substrate is None) or ( | ||
scenario is None | ||
), "substrate or scenario must be specified" | ||
assert not ( | ||
substrate is not None and scenario is not None | ||
), "Cannot specify both substrate and scenario" | ||
|
||
if substrate is not None: | ||
substrates = [*AVAILABLE_SUBSTRATES] | ||
assert ( | ||
substrate in substrates | ||
), f"substrate cannot be f{substrate}, use any of {substrates}" | ||
self._substrate_name = substrate | ||
self._env_fn = self._substrate | ||
|
||
elif scenario is not None: | ||
scenarios = [*[k for k in AVAILABLE_SCENARIOS]] | ||
assert ( | ||
scenario in scenarios | ||
), f"substrate cannot be f{substrate}, use any of {scenarios}" | ||
self._scenario_name = scenario | ||
self._env_fn = self._scenario | ||
|
||
def _substrate(self) -> Substrate: | ||
"""Returns a substrate as an environment | ||
Returns: | ||
[Substrate]: A substrate | ||
""" | ||
env = load_substrate(self._substrate_name) | ||
return MeltingpotEnvWrapper(env) | ||
|
||
def _scenario(self) -> Scenario: | ||
"""Returns a scenario as an environment | ||
Returns: | ||
[Scenario]: A scenario or None | ||
""" | ||
|
||
env = load_scenario(self._scenario_name) | ||
return MeltingpotEnvWrapper(env) | ||
|
||
def __call__(self, evaluation: bool = False) -> Union[Substrate, Scenario]: | ||
"""Creates an environment | ||
Returns: | ||
(Union[Substrate, Scenario]): The created environment | ||
""" | ||
env = self._env_fn() # type: ignore | ||
return env | ||
|
||
|
||
def load_substrate(substrate_name: str) -> Substrate: | ||
"""Loads a substrate from the available substrates | ||
Args: | ||
substrate_name (str): substrate name | ||
Returns: | ||
Substrate: A multi-agent environment | ||
""" | ||
config = substrate.get_config(substrate_name) | ||
env_config = config_dict.ConfigDict(config) | ||
|
||
return substrate.build(env_config) | ||
|
||
|
||
def load_scenario(scenario_name: str) -> Scenario: | ||
"""Loads a substrate from the available substrates | ||
Args: | ||
scenerio_name (str): scenario name | ||
Returns: | ||
Scenario: A multi-agent environment with background bots | ||
""" | ||
config = scenario.get_config(scenario_name) | ||
env_config = config_dict.ConfigDict(config) | ||
|
||
return scenario.build(env_config) |
Oops, something went wrong.