diff --git a/Dockerfile.tf b/Dockerfile.tf index 3b459aad6..05487eecf 100755 --- a/Dockerfile.tf +++ b/Dockerfile.tf @@ -43,12 +43,11 @@ RUN AutoROM -v ########################################################## # SMAC image FROM tf-core AS sc2 -## Install SC2 game -RUN apt-get install -y wget ## Install smac environment RUN apt-get -y install git RUN pip install .[sc2] -RUN ./bash_scripts/install_sc2.sh +# We use the pz wrapper for smac +RUN pip install .[pz] ENV SC2PATH /home/app/mava/3rdparty/StarCraftII ########################################################## diff --git a/Makefile b/Makefile index cda535b1a..158630c37 100755 --- a/Makefile +++ b/Makefile @@ -42,7 +42,7 @@ endif IMAGE = $(DOCKER_IMAGE_NAME):$(DOCKER_IMAGE_TAG) # make file commands build: - DOCKER_BUILDKIT=1 docker build --tag $(IMAGE) -f Dockerfile.tf --target $(DOCKER_IMAGE_TAG) --build-arg record=$(record) --progress=plain . + DOCKER_BUILDKIT=1 docker build --tag $(IMAGE) -f Dockerfile.tf --target $(DOCKER_IMAGE_TAG) --build-arg record=$(record) . run: $(DOCKER_RUN) python $(example) --base_dir /home/app/mava/logs/ @@ -62,6 +62,9 @@ run-tests: run-integration-tests: $(DOCKER_RUN) /bin/bash bash_scripts/tests.sh true +run-checks: + $(DOCKER_RUN) /bin/bash bash_scripts/check_format.sh + push: docker login -docker push $(IMAGE) diff --git a/bash_scripts/install_sc2.sh b/bash_scripts/install_sc2.sh index 821202c9c..a6a70715e 100755 --- a/bash_scripts/install_sc2.sh +++ b/bash_scripts/install_sc2.sh @@ -10,8 +10,8 @@ echo 'SC2PATH is set to '$SC2PATH if [ ! -d $SC2PATH ]; then echo 'StarCraftII is not installed. Installing now ...'; - wget http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip --progress=dot -e dotbytes=100M - unzip -qq iagreetotheeula SC2.4.10.zip + wget http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip --progress=dot -e dotbytes=50M + unzip -P iagreetotheeula SC2.4.10.zip rm -rf SC2.4.10.zip else echo 'StarCraftII is already installed.' diff --git a/examples/smac/feedforward/decentralised/run_madqn.py b/examples/smac/feedforward/decentralised/run_madqn.py index cffa87b3d..317a1298d 100644 --- a/examples/smac/feedforward/decentralised/run_madqn.py +++ b/examples/smac/feedforward/decentralised/run_madqn.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Example running MADQN on multi-agent Starcraft 2 (SMAC) environment.""" import functools from datetime import datetime @@ -28,7 +27,7 @@ ) from mava.systems.tf import madqn from mava.utils import lp_utils -from mava.utils.environments import smac_utils +from mava.utils.environments import pettingzoo_utils from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS @@ -47,10 +46,10 @@ def main(_: Any) -> None: - + """Example running MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment environment_factory = functools.partial( - smac_utils.make_environment, map_name=FLAGS.map_name + pettingzoo_utils.make_environment, env_class="smac", env_name=FLAGS.map_name ) # Networks. diff --git a/examples/smac/feedforward/decentralised/run_qmix.py b/examples/smac/feedforward/decentralised/run_qmix.py index 0fedcfe39..2ba3d2b0c 100644 --- a/examples/smac/feedforward/decentralised/run_qmix.py +++ b/examples/smac/feedforward/decentralised/run_qmix.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Example running QMIX on SMAC environments.""" + import functools from datetime import datetime from typing import Any @@ -25,7 +25,7 @@ from mava.components.tf.modules.exploration import LinearExplorationTimestepScheduler from mava.systems.tf import qmix from mava.utils import lp_utils -from mava.utils.environments import smac_utils +from mava.utils.environments import pettingzoo_utils from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS @@ -44,9 +44,10 @@ def main(_: Any) -> None: + """Example running QMIX on SMAC environments.""" # Environment. environment_factory = functools.partial( - smac_utils.make_environment, map_name=FLAGS.map_name + pettingzoo_utils.make_environment, env_class="smac", env_name=FLAGS.map_name ) # Networks. diff --git a/examples/smac/feedforward/decentralised/run_vdn.py b/examples/smac/feedforward/decentralised/run_vdn.py index cba480467..3b087e7e8 100644 --- a/examples/smac/feedforward/decentralised/run_vdn.py +++ b/examples/smac/feedforward/decentralised/run_vdn.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Example running VDN on multi-agent Starcraft 2 (SMAC) environment.""" import functools from datetime import datetime @@ -26,7 +25,7 @@ from mava.components.tf.modules.exploration import LinearExplorationTimestepScheduler from mava.systems.tf import vdn from mava.utils import lp_utils -from mava.utils.environments import smac_utils +from mava.utils.environments import pettingzoo_utils from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS @@ -45,10 +44,10 @@ def main(_: Any) -> None: - + """Example running VDN on multi-agent Starcraft 2 (SMAC) environment.""" # environment environment_factory = functools.partial( - smac_utils.make_environment, map_name=FLAGS.map_name + pettingzoo_utils.make_environment, env_class="smac", env_name=FLAGS.map_name ) # Networks. diff --git a/examples/smac/feedforward/decentralised/run_vdn_record.py b/examples/smac/feedforward/decentralised/run_vdn_record.py index 83419ad96..90a16507e 100644 --- a/examples/smac/feedforward/decentralised/run_vdn_record.py +++ b/examples/smac/feedforward/decentralised/run_vdn_record.py @@ -13,10 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Example running VDN on multi-agent Starcraft 2 (SMAC) environment, -while recording agents. -""" import functools from datetime import datetime @@ -29,7 +25,7 @@ from mava.components.tf.modules.exploration import LinearExplorationTimestepScheduler from mava.systems.tf import vdn from mava.utils import lp_utils -from mava.utils.environments import smac_utils +from mava.utils.environments import pettingzoo_utils from mava.utils.loggers import logger_utils from mava.wrappers.environment_loop_wrappers import MonitorParallelEnvironmentLoop @@ -49,10 +45,10 @@ def main(_: Any) -> None: - + """Example running VDN on SMAC,while recording agents.""" # environment environment_factory = functools.partial( - smac_utils.make_environment, map_name=FLAGS.map_name + pettingzoo_utils.make_environment, env_class="smac", env_name=FLAGS.map_name ) # Networks. diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index f64cb11f2..8fa1d2bc1 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" import functools from datetime import datetime @@ -33,7 +32,7 @@ from mava.components.tf.networks.epsilon_greedy import EpsilonGreedy from mava.systems.tf import madqn from mava.utils import lp_utils -from mava.utils.environments import smac_utils +from mava.utils.environments import pettingzoo_utils from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS @@ -96,10 +95,10 @@ def custom_recurrent_network( def main(_: Any) -> None: - + """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment environment_factory = functools.partial( - smac_utils.make_environment, map_name=FLAGS.map_name + pettingzoo_utils.make_environment, env_class="smac", env_name=FLAGS.map_name ) # Networks. diff --git a/mava/utils/environments/smac_utils.py b/mava/utils/environments/smac_utils.py deleted file mode 100644 index 066cf25d9..000000000 --- a/mava/utils/environments/smac_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Starcraft 2 environment factory.""" - -from typing import Any, Dict, Optional - -import dm_env - -try: - from smac.env import StarCraft2Env - - _has_smac = True -except ModuleNotFoundError: - _has_smac = False - pass -from mava.wrappers import SMACEnvWrapper # type:ignore - - -def load_smac_env(env_config: Dict[str, Any]) -> "StarCraft2Env": - """Loads a smac environment given a config dict. Also, the possible agents in the - environment are set""" - if _has_smac: - env = StarCraft2Env(**env_config) - env.possible_agents = list(range(env.n_agents)) - else: - raise Exception("Smac is not installed.") - return env - - -def make_environment( - evaluation: bool = False, - map_name: str = "3m", - random_seed: Optional[int] = None, - **kwargs: Any, -) -> dm_env.Environment: - """Wraps an starcraft 2 environment. - - Args: - map_name: str, name of micromanagement level. - - Returns: - A starcraft 2 smac environment wrapped as a DeepMind environment. - """ - if _has_smac: - del evaluation - - env = StarCraft2Env(map_name=map_name, seed=random_seed, **kwargs) - - # wrap starcraft 2 environment - environment = SMACEnvWrapper(env) - else: - raise Exception("Smac is not installed.") - return environment diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 4b819db61..a4eb4cfd0 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -25,12 +25,6 @@ PettingZooParallelEnvWrapper, ) from mava.wrappers.robocup import RoboCupWrapper -from mava.wrappers.smac import SMACEnvWrapper - -try: - from smac.env import StarCraft2Env -except ModuleNotFoundError: - pass from mava.wrappers.system_trainer_statistics import ( DetailedTrainerStatistics, NetworkStatisticsActorCritic, diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py deleted file mode 100644 index a873d36ce..000000000 --- a/mava/wrappers/smac.py +++ /dev/null @@ -1,347 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# See SMAC here: https://github.com/oxwhirl/smac -# Documentation available at smac/blob/master/docs/smac.md - -"""Wraps a StarCraft II MARL environment (SMAC) as a dm_env environment.""" - -from typing import Any, Dict, List, Tuple - -import dm_env -import numpy as np -from acme import specs -from acme.wrappers.gym_wrapper import _convert_to_spec -from gym import spaces -from gym.spaces import Box, Discrete - -try: - from pettingzoo.utils.env import ParallelEnv - from smac.env import StarCraft2Env -except ModuleNotFoundError: - pass - - -from mava import types -from mava.utils.wrapper_utils import convert_np_type, parameterized_restart -from mava.wrappers.env_wrappers import ParallelEnvWrapper - - -class SMACEnvWrapper(ParallelEnvWrapper): - """Wraps a StarCraft II MARL environment (SMAC) as a Mava Parallel environment. - Based on RLlib & Pettingzoo wrapper provided by SMAC. - Args: - ParallelEnvWrapper ([type]): [description] - """ - - def __init__(self, environment: "StarCraft2Env") -> None: - """Create a new multi-agent StarCraft env compatible with Mava. - Args: - environment (StarCraft2Env): Arguments to pass to the underlying - smac.env.starcraft.StarCraft2Env instance. - """ - - self._environment = environment - self.reset() - - def _get_agents(self) -> List: - """Function that returns agent names and ids. - Returns: - List: list containing agents in format {agent_name}_{agent_id}. - """ - agent_types = { - self._environment.marine_id: "marine", - self._environment.marauder_id: "marauder", - self._environment.medivac_id: "medivac", - self._environment.hydralisk_id: "hydralisk", - self._environment.zergling_id: "zergling", - self._environment.baneling_id: "baneling", - self._environment.stalker_id: "stalker", - self._environment.colossus_id: "colossus", - self._environment.zealot_id: "zealot", - } - - agents = [] - for agent_id, agent_info in self._environment.agents.items(): - agents.append(f"{agent_types[agent_info.unit_type]}_{agent_id}") - return agents - - def _observe_all(self, obs_list: List) -> Dict: - """Function that combibnes all agent observations into a single dict. - Args: - obs_list (List): list of all agent observations. - Returns: - Dict: dict containing agent observations and action masks. - """ - observe = {} - for i, obs in enumerate(obs_list): - observe[self.possible_agents[i]] = { - "observation": obs, - "action_mask": ( - np.array(self._environment.get_avail_agent_actions(i)).astype(bool) - ).astype(int), - } - return observe - - def reset(self) -> Tuple[dm_env.TimeStep, np.array]: - """Resets the env and returns observations from ready agents. - Returns: - obs (dict): New observations for each ready agent. - """ - self._env_done = False - self._reset_next_step = False - self._step_type = dm_env.StepType.FIRST - - # reset internal SC2 env - obs_list, state = self._environment.reset() - - # Initialize Spaces - # Agents only become populated after reset - self._possible_agents = self._get_agents() - self._agents = self._possible_agents[:] - - self.action_spaces = { - agent: Discrete(self._environment.get_total_actions()) - for agent in self._agents - } - self.observation_spaces = { - agent: spaces.Dict( - { - "observation": Box( - -1, - 1, - shape=(self._environment.get_obs_size(),), - dtype="float32", - ), - "action_mask": Box( - 0, - 1, - shape=(self.action_spaces[agent].n,), - dtype=self.action_spaces[agent].dtype, - ), - } - ) - for agent in self._agents - } - - # Convert observations - observe = self._observe_all(obs_list) - observations = self._convert_observations( - observe, {agent: False for agent in self._possible_agents} - ) - - # create discount spec - discount_spec = self.discount_spec() - self._discounts = { - agent: convert_np_type(discount_spec[agent].dtype, 1) - for agent in self._possible_agents - } - - # create rewards spec - rewards_spec = self.reward_spec() - rewards = { - agent: convert_np_type(rewards_spec[agent].dtype, 0) - for agent in self._possible_agents - } - - # dm_env timestep - timestep = parameterized_restart(rewards, self._discounts, observations) - - return timestep, {"s_t": state} - - def step(self, actions: Dict[str, np.ndarray]) -> Tuple[dm_env.TimeStep, np.array]: - """Returns observations from ready agents. - The returns are dicts mapping from agent_id strings to values. The - number of agents in the env can vary over time. - Returns - ------- - obs (dict): New observations for each ready agent. - rewards (dict): Reward values for each ready agent. If the - episode is just started, the value will be None. - dones (dict): Done values for each ready agent. The special key - "__all__" (required) is used to indicate env termination. - infos (dict): Optional info values for each agent id. - """ - if self._reset_next_step: - return self.reset() - - actions_feed = [actions[key] for key in self._agents] - reward, terminated, _ = self._environment.step(actions_feed) - obs_list = self._environment.get_obs() - state = self._environment.get_state() - self._env_done = terminated - - observe = self._observe_all(obs_list) - dones = {agent: terminated for agent in self._possible_agents} - - observations = self._convert_observations(observe, dones) - self._agents = list(observe.keys()) - rewards_spec = self.reward_spec() - - # Handle empty rewards - if not reward: - rewards = { - agent: convert_np_type(rewards_spec[agent].dtype, 0) - for agent in self._possible_agents - } - else: - rewards = { - agent: convert_np_type(rewards_spec[agent].dtype, reward) - for agent in self._agents - } - - if self.env_done(): - self._step_type = dm_env.StepType.LAST - self._reset_next_step = True - else: - self._step_type = dm_env.StepType.MID - - timestep = dm_env.TimeStep( - observation=observations, - reward=rewards, - discount=self._discounts, - step_type=self._step_type, - ) - - self.reward = rewards - - return timestep, {"s_t": state} - - def env_done(self) -> bool: - """Returns a bool indicating if all agents in env are done. - Returns: - bool: Bool indicating if all agents are done. - """ - return self._env_done - - def _convert_observations( - self, observes: Dict[str, np.ndarray], dones: Dict[str, bool] - ) -> types.Observation: - """Converts observations to correct Mava format. - Args: - observes (Dict[str, np.ndarray]): Dict containing agent observations. - dones (Dict[str, bool]): Dict indicating which agents are done. - Returns: - types.Observation: Correct format observations (OLT). - """ - observations: Dict[str, types.OLT] = {} - for agent, observation in observes.items(): - if isinstance(observation, dict) and "action_mask" in observation: - legals = observation["action_mask"] - observation = observation["observation"] - else: - legals = np.ones( - _convert_to_spec(self.action_space).shape, - dtype=self.action_space.dtype, - ) - observations[agent] = types.OLT( - observation=observation, - legal_actions=legals, - terminal=np.asarray([dones[agent]], dtype=np.float32), - ) - - return observations - - def observation_spec(self) -> types.Observation: - """Function returns observation spec (format) of the env. - Returns: - types.Observation: Observation spec. - """ - return { - agent: types.OLT( - observation=_convert_to_spec( - self.observation_spaces[agent]["observation"] - ), - legal_actions=_convert_to_spec( - self.observation_spaces[agent]["action_mask"] - ), - terminal=specs.Array((1,), np.float32), - ) - for agent in self._possible_agents - } - - def action_spec(self) -> Dict[str, specs.DiscreteArray]: - """Function returns action spec (format) of the env. - Returns: - Dict[str, specs.DiscreteArray]: action spec. - """ - return { - agent: _convert_to_spec(self.action_spaces[agent]) - for agent in self._possible_agents - } - - def reward_spec(self) -> Dict[str, specs.Array]: - """Function returns reward spec (format) of the env. - Returns: - Dict[str, specs.Array]: reward spec. - """ - return {agent: specs.Array((), np.float32) for agent in self._possible_agents} - - def discount_spec(self) -> Dict[str, specs.BoundedArray]: - """Function returns discount spec (format) of the env. - Returns: - Dict[str, specs.BoundedArray]: discount spec. - """ - return { - agent: specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) - for agent in self._possible_agents - } - - def extra_spec(self) -> Dict[str, specs.BoundedArray]: - """Function returns extra spec (format) of the env. - Returns: - Dict[str, specs.BoundedArray]: extra spec. - """ - state = self._environment.get_state() - # TODO (dries): What should the real bounds be of the state spec? - return { - "s_t": specs.BoundedArray( - state.shape, np.float32, minimum=float("-inf"), maximum=float("inf") - ) - } - - def seed(self, random_seed: int) -> None: - """Function to seed the environment. - Args: - random_seed (int): random seed used when seeding the env. - """ - self._environment._seed = random_seed - # Reset after setting seed - self.env.full_restart() - - @property - def agents(self) -> List: - """Returns active/not done agents in the env. - Returns: - List: active agents in the env. - """ - return self._agents - - @property - def possible_agents(self) -> List: - """Returns all posible agents in the env. - Returns: - List: all possible agents in the env. - """ - return self._possible_agents - - @property - def environment(self) -> "ParallelEnv": - """Returns the wrapped environment.""" - return self._environment - - def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment.""" - return getattr(self._environment, name) diff --git a/tests/conftest.py b/tests/conftest.py index ae514046d..0ff3e4845 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,9 @@ from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator + from mava.utils.environments.flatland_utils import load_flatland_env + from mava.wrappers.flatland import FlatlandEnvWrapper + _has_flatland = True except ModuleNotFoundError: _has_flatland = False @@ -41,9 +44,6 @@ from mava.types import Observation, Reward try: - if _has_flatland: - from mava.utils.environments.flatland_utils import load_flatland_env - from mava.wrappers.flatland import FlatlandEnvWrapper from mava.utils.environments.open_spiel_utils import load_open_spiel_env from mava.wrappers.open_spiel import OpenSpielSequentialWrapper except ImportError: @@ -89,14 +89,31 @@ class Helpers: - # Check all props are not none @staticmethod def verify_all_props_not_none(props_which_should_not_be_none: list) -> bool: + """Check all props are not none + + Args: + props_which_should_not_be_none : vars which should have a value. + + Returns: + bool indicating if vars are not none. + """ return all(prop is not None for prop in props_which_should_not_be_none) - # Return an env - currently Pettingzoo envs. @staticmethod def get_env(env_spec: EnvSpec) -> Union[AECEnv, ParallelEnv]: + """Return an env based on an env spec. + + Args: + env_spec : decription of env. + + Raises: + Exception: No appropriate env found. + + Returns: + an envrionment. + """ env = None if env_spec.env_source == EnvSource.PettingZoo: mod = importlib.import_module(env_spec.env_name) @@ -113,11 +130,21 @@ def get_env(env_spec: EnvSpec) -> Union[AECEnv, ParallelEnv]: env.reset() # type:ignore return env - # Returns a wrapper function. @staticmethod def get_wrapper_function( env_spec: EnvSpec, ) -> dm_env.Environment: + """Returns a wrapper function. + + Args: + env_spec : decription of env. + + Raises: + Exception: No env wrapper found. + + Returns: + an envrionment wrapper. + """ wrapper: dm_env.Environment = None if env_spec.env_source == EnvSource.PettingZoo: if env_spec.env_type == EnvType.Parallel: @@ -132,11 +159,21 @@ def get_wrapper_function( raise Exception("Env_spec is not valid.") return wrapper - # Returns an env loop. @staticmethod def get_env_loop( env_spec: EnvSpec, ) -> acme.core.Worker: + """Returns an env loop. + + Args: + env_spec : decription of env. + + Raises: + Exception: Unable to find env loop. + + Returns: + env loop. + """ env_loop = None if env_spec.env_type == EnvType.Parallel: env_loop = ParallelEnvironmentLoop @@ -146,9 +183,6 @@ def get_env_loop( raise Exception("Env_spec is not valid.") return env_loop - """Function that retrieves a mocked env, based on - env_spec.""" - @staticmethod def get_mocked_env( env_spec: EnvSpec, @@ -158,12 +192,17 @@ def get_mocked_env( ParallelMAContinuousEnvironment, SequentialMAContinuousEnvironment, ]: - env = None - if not hasattr(env_spec, "env_name"): - raise Exception("No env_name passed in.") + """Function that retrieves a mocked env. + + Args: + env_spec : decription of env. - if not hasattr(env_spec, "env_type"): - raise Exception("No env_type passed in.") + Raises: + Exception: no valid env found. + + Returns: + a mocked environment. + """ env_name = env_spec.env_name if env_name is MockedEnvironments.Mocked_Dicrete: @@ -207,6 +246,14 @@ def get_mocked_env( def is_mocked_env( env_name: str, ) -> bool: + """Returns bool indicating if env is mocked or not. + + Args: + env_spec : decription of env. + + Returns: + bool indicating if env is mocked or not. + """ mock = False if ( env_name is MockedEnvironments.Mocked_Continous @@ -215,12 +262,18 @@ def is_mocked_env( mock = True return mock - # Returns a wrapped env and specs @staticmethod def get_wrapped_env( env_spec: EnvSpec, **kwargs: Any ) -> Tuple[dm_env.Environment, acme.specs.EnvironmentSpec]: + """Returns a wrapped env and specs. + + Args: + env_spec : decription of env. + Returns: + a wrapped env and specs. + """ specs = None if Helpers.is_mocked_env(env_spec.env_name): wrapped_env = Helpers.get_mocked_env(env_spec) @@ -232,17 +285,30 @@ def get_wrapped_env( specs = Helpers.get_pz_env_spec(wrapped_env)._specs return wrapped_env, specs - # Returns a petting zoo environment spec. + # @staticmethod def get_pz_env_spec(environment: dm_env.Environment) -> dm_env.Environment: + """Returns a petting zoo environment spec. + + Args: + environment : an env. + + Returns: + a petting zoo environment spec. + """ return mava_specs.MAEnvironmentSpec(environment) - # Seeds action space @staticmethod def seed_action_space( env_wrapper: Union[PettingZooAECEnvWrapper, PettingZooParallelEnvWrapper], random_seed: int, ) -> None: + """Seeds action space. + + Args: + env_wrapper : an env wrapper. + random_seed : random seed to be used. + """ [ env_wrapper.action_spaces[agent].seed(random_seed) for agent in env_wrapper.agents @@ -250,6 +316,15 @@ def seed_action_space( @staticmethod def compare_dicts(dictA: Dict, dictB: Dict) -> bool: + """Function that check if two dicts are equal. + + Args: + dictA : dict A. + dictB : dict B. + + Returns: + bool indicating if dicts are equal or not. + """ typesA = [type(k) for k in dictA.values()] typesB = [type(k) for k in dictB.values()] @@ -257,6 +332,11 @@ def compare_dicts(dictA: Dict, dictB: Dict) -> bool: @staticmethod def assert_valid_episode(episode_result: Dict) -> None: + """Function that checks if a valid episode was run. + + Args: + episode_result : result dict from an episode. + """ assert ( episode_result["episode_length"] > 0 and episode_result["mean_episode_return"] is not None @@ -271,6 +351,13 @@ def assert_env_reset( dm_env_timestep: dm_env.TimeStep, env_spec: EnvSpec, ) -> None: + """Assert env are reset correctly. + + Args: + wrapped_env : wrapped env. + dm_env_timestep : timestep. + env_spec : env spec. + """ if env_spec.env_type == EnvType.Parallel: rewards_spec = wrapped_env.reward_spec() expected_rewards = { @@ -318,6 +405,15 @@ def verify_observations_are_normalized( min: int = 0, max: int = 1, ) -> None: + """Verify observations are normalized. + + Args: + observations : env obs. + agents : env agents. + env_spec : env spec. + min : min for normalization. + max : max for normalization. + """ if env_spec.env_type == EnvType.Parallel: for agent in agents: assert ( @@ -337,6 +433,15 @@ def verify_observations_are_normalized( def verify_reward_is_normalized( rewards: Reward, agents: List, env_spec: EnvSpec, min: int = 0, max: int = 1 ) -> None: + """Verify reward is normalized. + + Args: + rewards : rewards. + agents : env agents. + env_spec : env spec. + min : min for normalization. + max : max for normalization. + """ if env_spec.env_type == EnvType.Parallel: for agent in agents: assert ( @@ -350,6 +455,13 @@ def verify_reward_is_normalized( def verify_observations_are_standardized( observations: Observation, agents: List, env_spec: EnvSpec ) -> None: + """Verify obs are standardized. + + Args: + observations : env observations. + agents : env agents. + env_spec : env spec. + """ if env_spec.env_type == EnvType.Parallel: for agent in agents: npt.assert_almost_equal( @@ -369,10 +481,20 @@ def verify_observations_are_standardized( @staticmethod def mock_done() -> bool: + """Mock env being done. + + Returns: + returns true. + """ return True @typing.no_type_check @pytest.fixture def helpers() -> Helpers: + """Return helper class. + + Returns: + helpers class. + """ return Helpers