diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 8e4a708c9..36e32484c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,7 +2,7 @@ # Default code owners for repo -* @arnupretorius @DriesSmit @RuanJohn @jcformanek @siddarthsingh1 @sash-a @OmaymaMahjoub @ulricharmel @callumtilbury @WiemKhlifi +* @RuanJohn @sash-a @OmaymaMahjoub @WiemKhlifi @SimonDuToit @Louay-Ben-nessir # Add specific code owners for certain files or folders below diff --git a/mava/configs/arch/anakin.yaml b/mava/configs/arch/anakin.yaml index e117729d0..b026cc90e 100644 --- a/mava/configs/arch/anakin.yaml +++ b/mava/configs/arch/anakin.yaml @@ -1,4 +1,5 @@ # --- Anakin config --- +architecture_name: anakin # --- Training --- num_envs: 16 # Number of vectorised environments per device. diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml new file mode 100644 index 000000000..52ee0ffbf --- /dev/null +++ b/mava/configs/arch/sebulba.yaml @@ -0,0 +1,25 @@ +# --- Sebulba config --- +architecture_name: sebulba + +# --- Training --- +num_envs: 32 # number of environments per thread. + +# --- Evaluation --- +evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select + # an action which corresponds to the greatest logit. If false, the policy will sample + # from the logits. +num_eval_episodes: 32 # Number of episodes to evaluate per evaluation. +num_evaluation: 100 # Number of evenly spaced evaluations to perform during training. +num_absolute_metric_eval_episodes: 320 # Number of episodes to evaluate the absolute metric (the final evaluation). +absolute_metric: True # Whether the absolute metric should be computed. For more details +# on the absolute metric please see: https://arxiv.org/abs/2209.10485 + +# --- Sebulba devices config --- +n_threads_per_executor: 2 # num of different threads/env batches per actor +actor_device_ids: [0] # ids of actor devices +learner_device_ids: [0] # ids of learner devices +rollout_queue_size : 5 +# The size of the pipeline queue determines the extent of off-policy training allowed. A larger value permits more off-policy training. +# Too large of a value with too many actors will lead to all of the updates getting wasted in old episodes +# Too small of a value and the utility of having multiple actors is lost. +# A value of 1 with a single actor leads to almost strictly on-policy training. diff --git a/mava/configs/default/ff_hasac.yaml b/mava/configs/default/ff_hasac.yaml index 36448357d..0347d171f 100644 --- a/mava/configs/default/ff_hasac.yaml +++ b/mava/configs/default/ff_hasac.yaml @@ -4,7 +4,7 @@ defaults: - arch: anakin - system: sac/ff_hasac - network: mlp # [mlp, cnn] - - env: mabrax # [mabrax] + - env: mabrax # [mabrax, mpe] hydra: searchpath: diff --git a/mava/configs/default/ff_ippo.yaml b/mava/configs/default/ff_ippo.yaml index 5e2cd4dbf..5b938fef5 100644 --- a/mava/configs/default/ff_ippo.yaml +++ b/mava/configs/default/ff_ippo.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: ppo/ff_ippo - network: mlp # [mlp, cnn] - - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax] + - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/ff_ippo_sebulba.yaml b/mava/configs/default/ff_ippo_sebulba.yaml new file mode 100644 index 000000000..d0ecfae97 --- /dev/null +++ b/mava/configs/default/ff_ippo_sebulba.yaml @@ -0,0 +1,11 @@ +defaults: + - logger: logger + - arch: sebulba + - system: ppo/ff_ippo + - network: mlp # [mlp, continuous_mlp, cnn] + - env: lbf_gym # [rware_gym, lbf_gym, smaclite_gym] + - _self_ + +hydra: + searchpath: + - file://mava/configs diff --git a/mava/configs/default/ff_isac.yaml b/mava/configs/default/ff_isac.yaml index c9ff0bb28..bda1bae7c 100644 --- a/mava/configs/default/ff_isac.yaml +++ b/mava/configs/default/ff_isac.yaml @@ -4,7 +4,7 @@ defaults: - arch: anakin - system: sac/ff_isac - network: mlp - - env: mabrax # [mabrax] + - env: mabrax # [mabrax, mpe] hydra: searchpath: diff --git a/mava/configs/default/ff_mappo.yaml b/mava/configs/default/ff_mappo.yaml index 76fd980c7..9f3953b00 100644 --- a/mava/configs/default/ff_mappo.yaml +++ b/mava/configs/default/ff_mappo.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: ppo/ff_mappo - network: mlp # [mlp, cnn] - - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax] + - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/ff_masac.yaml b/mava/configs/default/ff_masac.yaml index 123cc6c67..235d02e8f 100644 --- a/mava/configs/default/ff_masac.yaml +++ b/mava/configs/default/ff_masac.yaml @@ -4,7 +4,7 @@ defaults: - arch: anakin - system: sac/ff_masac - network: mlp - - env: mabrax # [mabrax] + - env: mabrax # [mabrax, mpe] hydra: searchpath: diff --git a/mava/configs/default/ff_sable.yaml b/mava/configs/default/ff_sable.yaml index bcf11797c..406c605ca 100644 --- a/mava/configs/default/ff_sable.yaml +++ b/mava/configs/default/ff_sable.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: sable/ff_sable - network: ff_retention - - env: rware # [cleaner, connector, gigastep, lbf, rware, smax] + - env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/mat.yaml b/mava/configs/default/mat.yaml index 393781c63..9e73740b2 100644 --- a/mava/configs/default/mat.yaml +++ b/mava/configs/default/mat.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: mat/mat - network: transformer - - env: rware # [gigastep, lbf, mabrax, matrax, rware, smax] + - env: rware # [gigastep, lbf, mabrax, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/rec_ippo.yaml b/mava/configs/default/rec_ippo.yaml index 71208d0c2..91fbab8c9 100644 --- a/mava/configs/default/rec_ippo.yaml +++ b/mava/configs/default/rec_ippo.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: ppo/rec_ippo - network: rnn # [rnn, rcnn] - - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax] + - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/rec_mappo.yaml b/mava/configs/default/rec_mappo.yaml index 72d96f0fc..d3b145f5a 100644 --- a/mava/configs/default/rec_mappo.yaml +++ b/mava/configs/default/rec_mappo.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: ppo/rec_mappo - network: rnn # [rnn, rcnn] - - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax] + - env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe] - _self_ hydra: diff --git a/mava/configs/default/rec_sable.yaml b/mava/configs/default/rec_sable.yaml index 7dbdbbbc8..654441bd4 100644 --- a/mava/configs/default/rec_sable.yaml +++ b/mava/configs/default/rec_sable.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: sable/rec_sable - network: rec_retention - - env: rware # [cleaner, connector, gigastep, lbf, rware, smax] + - env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax, mpe] - _self_ hydra: diff --git a/mava/configs/env/lbf_gym.yaml b/mava/configs/env/lbf_gym.yaml new file mode 100644 index 000000000..f001e0913 --- /dev/null +++ b/mava/configs/env/lbf_gym.yaml @@ -0,0 +1,25 @@ +# ---Environment Configs--- +defaults: + - _self_ + +env_name: LevelBasedForaging # Used for logging purposes. +scenario: + name: lbforaging + task_name: Foraging-8x8-2p-1f-v3 + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. +log_win_rate: False + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True + +kwargs: + max_episode_steps: 100 diff --git a/mava/configs/env/mpe.yaml b/mava/configs/env/mpe.yaml new file mode 100644 index 000000000..d58f07547 --- /dev/null +++ b/mava/configs/env/mpe.yaml @@ -0,0 +1,19 @@ +# --- Environment Configs--- +defaults: + - _self_ + - scenario: simple_spread_3ag # [simple_spread_3ag, simple_spread_5ag, simple_spread_10ag] + +env_name: MPE # Used for logging purposes and selection of the corresponding wrapper. + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. +log_win_rate: False + +kwargs: + # Note: We only support `Continuous` actions for now but the `Discrete` version works as well. + action_type: Continuous # Whether agent action spaces are "Continuous" or "Discrete". diff --git a/mava/configs/env/rware_gym.yaml b/mava/configs/env/rware_gym.yaml new file mode 100644 index 000000000..facf7f8d7 --- /dev/null +++ b/mava/configs/env/rware_gym.yaml @@ -0,0 +1,25 @@ +# ---Environment Configs--- +defaults: + - _self_ + +env_name: RobotWarehouse # Used for logging purposes. +scenario: + name: rware + task_name: rware-tiny-2ag-v2 # [rware-tiny-2ag-v2, rware-tiny-4ag-v2, rware-tiny-4ag-easy-v2, rware-small-4ag-v2] + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. +log_win_rate: False + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True + +kwargs: + max_episode_steps: 500 diff --git a/mava/configs/env/scenario/simple_spread_10ag.yaml b/mava/configs/env/scenario/simple_spread_10ag.yaml new file mode 100644 index 000000000..c6e89a831 --- /dev/null +++ b/mava/configs/env/scenario/simple_spread_10ag.yaml @@ -0,0 +1,8 @@ +# The config of the simple_spread_10ag scenario. +name: MPE_simple_spread_v3 +task_name: simple_spread_10ag + +task_config: + num_agents: 10 + num_landmarks: 10 + local_ratio: 0.5 diff --git a/mava/configs/env/scenario/simple_spread_3ag.yaml b/mava/configs/env/scenario/simple_spread_3ag.yaml new file mode 100644 index 000000000..a3a37339f --- /dev/null +++ b/mava/configs/env/scenario/simple_spread_3ag.yaml @@ -0,0 +1,8 @@ +# The config of the simple_spread_3ag scenario. +name: MPE_simple_spread_v3 +task_name: simple_spread_3ag + +task_config: + num_agents: 3 + num_landmarks: 3 + local_ratio: 0.5 diff --git a/mava/configs/env/scenario/simple_spread_5ag.yaml b/mava/configs/env/scenario/simple_spread_5ag.yaml new file mode 100644 index 000000000..696fbdce7 --- /dev/null +++ b/mava/configs/env/scenario/simple_spread_5ag.yaml @@ -0,0 +1,8 @@ +# The config of the simple_spread_5ag scenario. +name: MPE_simple_spread_v3 +task_name: simple_spread_5ag + +task_config: + num_agents: 5 + num_landmarks: 5 + local_ratio: 0.5 diff --git a/mava/configs/env/smaclite_gym.yaml b/mava/configs/env/smaclite_gym.yaml new file mode 100644 index 000000000..967daec88 --- /dev/null +++ b/mava/configs/env/smaclite_gym.yaml @@ -0,0 +1,25 @@ +# ---Environment Configs--- +defaults: + - _self_ + +env_name: SMACLite # Used for logging purposes. +scenario: + name: smaclite + task_name: smaclite/2s3z-v0 # smaclite/ + ['10m_vs_11m-v0', '27m_vs_30m-v0', '3s5z_vs_3s6z-v0', '2s3z-v0', '3s5z-v0', '2c_vs_64zg-v0', '2s_vs_1sc-v0', '3s_vs_5z-v0'] + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. +log_win_rate: True + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True + +kwargs: + max_episode_steps: 500 diff --git a/mava/evaluator.py b/mava/evaluator.py index 8ed1cd001..6b2fda203 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp +import numpy as np from chex import Array, PRNGKey from flax.core.frozen_dict import FrozenDict from jax import tree @@ -36,6 +37,7 @@ RecActorApply, State, ) +from mava.wrappers.gym import GymToJumanji # Optional extras that are passed out of the actor and then into the actor in the next step ActorState: TypeAlias = Dict[str, Any] @@ -207,3 +209,116 @@ def eval_act_fn( return action.squeeze(0), {_hidden_state: hidden_state} return eval_act_fn + + +def get_sebulba_eval_fn( + env_maker: Callable[[int, int], GymToJumanji], + act_fn: EvalActFn, + config: DictConfig, + np_rng: np.random.Generator, + absolute_metric: bool, +) -> Tuple[EvalFn, Any]: + """Creates a function that can be used to evaluate agents on a given environment. + + Args: + ---- + env_maker: A function to create the environment instances. + act_fn: A function that takes in params, timestep, key and optionally a state + and returns actions and optionally a state (see `EvalActFn`). + config: The system config. + np_rng: Random number generator for seeding environment. + absolute_metric: Whether or not this evaluator calculates the absolute_metric. + This determines how many evaluation episodes it does. + """ + n_devices = jax.device_count() + eval_episodes = ( + config.arch.num_absolute_metric_eval_episodes + if absolute_metric + else config.arch.num_eval_episodes + ) + + n_parallel_envs = min(eval_episodes, config.arch.num_envs) + episode_loops = math.ceil(eval_episodes / n_parallel_envs) + env = env_maker(config, n_parallel_envs) + + act_fn = jax.jit( + act_fn, device=jax.local_devices()[config.arch.actor_device_ids[0]] + ) # Evaluate using the first actor device + + # Warnings if num eval episodes is not divisible by num parallel envs. + if eval_episodes % n_parallel_envs != 0: + warnings.warn( + f"Number of evaluation episodes ({eval_episodes}) is not divisible by `num_envs` * " + f"`num_devices` ({n_parallel_envs} * {n_devices}). Some extra evaluations will be " + f"executed. New number of evaluation episodes = {episode_loops * n_parallel_envs}", + stacklevel=2, + ) + + def eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics: + """Evaluates the given params on an environment and returns relevent metrics. + + Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length, + also win rate for environments that support it. + + Returns: Dict[str, Array] - dictionary of metric name to metric values for each episode. + """ + + def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: + """Simulates `num_envs` episodes.""" + + # Generate a list of random seeds within the 32-bit integer range, using a seeded RNG. + seeds = np_rng.integers(np.iinfo(np.int32).max, size=n_parallel_envs).tolist() + ts = env.reset(seed=seeds) + + timesteps_array = [ts] + + actor_state = init_act_state + finished_eps = ts.last() + + while not finished_eps.all(): + key, act_key = jax.random.split(key) + action, actor_state = act_fn(params, ts, act_key, actor_state) + cpu_action = jax.device_get(action) + ts = env.step(cpu_action) + timesteps_array.append(ts) + + finished_eps = np.logical_or(finished_eps, ts.last()) + + timesteps = jax.tree.map(lambda *x: np.stack(x), *timesteps_array) + + metrics = timesteps.extras["episode_metrics"] + if config.env.log_win_rate: + metrics["won_episode"] = timesteps.extras["won_episode"] + + # find the first instance of done to get the metrics at that timestep, we don't + # care about subsequent steps because we only the results from the first episode + done_idx = np.argmax(timesteps.last(), axis=0) + metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) + del metrics["is_terminal_step"] # uneeded for logging + + return key, metrics + + # This loop is important because we don't want too many parallel envs. + # So in evaluation we have num_envs parallel envs and loop enough times + # so that we do at least `eval_episodes` number of episodes. + metrics_array = [] + for _ in range(episode_loops): + key, metric = _episode(key) + metrics_array.append(metric) + + # flatten metrics + metrics: Metrics = jax.tree_map(lambda *x: np.array(x).reshape(-1), *metrics_array) + return metrics + + def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics: + """Wrapper around eval function to time it and add in steps per second metric.""" + start_time = time.time() + + metrics = eval_fn(params, key, init_act_state) + + end_time = time.time() + total_timesteps = jnp.sum(metrics["episode_length"]) + metrics["steps_per_second"] = total_timesteps / (end_time - start_time) + return metrics + + return timed_eval_fn, env diff --git a/mava/networks/sable_network.py b/mava/networks/sable_network.py index e626bfc16..40b1fd615 100644 --- a/mava/networks/sable_network.py +++ b/mava/networks/sable_network.py @@ -26,8 +26,10 @@ from mava.networks.torsos import SwiGLU from mava.networks.utils.sable import ( act_encoder_fn, - autoregressive_act, - train_decoder_fn, + continuous_autoregressive_act, + continuous_train_decoder_fn, + discrete_autoregressive_act, + discrete_train_decoder_fn, train_encoder_fn, ) from mava.systems.sable.types import HiddenStates, SableNetworkConfig @@ -352,7 +354,7 @@ class SableNetwork(nn.Module): action_space_type: str = _DISCRETE def setup(self) -> None: - if self.action_space_type not in [_DISCRETE]: + if self.action_space_type not in [_DISCRETE, _CONTINUOUS]: raise ValueError(f"Invalid action space type: {self.action_space_type}") assert ( @@ -385,15 +387,27 @@ def setup(self) -> None: train_encoder_fn, chunk_size=self.memory_config.chunk_size, ) - self.train_decoder_fn = partial( - train_decoder_fn, n_agents=self.n_agents, chunk_size=self.memory_config.chunk_size - ) - self.act_encoder_fn = partial( act_encoder_fn, chunk_size=self.n_agents_per_chunk, ) - self.autoregressive_act = autoregressive_act + if self.action_space_type == _CONTINUOUS: + self.train_decoder_fn = partial( + continuous_train_decoder_fn, + n_agents=self.n_agents, + chunk_size=self.memory_config.chunk_size, + action_dim=self.action_dim, + ) + self.autoregressive_act = partial( + continuous_autoregressive_act, action_dim=self.action_dim + ) + else: + self.train_decoder_fn = partial( + discrete_train_decoder_fn, + n_agents=self.n_agents, + chunk_size=self.memory_config.chunk_size, + ) + self.autoregressive_act = discrete_autoregressive_act # type: ignore def __call__( self, @@ -424,9 +438,7 @@ def __call__( rng_key=rng_key, ) - action_log = jnp.squeeze(action_log, axis=-1) value = jnp.squeeze(value, axis=-1) - entropy = jnp.squeeze(entropy, axis=-1) return value, action_log, entropy def get_actions( @@ -467,7 +479,5 @@ def get_actions( decoder_cross_retn=updated_dec_hs[1], ) - output_actions = jnp.squeeze(output_actions, axis=-1) - output_actions_log = jnp.squeeze(output_actions_log, axis=-1) value = jnp.squeeze(value, axis=-1) return output_actions, output_actions_log, value, updated_hs diff --git a/mava/networks/utils/sable/__init__.py b/mava/networks/utils/sable/__init__.py index d26b9f645..21b8a46f7 100644 --- a/mava/networks/utils/sable/__init__.py +++ b/mava/networks/utils/sable/__init__.py @@ -14,8 +14,10 @@ # ruff: noqa: F401 from mava.networks.utils.sable.decode import ( - autoregressive_act, - train_decoder_fn, + continuous_autoregressive_act, + continuous_train_decoder_fn, + discrete_autoregressive_act, + discrete_train_decoder_fn, ) from mava.networks.utils.sable.encode import ( act_encoder_fn, diff --git a/mava/networks/utils/sable/decode.py b/mava/networks/utils/sable/decode.py index c9befeb36..47edecf0f 100644 --- a/mava/networks/utils/sable/decode.py +++ b/mava/networks/utils/sable/decode.py @@ -18,16 +18,22 @@ import distrax import jax import jax.numpy as jnp +import tensorflow_probability.substrates.jax.distributions as tfd from flax import linen as nn +from mava.networks.distributions import TanhTransformedDistribution + # General shapes legend: # B: batch size # S: sequence length # A: number of actions # N: number of agents +# Constant to avoid numerical instability +_MIN_SCALE = 1e-3 + -def train_decoder_fn( +def discrete_train_decoder_fn( decoder: nn.Module, obs_rep: chex.Array, action: chex.Array, @@ -43,7 +49,7 @@ def train_decoder_fn( # Delete `rng_key` since it is not used in discrete action space del rng_key - shifted_actions = get_shifted_actions(action, legal_actions, n_agents=n_agents) + shifted_actions = get_shifted_discrete_actions(action, legal_actions, n_agents=n_agents) logit = jnp.zeros_like(legal_actions, dtype=jnp.float32) # Apply the decoder per chunk @@ -73,14 +79,14 @@ def train_decoder_fn( distribution = distrax.Categorical(logits=masked_logits) action_log_prob = distribution.log_prob(action) - action_log_prob = jnp.expand_dims(action_log_prob, axis=-1) - entropy = jnp.expand_dims(distribution.entropy(), axis=-1) - return action_log_prob, entropy + return action_log_prob, distribution.entropy() -def get_shifted_actions(action: chex.Array, legal_actions: chex.Array, n_agents: int) -> chex.Array: - """Get the shifted action sequence for predicting the next action.""" +def get_shifted_discrete_actions( + action: chex.Array, legal_actions: chex.Array, n_agents: int +) -> chex.Array: + """Get the shifted discrete action sequence for predicting the next action.""" B, S, A = legal_actions.shape # Create a shifted action sequence for predicting the next action @@ -102,7 +108,7 @@ def get_shifted_actions(action: chex.Array, legal_actions: chex.Array, n_agents: return shifted_actions -def autoregressive_act( +def discrete_autoregressive_act( decoder: nn.Module, obs_rep: chex.Array, hstates: chex.Array, @@ -141,5 +147,122 @@ def autoregressive_act( shifted_actions = shifted_actions.at[:, i + 1, 1:].set( jax.nn.one_hot(action[:, 0], A), mode="drop" ) + output_actions = output_action.astype(jnp.int32) + output_actions = jnp.squeeze(output_actions, axis=-1) + output_action_log = jnp.squeeze(output_action_log, axis=-1) + return output_actions, output_action_log, hstates + + +def continuous_train_decoder_fn( + decoder: nn.Module, + obs_rep: chex.Array, + action: chex.Array, + legal_actions: chex.Array, + hstates: chex.Array, + dones: chex.Array, + step_count: chex.Array, + n_agents: int, + chunk_size: int, + action_dim: int, + rng_key: Optional[chex.PRNGKey] = None, +) -> Tuple[chex.Array, chex.Array]: + """Parallel action sampling for discrete action spaces.""" + # Delete `legal_actions` since it is not used in continuous action space + del legal_actions + + B, S, _ = action.shape + shifted_actions = get_shifted_continuous_actions(action, action_dim, n_agents=n_agents) + act_mean = jnp.zeros((B, S, action_dim), dtype=jnp.float32) + + # Apply the decoder per chunk + num_chunks = shifted_actions.shape[1] // chunk_size + for chunk_id in range(0, num_chunks): + start_idx = chunk_id * chunk_size + end_idx = (chunk_id + 1) * chunk_size + # Chunk obs_rep, shifted_actions, dones, and step_count + chunked_obs_rep = obs_rep[:, start_idx:end_idx] + chunk_shifted_actions = shifted_actions[:, start_idx:end_idx] + chunk_dones = dones[:, start_idx:end_idx] + chunk_step_count = step_count[:, start_idx:end_idx] + chunked_act_mean, hstates = decoder( + action=chunk_shifted_actions, + obs_rep=chunked_obs_rep, + hstates=hstates, + dones=chunk_dones, + step_count=chunk_step_count, + ) + act_mean = act_mean.at[:, start_idx:end_idx].set(chunked_act_mean) + + action_std = jax.nn.softplus(decoder.log_std) + _MIN_SCALE + + base_distribution = tfd.Normal(loc=act_mean, scale=action_std) + distribution = tfd.Independent( + TanhTransformedDistribution(base_distribution), + reinterpreted_batch_ndims=1, + ) + + action_log_prob = distribution.log_prob(action) + entropy = distribution.entropy(seed=rng_key) + + return action_log_prob, entropy + + +def get_shifted_continuous_actions( + action: chex.Array, action_dim: int, n_agents: int +) -> chex.Array: + """Get the shifted continuous action sequence for predicting the next action.""" + B, S, _ = action.shape + + shifted_actions = jnp.zeros((B, S, action_dim)) + start_timestep_token = jnp.zeros(action_dim) + shifted_actions = shifted_actions.at[:, 1:, :].set(action[:, :-1, :]) + shifted_actions = shifted_actions.at[:, ::n_agents, :].set(start_timestep_token) + + return shifted_actions + + +def continuous_autoregressive_act( + decoder: nn.Module, + obs_rep: chex.Array, + hstates: chex.Array, + legal_actions: chex.Array, + step_count: chex.Array, + action_dim: int, + key: chex.PRNGKey, +) -> Tuple[chex.Array, chex.Array, chex.Array]: + # Delete `legal_actions` since it is not used in continuous action space + del legal_actions + + B, N = step_count.shape + shifted_actions = jnp.zeros((B, N, action_dim)) + output_action = jnp.zeros((B, N, action_dim)) + output_action_log = jnp.zeros((B, N)) + + # Apply the decoder autoregressively + for i in range(N): + act_mean, hstates = decoder.recurrent( + action=shifted_actions[:, i : i + 1, :], + obs_rep=obs_rep[:, i : i + 1, :], + hstates=hstates, + step_count=step_count[:, i : i + 1], + ) + action_std = jax.nn.softplus(decoder.log_std) + _MIN_SCALE + + key, sample_key = jax.random.split(key) + + base_distribution = tfd.Normal(loc=act_mean, scale=action_std) + distribution = tfd.Independent( + TanhTransformedDistribution(base_distribution), + reinterpreted_batch_ndims=1, + ) + + # the action and raw action are now just identical. + action = distribution.sample(seed=sample_key) + action_log = distribution.log_prob(action) + + output_action = output_action.at[:, i, :].set(action[:, i, :]) + output_action_log = output_action_log.at[:, i].set(action_log[:, i]) + # Adds all except the last action to shifted_actions, as it is out of range + shifted_actions = shifted_actions.at[:, i + 1, :].set(action[:, i, :], mode="drop") - return output_action.astype(jnp.int32), output_action_log, hstates + return output_action, output_action_log, hstates diff --git a/mava/systems/__init__.py b/mava/systems/__init__.py deleted file mode 100644 index 21db9ec1c..000000000 --- a/mava/systems/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/mava/systems/mat/anakin/mat.py b/mava/systems/mat/anakin/mat.py index df9a1150e..7eb229774 100644 --- a/mava/systems/mat/anakin/mat.py +++ b/mava/systems/mat/anakin/mat.py @@ -42,6 +42,7 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import ( merge_leading_dims, unreplicate_batch_dim, @@ -49,7 +50,6 @@ ) from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -330,7 +330,7 @@ def learner_setup( init_x = env.observation_spec().generate_value() init_x = tree.map(lambda x: x[None, ...], init_x) - _, action_space_type = get_action_head(env) + _, action_space_type = get_action_head(env.action_spec()) if action_space_type == "discrete": init_action = jnp.zeros((1, config.system.num_agents), dtype=jnp.int32) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index 55a3a1ccf..a722afd21 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -35,6 +35,7 @@ from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import ( merge_leading_dims, unreplicate_batch_dim, @@ -42,7 +43,6 @@ ) from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -335,7 +335,7 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index adff6cb2c..649696a3f 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -34,10 +34,10 @@ from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -333,7 +333,7 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index f7ee96a5a..f8b2baa60 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -49,10 +49,10 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -426,7 +426,7 @@ def learner_setup( # Define network and optimisers. actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index b9b89b305..3cb9c8763 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -49,10 +49,10 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -428,7 +428,7 @@ def learner_setup( # Define network and optimiser. actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py new file mode 100644 index 000000000..6f34c0b1a --- /dev/null +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -0,0 +1,751 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import queue +import threading +import warnings +from collections import defaultdict +from queue import Queue +from typing import Any, Dict, List, Sequence, Tuple + +import chex +import hydra +import jax +import jax.debug +import jax.numpy as jnp +import numpy as np +import optax +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from jax import tree +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding, PartitionSpec, Sharding +from numpy.typing import NDArray +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from mava.evaluator import get_sebulba_eval_fn as get_eval_fn +from mava.evaluator import make_ff_eval_act_fn +from mava.networks import FeedForwardActor as Actor +from mava.networks import FeedForwardValueNet as Critic +from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.types import ( + ActorApply, + CriticApply, + ExperimentOutput, + Observation, + SebulbaLearnerFn, +) +from mava.utils import make_env as environments +from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_sebulba_config, check_total_timesteps +from mava.utils.jax_utils import merge_leading_dims, switch_leading_axes +from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.network_utils import get_action_head +from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime +from mava.utils.training import make_learning_rate +from mava.wrappers.episode_metrics import get_final_step_metrics +from mava.wrappers.gym import GymToJumanji + + +def rollout( + key: chex.PRNGKey, + env: GymToJumanji, + config: DictConfig, + rollout_queue: Pipeline, + params_source: ParamsSource, + apply_fns: Tuple[ActorApply, CriticApply], + actor_device: int, + seeds: List[int], + thread_lifetime: ThreadLifetime, +) -> None: + """Runs rollouts to collect trajectories from the environment. + + Args: + key (chex.PRNGKey): The PRNGkey. + config (DictConfig): Configuration settings for the environment and rollout. + rollout_queue (Pipeline): Queue for sending collected rollouts to the learner. + params_source (ParamsSource): Source for fetching the latest network parameters + from the learner. + apply_fns (Tuple): Functions for running the actor and critic networks. + actor_device (Device): Actor device to use for rollout. + seeds (List[int]): Seeds for initializing the environment. + thread_lifetime (ThreadLifetime): Manages the thread's lifecycle. + """ + name = threading.current_thread().name + print(f"{Fore.BLUE}{Style.BRIGHT}Thread {name} started{Style.RESET_ALL}") + actor_apply_fn, critic_apply_fn = apply_fns + num_agents, num_envs = config.system.num_agents, config.arch.num_envs + move_to_device = lambda x: jax.device_put(x, device=actor_device) + + @jax.jit + def act_fn( + params: Params, + observation: Observation, + key: chex.PRNGKey, + ) -> Tuple: + """Get action and value.""" + actor_policy = actor_apply_fn(params.actor_params, observation) + action = actor_policy.sample(seed=key) + log_prob = actor_policy.log_prob(action) + # It may be faster to calculate the values in the learner as + # then we won't need to pass critic params to actors. + value = critic_apply_fn(params.critic_params, observation).squeeze() + return action, log_prob, value + + timestep = env.reset(seed=seeds) + dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) + + # Loop till the desired num_updates is reached. + while not thread_lifetime.should_stop(): + # Rollout + traj: List[PPOTransition] = [] + actor_timings: Dict[str, List[float]] = defaultdict(list) + with RecordTimeTo(actor_timings["rollout_time"]): + for _ in range(config.system.rollout_length): + with RecordTimeTo(actor_timings["get_params_time"]): + params = params_source.get() # Get the latest parameters from the learner + + obs_tpu = tree.map(move_to_device, timestep.observation) + + # Get action and value + with RecordTimeTo(actor_timings["compute_action_time"]): + key, act_key = jax.random.split(key) + action, log_prob, value = act_fn(params, obs_tpu, act_key) + cpu_action = jax.device_get(action) + + # Step environment + with RecordTimeTo(actor_timings["env_step_time"]): + timestep = env.step(cpu_action) + + dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) + + # Append data to storage + traj.append( + PPOTransition( + dones, + action, + value, + timestep.reward, + log_prob, + obs_tpu, + timestep.extras["episode_metrics"], + ) + ) + + # send trajectories to learner + with RecordTimeTo(actor_timings["rollout_put_time"]): + try: + rollout_queue.put(traj, timestep, actor_timings) + except queue.Full: + err = "Waited too long to add to the rollout queue, killing the actor thread" + warnings.warn(err, stacklevel=2) + break + + env.close() + + +def get_learner_step_fn( + apply_fns: Tuple[ActorApply, CriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> SebulbaLearnerFn[LearnerState, PPOTransition]: + """Get the learner function.""" + + num_envs = config.arch.num_envs + num_learner_envs = int(num_envs // len(config.arch.learner_device_ids)) + + # Get apply and update functions for actor and critic networks. + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def _update_step( + learner_state: LearnerState, + traj_batch: PPOTransition, + ) -> Tuple[LearnerState, Tuple]: + """A single update of the network. + + This function calculates advantages and targets based on the trajectories + from the actor and updates the actor and critic networks based on the losses. + + Args: + learner_state (LearnerState): contains all the items needed for learning. + traj_batch (PPOTransition): the batch of data to learn with. + """ + + def _calculate_gae( + traj_batch: PPOTransition, last_val: chex.Array + ) -> Tuple[chex.Array, chex.Array]: + """Calculate the GAE.""" + + gamma, gae_lambda = config.system.gamma, config.system.gae_lambda + + def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple: + """Calculate the GAE for a single transition.""" + gae, next_value = gae_and_next_value + done, value, reward = transition.done, transition.value, transition.reward + + delta = reward + gamma * next_value * (1 - done) - value + gae = delta + gamma * gae_lambda * (1 - done) * gae + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + # Calculate advantage + params, opt_states, key, _, final_timestep = learner_state + last_val = critic_apply_fn(params.critic_params, final_timestep.observation) + advantages, targets = _calculate_gae(traj_batch, last_val) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + + # Unpack train state and batch info + params, opt_states, key = train_state + traj_batch, advantages, targets = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + traj_batch: PPOTransition, + gae: chex.Array, + key: chex.PRNGKey, + ) -> Tuple: + """Calculate the actor loss.""" + # Rerun network + actor_policy = actor_apply_fn(actor_params, traj_batch.obs) + log_prob = actor_policy.log_prob(traj_batch.action) + + # Calculate actor loss + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config.system.clip_eps, + 1.0 + config.system.clip_eps, + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + # The seed will be used in the TanhTransformedDistribution: + entropy = actor_policy.entropy(seed=key).mean() + + total_loss_actor = loss_actor - config.system.ent_coef * entropy + return total_loss_actor, (loss_actor, entropy) + + def _critic_loss_fn( + critic_params: FrozenDict, traj_batch: PPOTransition, targets: chex.Array + ) -> Tuple: + """Calculate the critic loss.""" + # Rerun network + value = critic_apply_fn(critic_params, traj_batch.obs) + + # Calculate value loss + value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( + -config.system.clip_eps, config.system.clip_eps + ) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + + critic_total_loss = config.system.vf_coef * value_loss + return critic_total_loss, (value_loss) + + # Calculate actor loss + key, entropy_key = jax.random.split(key) + actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) + actor_loss_info, actor_grads = actor_grad_fn( + params.actor_params, traj_batch, advantages, entropy_key + ) + + # Calculate critic loss + critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) + critic_loss_info, critic_grads = critic_grad_fn( + params.critic_params, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # pmean over learner devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), + axis_name="learner_devices", + ) + + # pmean over learner devices. + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="learner_devices" + ) + + # Update actor params and optimiser state + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # Update critic params and optimiser state + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + # Pack new params and optimiser state + new_params = Params(actor_new_params, critic_new_params) + new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) + # Pack loss info + actor_total_loss, (actor_loss, entropy) = actor_loss_info + critic_total_loss, (value_loss) = critic_loss_info + total_loss = critic_total_loss + actor_total_loss + loss_info = { + "total_loss": total_loss, + "value_loss": value_loss, + "actor_loss": actor_loss, + "entropy": entropy, + } + return (new_params, new_opt_state, key), loss_info + + params, opt_states, traj_batch, advantages, targets, key = update_state + key = jnp.squeeze(key, axis=0) # Remove the learner_devices axis + key, shuffle_key, entropy_key = jax.random.split(key, 3) + key = jnp.expand_dims(key, axis=0) # add the learner_devices axis for shape consitency + # Shuffle minibatches + batch_size = config.system.rollout_length * num_learner_envs + permutation = jax.random.permutation(shuffle_key, batch_size) + batch = (traj_batch, advantages, targets) + batch = tree.map(lambda x: merge_leading_dims(x, 2), batch) + shuffled_batch = tree.map(lambda x: jnp.take(x, permutation, axis=0), batch) + minibatches = tree.map( + lambda x: jnp.reshape(x, (config.system.num_minibatches, -1, *x.shape[1:])), + shuffled_batch, + ) + # Update minibatches + (params, opt_states, _), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states, entropy_key), minibatches + ) + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + return update_state, loss_info + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + # Update epochs + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.ppo_epochs + ) + + params, opt_states, traj_batch, advantages, targets, key = update_state + learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn( + learner_state: LearnerState, traj_batch: PPOTransition + ) -> ExperimentOutput[LearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + Args: + learner_state (NamedTuple): + - params (Params): The initial model parameters. + - opt_states (OptStates): The initial optimizer state. + - key (chex.PRNGKey): The random number generator state. + - env_state (LogEnvState): The environment state. + - timesteps (TimeStep): The last timestep of the rollout. + """ + # This function is shard mapped on the batch axis, but `_update_step` needs + # the first axis to be time + traj_batch = tree.map(switch_leading_axes, traj_batch) + learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch) + + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_thread( + learn_fn: SebulbaLearnerFn[LearnerState, PPOTransition], + learner_state: LearnerState, + config: DictConfig, + eval_queue: Queue, + pipeline: Pipeline, + params_sources: Sequence[ParamsSource], +) -> None: + for _ in range(config.arch.num_evaluation): + # Create the lists to store metrics and timings for this learning iteration. + metrics: List[Tuple[Dict, Dict]] = [] + rollout_times_array: List[Dict] = [] + learn_times: Dict[str, List[float]] = defaultdict(list) + + with RecordTimeTo(learn_times["learner_time_per_eval"]): + for _ in range(config.system.num_updates_per_eval): + # Get the trajectory batch from the pipeline + # This is blocking so it will wait until the pipeline has data. + with RecordTimeTo(learn_times["rollout_get_time"]): + traj_batch, timestep, rollout_time = pipeline.get(block=True) + + # Replace the timestep in the learner state with the latest timestep + # This means the learner has access to the entire trajectory as well as + # an additional timestep which it can use to bootstrap. + learner_state = learner_state._replace(timestep=timestep) + # Update the networks + with RecordTimeTo(learn_times["learning_time"]): + learner_state, ep_metrics, train_metrics = learn_fn(learner_state, traj_batch) + + metrics.append((ep_metrics, train_metrics)) + rollout_times_array.append(rollout_time) + + # Update all the params sources so all actors can get the latest params + params = jax.block_until_ready(learner_state.params) + for source in params_sources: + source.update(params) + + # Pass all the metrics and params to the main thread (evaluator) for logging and evaluation + ep_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) + rollout_times: Dict[str, NDArray] = tree.map(lambda *x: np.mean(x), *rollout_times_array) + timing_dict = rollout_times | learn_times + timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) + + eval_queue.put((ep_metrics, train_metrics, learner_state, timing_dict)) + + +def learner_setup( + key: chex.PRNGKey, config: DictConfig, learner_devices: List +) -> Tuple[ + SebulbaLearnerFn[LearnerState, PPOTransition], + Tuple[ActorApply, CriticApply], + LearnerState, + Sharding, +]: + """Initialise learner_fn, network and learner state.""" + + # create temporory envoirnments. + env = environments.make_gym_env(config, config.arch.num_envs) + # Get number of agents and actions. + action_space = env.single_action_space + config.system.num_agents = len(action_space) + config.system.num_actions = int(action_space[0].n) + + devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) + mesh = Mesh(devices, axis_names=("learner_devices",)) + model_spec = PartitionSpec() + data_spec = PartitionSpec("learner_devices") + learner_sharding = NamedSharding(mesh, model_spec) + + # PRNG keys. + key, actor_key, critic_key = jax.random.split(key, 3) + + # Define network and optimiser. + actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + action_head, _ = get_action_head(action_space) + actor_action_head = hydra.utils.instantiate(action_head, action_dim=config.system.num_actions) + + critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + + actor_network = Actor(torso=actor_torso, action_head=actor_action_head) + critic_network = Critic(torso=critic_torso) + + actor_lr = make_learning_rate(config.system.actor_lr, config) + critic_lr = make_learning_rate(config.system.critic_lr, config) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(critic_lr, eps=1e-5), + ) + + # Initialise observation: Select only obs for a single agent. + init_obs = jnp.array([env.single_observation_space.sample()]) + init_action_mask = jnp.ones((config.system.num_agents, config.system.num_actions)) + init_x = Observation(init_obs, init_action_mask) + + # Initialise actor params and optimiser state. + actor_params = actor_network.init(actor_key, init_x) + actor_opt_state = actor_optim.init(actor_params) + + # Initialise critic params and optimiser state. + critic_params = critic_network.init(critic_key, init_x) + critic_opt_state = critic_optim.init(critic_params) + + # Pack params. + params = Params(actor_params, critic_params) + + # Pack apply and update functions. + apply_fns = (actor_network.apply, critic_network.apply) + update_fns = (actor_optim.update, critic_optim.update) + + # defines how the learner state is sharded: params, opt and key = sharded, timestep = sharded + learn_state_spec = LearnerState(model_spec, model_spec, data_spec, None, data_spec) + learn = get_learner_step_fn(apply_fns, update_fns, config) + learn = jax.jit( + shard_map( + learn, + mesh=mesh, + in_specs=(learn_state_spec, data_spec), + out_specs=ExperimentOutput(learn_state_spec, data_spec, data_spec), + ) + ) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params(input_params=params) + # Update the params + params = restored_params + + # Define params to be replicated across devices and batches. + key, *step_keys = jax.random.split(key, len(learner_devices) + 1) + step_keys = jnp.stack(step_keys, 0) + opt_states = OptStates(actor_opt_state, critic_opt_state) + + # Duplicate learner across Learner devices. + params, opt_states, step_keys = jax.device_put( + (params, opt_states, step_keys), learner_sharding + ) + + # Initialise learner state. + init_learner_state = LearnerState(params, opt_states, step_keys, None, None) # type: ignore + env.close() + + return learn, apply_fns, init_learner_state, learner_sharding # type: ignore + + +def run_experiment(_config: DictConfig) -> float: + """Runs experiment.""" + config = copy.deepcopy(_config) + + local_devices = jax.local_devices() + devices = jax.devices() + err = "Local and global devices must be the same, we dont support multihost yet" + assert len(local_devices) == len(devices), err + learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] + actor_devices = [local_devices[device_id] for device_id in config.arch.actor_device_ids] + + # JAX and numpy RNGs + key = jax.random.PRNGKey(config.system.seed) + np_rng = np.random.default_rng(config.system.seed) + + # Setup learner. + learn, apply_fns, learner_state, learner_sharding = learner_setup(key, config, learner_devices) + + # Setup evaluator. + # One key per device for evaluation. + eval_act_fn = make_ff_eval_act_fn(apply_fns[0], config) + evaluator, evaluator_envs = get_eval_fn( + environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=False + ) + + # Calculate total timesteps. + config = check_total_timesteps(config) + check_sebulba_config(config) + + steps_per_rollout = ( + config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval + ) + + # Logger setup + logger = MavaLogger(config) + print_cfg: Dict = OmegaConf.to_container(config, resolve=True) + print_cfg["arch"]["devices"] = jax.devices() + pprint(print_cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Executor setup and launch. + inital_params = jax.device_put(learner_state.params, actor_devices[0]) # unreplicate + + # the rollout queue/ the pipe between actor and learner + pipe_lifetime = ThreadLifetime() + pipe = Pipeline(config.arch.rollout_queue_size, learner_sharding, pipe_lifetime) + pipe.start() + + params_sources: List[ParamsSource] = [] + actor_threads: List[threading.Thread] = [] + actor_lifetime = ThreadLifetime() + params_sources_lifetime = ThreadLifetime() + + # Create the actor threads + print(f"{Fore.BLUE}{Style.BRIGHT}Starting up actor threads...{Style.RESET_ALL}") + for actor_device in actor_devices: + # Create 1 params source per device + params_source = ParamsSource(inital_params, actor_device, params_sources_lifetime) + params_source.start() + params_sources.append(params_source) + # Create multiple rollout threads per actor device + for thread_id in range(config.arch.n_threads_per_executor): + key, act_key = jax.random.split(key) + seeds = np_rng.integers(np.iinfo(np.int32).max, size=config.arch.num_envs).tolist() + act_key = jax.device_put(key, actor_device) + + actor = threading.Thread( + target=rollout, + args=( + act_key, + # We have to do this here, creating envs inside actor threads causes deadlocks + environments.make_gym_env(config, config.arch.num_envs), + config, + pipe, + params_source, + apply_fns, + actor_device, + seeds, + actor_lifetime, + ), + name=f"Actor-{actor_device}-{thread_id}", + ) + actor_threads.append(actor) + + # Start the actors simultaneously + for actor in actor_threads: + actor.start() + + eval_queue: Queue = Queue() + threading.Thread( + target=learner_thread, + name="Learner", + args=(learn, learner_state, config, eval_queue, pipe, params_sources), + ).start() + + max_episode_return = -np.inf + best_params_cpu = jax.device_get(inital_params.actor_params) + + # This is the main loop, all it does is evaluation and logging. + # Acting and learning is happening in their own threads. + # This loop waits for the learner to finish an update before evaluation and logging. + for eval_step in range(config.arch.num_evaluation): + # Sync with the learner - the get() is blocking so it keeps eval and learning in step. + episode_metrics, train_metrics, learner_state, time_metrics = eval_queue.get() + + t = int(steps_per_rollout * (eval_step + 1)) + time_metrics |= {"timestep": t, "pipline_size": pipe.qsize()} + logger.log(time_metrics, t, eval_step, LogEvent.MISC) + + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / time_metrics["rollout_time"] + if ep_completed: + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + + train_metrics["learner_step"] = (eval_step + 1) * config.system.num_updates_per_eval + train_metrics["learner_steps_per_second"] = ( + config.system.num_updates_per_eval + ) / time_metrics["learner_time_per_eval"] + logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) + + learner_state_cpu = jax.device_get(learner_state) + key, eval_key = jax.random.split(key, 2) + eval_metrics = evaluator(learner_state_cpu.params.actor_params, eval_key, {}) + logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) + + episode_return = np.mean(eval_metrics["episode_return"]) + + if save_checkpoint: # Save a checkpoint of the learner state + checkpointer.save( + timestep=steps_per_rollout * (eval_step + 1), + unreplicated_learner_state=learner_state_cpu, + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params_cpu = copy.deepcopy(learner_state_cpu.params.actor_params) + max_episode_return = float(episode_return) + + evaluator_envs.close() + eval_performance = float(np.mean(eval_metrics[config.env.eval_metric])) + + # Measure absolute metric. + if config.arch.absolute_metric: + print(f"{Fore.BLUE}{Style.BRIGHT}Measuring absolute metric...{Style.RESET_ALL}") + abs_metric_evaluator, abs_metric_evaluator_envs = get_eval_fn( + environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=True + ) + key, eval_key = jax.random.split(key, 2) + eval_metrics = abs_metric_evaluator(best_params_cpu, eval_key, {}) + + t = int(steps_per_rollout * (eval_step + 1)) + logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) + abs_metric_evaluator_envs.close() + + # Stop all the threads. + logger.stop() + actor_lifetime.stop() + pipe.clear() # We clear the pipeline before stopping the actor threads to avoid deadlock + print(f"{Fore.RED}{Style.BRIGHT}Pipe cleared{Style.RESET_ALL}") + print(f"{Fore.RED}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}") + for actor in actor_threads: + actor.join() + print(f"{Fore.RED}{Style.BRIGHT}{actor.name} stopped{Style.RESET_ALL}") + print(f"{Fore.RED}{Style.BRIGHT}Stopping pipeline...{Style.RESET_ALL}") + pipe_lifetime.stop() + pipe.join() + print(f"{Fore.RED}{Style.BRIGHT}Stopping params sources...{Style.RESET_ALL}") + params_sources_lifetime.stop() + for params_source in params_sources: + params_source.join() + print(f"{Fore.RED}{Style.BRIGHT}All threads stopped...{Style.RESET_ALL}") + + return eval_performance + + +@hydra.main( + config_path="../../../configs/default/", + config_name="ff_ippo_sebulba.yaml", + version_base="1.2", +) +def hydra_entry_point(cfg: DictConfig) -> float: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + cfg.logger.system_name = "ff_ippo_sebulba" + + # Run experiment. + eval_performance = run_experiment(cfg) + print(f"{Fore.CYAN}{Style.BRIGHT}IPPO experiment completed{Style.RESET_ALL}") + return eval_performance + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/mava/systems/ppo/types.py b/mava/systems/ppo/types.py index 70b37afd5..9e56e17f8 100644 --- a/mava/systems/ppo/types.py +++ b/mava/systems/ppo/types.py @@ -19,7 +19,7 @@ from optax._src.base import OptState from typing_extensions import NamedTuple -from mava.types import Action, Done, HiddenState, State, Value +from mava.types import Action, Done, HiddenState, Observation, State, Value class Params(NamedTuple): @@ -73,7 +73,7 @@ class PPOTransition(NamedTuple): value: Value reward: chex.Array log_prob: chex.Array - obs: chex.Array + obs: Observation class RNNPPOTransition(NamedTuple): diff --git a/mava/systems/q_learning/anakin/rec_iql.py b/mava/systems/q_learning/anakin/rec_iql.py index 6a11df8f5..a5a876ccd 100644 --- a/mava/systems/q_learning/anakin/rec_iql.py +++ b/mava/systems/q_learning/anakin/rec_iql.py @@ -47,13 +47,13 @@ from mava.types import MarlEnv, Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import ( switch_leading_axes, unreplicate_batch_dim, unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics diff --git a/mava/systems/q_learning/anakin/rec_qmix.py b/mava/systems/q_learning/anakin/rec_qmix.py index 2b485bd09..7dcccf75c 100644 --- a/mava/systems/q_learning/anakin/rec_qmix.py +++ b/mava/systems/q_learning/anakin/rec_qmix.py @@ -47,13 +47,13 @@ from mava.types import MarlEnv, Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import ( switch_leading_axes, unreplicate_batch_dim, unreplicate_n_dims, ) from mava.utils.logger import LogEvent, MavaLogger -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics diff --git a/mava/systems/sable/anakin/ff_sable.py b/mava/systems/sable/anakin/ff_sable.py index 6d65ac679..5a7f08859 100644 --- a/mava/systems/sable/anakin/ff_sable.py +++ b/mava/systems/sable/anakin/ff_sable.py @@ -42,10 +42,10 @@ from mava.types import Action, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -164,8 +164,7 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # Unpack train state and batch info - params, opt_state = train_state + params, opt_state, key = train_state traj_batch, advantages, targets = batch_info def _loss_fn( @@ -173,6 +172,7 @@ def _loss_fn( traj_batch: Transition, gae: chex.Array, value_targets: chex.Array, + rng_key: chex.PRNGKey, ) -> Tuple: """Calculate Sable loss.""" # Rerun network @@ -181,6 +181,7 @@ def _loss_fn( observation=traj_batch.obs, action=traj_batch.action, dones=traj_batch.done, + rng_key=rng_key, ) # Calculate actor loss @@ -216,13 +217,9 @@ def _loss_fn( return total_loss, (actor_loss, entropy, value_loss) # Calculate loss + key, entropy_key = jax.random.split(key) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - loss_info, grads = grad_fn( - params, - traj_batch, - advantages, - targets, - ) + loss_info, grads = grad_fn(params, traj_batch, advantages, targets, entropy_key) # Compute the parallel mean (pmean) over the batch. # This pmean could be a regular mean as the batch axis is on the same device. @@ -242,12 +239,12 @@ def _loss_fn( "entropy": entropy, } - return (new_params, new_opt_state), loss_info + return (new_params, new_opt_state, key), loss_info (params, opt_states, traj_batch, advantages, targets, key) = update_state # Shuffle minibatches - key, batch_shuffle_key, agent_shuffle_key = jax.random.split(key, 3) + key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4) # Shuffle batch batch_size = config.system.rollout_length * config.arch.num_envs @@ -267,9 +264,9 @@ def _loss_fn( ) # Update minibatches - (params, opt_states), loss_info = jax.lax.scan( + (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, - (params, opt_states), + (params, opt_states, entropy_key), minibatches, ) @@ -338,7 +335,7 @@ def learner_setup( key, net_key = keys # Get number of agents and actions. - action_dim = int(env.action_spec().num_values[0]) + action_dim = env.action_dim n_agents = env.action_spec().shape[0] config.system.num_agents = n_agents config.system.num_actions = action_dim @@ -356,7 +353,7 @@ def learner_setup( # Set positional encoding to False, since ff-sable does not use temporal dependencies. config.network.memory_config.timestep_positional_encoding = False - _, action_space_type = get_action_head(env) + _, action_space_type = get_action_head(env.action_spec()) # Define network. sable_network = SableNetwork( diff --git a/mava/systems/sable/anakin/rec_sable.py b/mava/systems/sable/anakin/rec_sable.py index f8291e798..97b83cc37 100644 --- a/mava/systems/sable/anakin/rec_sable.py +++ b/mava/systems/sable/anakin/rec_sable.py @@ -43,10 +43,10 @@ from mava.types import Action, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import concat_time_and_agents, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.total_timestep_checker import check_total_timesteps from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -171,7 +171,7 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - params, opt_state = train_state + params, opt_state, key = train_state traj_batch, advantages, targets, prev_hstates = batch_info def _loss_fn( @@ -180,6 +180,7 @@ def _loss_fn( gae: chex.Array, value_targets: chex.Array, prev_hstates: HiddenStates, + rng_key: chex.PRNGKey, ) -> Tuple: """Calculate Sable loss.""" # Rerun network @@ -189,6 +190,7 @@ def _loss_fn( traj_batch.action, prev_hstates, traj_batch.done, + rng_key, ) # Calculate actor loss @@ -225,8 +227,16 @@ def _loss_fn( return total_loss, (actor_loss, entropy, value_loss) # Calculate loss + key, entropy_key = jax.random.split(key) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - loss_info, grads = grad_fn(params, traj_batch, advantages, targets, prev_hstates) + loss_info, grads = grad_fn( + params, + traj_batch, + advantages, + targets, + prev_hstates, + entropy_key, + ) # Compute the parallel mean (pmean) over the batch. # This pmean could be a regular mean as the batch axis is on the same device. @@ -246,12 +256,12 @@ def _loss_fn( "entropy": entropy, } - return (new_params, new_opt_state), loss_info + return (new_params, new_opt_state, key), loss_info (params, opt_states, traj_batch, advantages, targets, key, prev_hstates) = update_state # Shuffle minibatches - key, batch_shuffle_key, agent_shuffle_key = jax.random.split(key, 3) + key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4) # Shuffle batch batch_size = config.arch.num_envs @@ -280,9 +290,9 @@ def _loss_fn( ) # UPDATE MINIBATCHES - (params, opt_states), loss_info = jax.lax.scan( + (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, - (params, opt_states), + (params, opt_states, entropy_key), (*minibatches, prev_hs_minibatch), ) @@ -353,7 +363,7 @@ def learner_setup( key, net_key = keys # Get number of agents and actions. - action_dim = int(env.action_spec().num_values[0]) + action_dim = env.action_dim n_agents = env.action_spec().shape[0] config.system.num_agents = n_agents config.system.num_actions = action_dim @@ -366,7 +376,7 @@ def learner_setup( else: config.network.memory_config.chunk_size = config.system.rollout_length * n_agents - _, action_space_type = get_action_head(env) + _, action_space_type = get_action_head(env.action_spec()) # Define network. sable_network = SableNetwork( diff --git a/mava/systems/sac/anakin/ff_hasac.py b/mava/systems/sac/anakin/ff_hasac.py index 0ea26ba9e..043db91d9 100644 --- a/mava/systems/sac/anakin/ff_hasac.py +++ b/mava/systems/sac/anakin/ff_hasac.py @@ -52,6 +52,7 @@ from mava.utils import make_env as environments from mava.utils.centralised_training import get_joint_action from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import ( tree_at_set, tree_slice, @@ -60,7 +61,6 @@ ) from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics # General shape comment guideline: @@ -153,7 +153,7 @@ def replicate(x: Any) -> Any: # Making actor network actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate( action_head, action_dim=env.action_dim, independent_std=False ) diff --git a/mava/systems/sac/anakin/ff_isac.py b/mava/systems/sac/anakin/ff_isac.py index 9d70984fc..12416d542 100644 --- a/mava/systems/sac/anakin/ff_isac.py +++ b/mava/systems/sac/anakin/ff_isac.py @@ -49,10 +49,10 @@ from mava.types import MarlEnv, Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics @@ -111,7 +111,7 @@ def replicate(x: Any) -> Any: # Making actor network actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate( action_head, action_dim=env.action_dim, independent_std=False ) diff --git a/mava/systems/sac/anakin/ff_masac.py b/mava/systems/sac/anakin/ff_masac.py index 3b99fcdd7..693364d68 100644 --- a/mava/systems/sac/anakin/ff_masac.py +++ b/mava/systems/sac/anakin/ff_masac.py @@ -50,10 +50,10 @@ from mava.utils import make_env as environments from mava.utils.centralised_training import get_joint_action, get_updated_joint_actions from mava.utils.checkpointing import Checkpointer +from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.total_timestep_checker import check_total_timesteps from mava.wrappers import episode_metrics @@ -114,7 +114,7 @@ def replicate(x: Any) -> Any: # Making actor network actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate( action_head, action_dim=env.action_dim, independent_std=False ) diff --git a/mava/types.py b/mava/types.py index 7cd52759b..4072629dc 100644 --- a/mava/types.py +++ b/mava/types.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Generic, Protocol, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Optional, Protocol, Tuple, TypeVar, Union import chex import jumanji.specs as specs @@ -117,7 +117,7 @@ class Observation(NamedTuple): agents_view: chex.Array # (num_agents, num_obs_features) action_mask: chex.Array # (num_agents, num_actions) - step_count: chex.Array # (num_agents, ) + step_count: Optional[chex.Array] = None # (num_agents, ) class ObservationGlobalState(NamedTuple): @@ -130,7 +130,7 @@ class ObservationGlobalState(NamedTuple): agents_view: chex.Array # (num_agents, num_obs_features) action_mask: chex.Array # (num_agents, num_actions) global_state: chex.Array # (num_agents, num_agents * num_obs_features) - step_count: chex.Array # (num_agents, ) + step_count: Optional[chex.Array] = None # (num_agents, ) RNNObservation: TypeAlias = Tuple[Observation, Done] @@ -140,6 +140,7 @@ class ObservationGlobalState(NamedTuple): # `MavaState` is the main type passed around in our systems. It is often used as a scan carry. # Types like: `LearnerState` (mava/systems//types.py) are `MavaState`s. MavaState = TypeVar("MavaState") +MavaTransition = TypeVar("MavaTransition") class ExperimentOutput(NamedTuple, Generic[MavaState]): @@ -151,6 +152,7 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]): LearnerFn = Callable[[MavaState], ExperimentOutput[MavaState]] +SebulbaLearnerFn = Callable[[MavaState, MavaTransition], ExperimentOutput[MavaState]] ActorApply = Callable[[FrozenDict, Observation], Distribution] CriticApply = Callable[[FrozenDict, Observation], Value] RecActorApply = Callable[ diff --git a/mava/utils/total_timestep_checker.py b/mava/utils/config.py similarity index 56% rename from mava/utils/total_timestep_checker.py rename to mava/utils/config.py index c2cda8320..23484311b 100644 --- a/mava/utils/total_timestep_checker.py +++ b/mava/utils/config.py @@ -18,9 +18,37 @@ from omegaconf import DictConfig +def check_sebulba_config(config: DictConfig) -> None: + """Checks that the given config does not have conflicting values.""" + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + + assert config.arch.num_envs % len(config.arch.learner_device_ids) == 0, ( + "Number of environments must be divisible by the number of learner." + + "The output of each actor is equally split across the learners." + ) + + num_eval_samples = ( + int(config.arch.num_envs / len(config.arch.learner_device_ids)) + * config.system.rollout_length + ) + assert num_eval_samples % config.system.num_minibatches == 0, ( + f"Number of training samples per evaluator ({num_eval_samples})" + + f"must be divisible by num_minibatches ({config.system.num_minibatches})." + ) + + def check_total_timesteps(config: DictConfig) -> DictConfig: """Check if total_timesteps is set, if not, set it based on the other parameters""" - n_devices = len(jax.devices()) + + if config.arch.architecture_name == "anakin": + n_devices = len(jax.devices()) + update_batch_size = config.system.update_batch_size + else: + n_devices = 1 # We only use a single device's output when updating. + update_batch_size = 1 if config.system.total_timesteps is None: config.system.num_updates = int(config.system.num_updates) @@ -28,7 +56,7 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: n_devices * config.system.num_updates * config.system.rollout_length - * config.system.update_batch_size + * update_batch_size * config.arch.num_envs ) else: @@ -36,7 +64,7 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: config.system.num_updates = int( config.system.total_timesteps // config.system.rollout_length - // config.system.update_batch_size + // update_batch_size // config.arch.num_envs // n_devices ) diff --git a/mava/utils/jax_utils.py b/mava/utils/jax_utils.py index 0684d4070..2425210e9 100644 --- a/mava/utils/jax_utils.py +++ b/mava/utils/jax_utils.py @@ -105,5 +105,4 @@ def unreplicate_batch_dim(x: Any) -> Any: def switch_leading_axes(arr: chex.Array) -> chex.Array: """Switches the first two axes, generally used for BT -> TB.""" - arr = tree.map(lambda x: jax.numpy.swapaxes(x, 0, 1), arr) - return arr + return tree.map(lambda x: x.swapaxes(0, 1), arr) diff --git a/mava/utils/logger.py b/mava/utils/logger.py index dbb55e082..75280b9c8 100644 --- a/mava/utils/logger.py +++ b/mava/utils/logger.py @@ -153,8 +153,11 @@ class NeptuneLogger(BaseLogger): def __init__(self, cfg: DictConfig, unique_token: str) -> None: tags = list(cfg.logger.kwargs.neptune_tag) project = cfg.logger.kwargs.neptune_project + mode = ( + "async" if cfg.arch.architecture_name == "anakin" else "sync" + ) # async logging leads to deadlocks in sebulba - self.logger = neptune.init_run(project=project, tags=tags) + self.logger = neptune.init_run(project=project, tags=tags, mode=mode) self.logger["config"] = stringify_unsupported(cfg) self.detailed_logging = cfg.logger.kwargs.detailed_neptune_logging @@ -175,6 +178,7 @@ def log_stat(self, key: str, value: float, step: int, eval_step: int, event: Log if not self.detailed_logging and not is_main_metric: return + value = value.item() if isinstance(value, (jax.Array, np.ndarray)) else value self.logger[f"{event.value}/{key}"].log(value, step=step) def stop(self) -> None: @@ -341,7 +345,7 @@ def get_logger_path(config: DictConfig, logger_type: str) -> str: def describe(x: ArrayLike) -> Union[Dict[str, ArrayLike], ArrayLike]: """Generate summary statistics for an array of metrics (mean, std, min, max).""" - if not isinstance(x, jax.Array) or x.ndim == 0: + if not isinstance(x, (jax.Array, np.ndarray)) or x.ndim == 0: return x # np instead of jnp because we don't jit here diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 03b4678f0..0a56367c8 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -14,6 +14,10 @@ from typing import Tuple +import gymnasium +import gymnasium as gym +import gymnasium.vector +import gymnasium.wrappers import jaxmarl import jumanji import matrax @@ -40,13 +44,20 @@ CleanerWrapper, ConnectorWrapper, GigastepWrapper, + GymAgentIDWrapper, + GymRecordEpisodeMetrics, + GymToJumanji, LbfWrapper, MabraxWrapper, MatraxWrapper, + MPEWrapper, RecordEpisodeMetrics, RwareWrapper, + SmacWrapper, SmaxWrapper, + UoeWrapper, VectorConnectorWrapper, + async_multiagent_worker, ) # Registry mapping environment names to their generator and wrapper classes. @@ -63,9 +74,15 @@ # Registry mapping environment names directly to the corresponding wrapper classes. _matrax_registry = {"Matrax": MatraxWrapper} -_jaxmarl_registry = {"Smax": SmaxWrapper, "MaBrax": MabraxWrapper} +_jaxmarl_registry = {"Smax": SmaxWrapper, "MaBrax": MabraxWrapper, "MPE": MPEWrapper} _gigastep_registry = {"Gigastep": GigastepWrapper} +_gym_registry = { + "RobotWarehouse": UoeWrapper, + "LevelBasedForaging": UoeWrapper, + "SMACLite": SmacWrapper, +} + def add_extra_wrappers( train_env: MarlEnv, eval_env: MarlEnv, config: DictConfig @@ -133,6 +150,8 @@ def make_jaxmarl_env(config: DictConfig, add_global_state: bool = False) -> Tupl kwargs = dict(config.env.kwargs) if "smax" in config.env.env_name.lower(): kwargs["scenario"] = map_name_to_scenario(config.env.scenario.task_name) + elif "mpe" in config.env.env_name.lower(): + kwargs.update(config.env.scenario.task_config) # Create jaxmarl envs. train_env: MarlEnv = _jaxmarl_registry[config.env.env_name]( @@ -207,6 +226,44 @@ def make_gigastep_env( return train_env, eval_env +def make_gym_env( + config: DictConfig, + num_env: int, + add_global_state: bool = False, +) -> GymToJumanji: + """ + Create a gymnasium environment. + + Args: + config (Dict): The configuration of the environment. + num_env (int) : The number of parallel envs to create. + add_global_state (bool): Whether to add the global state to the observation. Default False. + + Returns: + Async environments. + """ + wrapper = _gym_registry[config.env.env_name] + config.system.add_agent_id = config.system.add_agent_id & (~config.env.implicit_agent_id) + + def create_gym_env(config: DictConfig, add_global_state: bool = False) -> gymnasium.Env: + registered_name = f"{config.env.scenario.name}:{config.env.scenario.task_name}" + env = gym.make(registered_name, disable_env_checker=True, **config.env.kwargs) + wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state) + if config.system.add_agent_id: + wrapped_env = GymAgentIDWrapper(wrapped_env) + wrapped_env = GymRecordEpisodeMetrics(wrapped_env) + return wrapped_env + + envs = gymnasium.vector.AsyncVectorEnv( + [lambda: create_gym_env(config, add_global_state) for _ in range(num_env)], + worker=async_multiagent_worker, + ) + + envs = GymToJumanji(envs) + + return envs + + def make(config: DictConfig, add_global_state: bool = False) -> Tuple[MarlEnv, MarlEnv]: """ Create environments for training and evaluation. diff --git a/mava/utils/network_utils.py b/mava/utils/network_utils.py index a2949bdd3..b16c46054 100644 --- a/mava/utils/network_utils.py +++ b/mava/utils/network_utils.py @@ -12,19 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Tuple +from typing import Dict, Tuple, Union -from jumanji.specs import DiscreteArray, MultiDiscreteArray - -from mava.types import MarlEnv +from gymnasium.spaces import Discrete, MultiDiscrete, Space +from jumanji.specs import DiscreteArray, MultiDiscreteArray, Spec _DISCRETE = "discrete" _CONTINUOUS = "continuous" -def get_action_head(env: MarlEnv) -> Tuple[Dict[str, str], str]: +def get_action_head(action_types: Union[Spec, Space]) -> Tuple[Dict[str, str], str]: """Returns the appropriate action head config based on the environment action_spec.""" - if isinstance(env.action_spec(), (DiscreteArray, MultiDiscreteArray)): + if isinstance(action_types, (DiscreteArray, MultiDiscreteArray, Discrete, MultiDiscrete)): return {"_target_": "mava.networks.heads.DiscreteActionHead"}, _DISCRETE return {"_target_": "mava.networks.heads.ContinuousActionHead"}, _CONTINUOUS diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py new file mode 100644 index 000000000..dc51140f5 --- /dev/null +++ b/mava/utils/sebulba.py @@ -0,0 +1,192 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import queue +import threading +import time +from typing import Any, Dict, List, Sequence, Tuple, Union + +import jax +import jax.numpy as jnp +from colorama import Fore, Style +from jax import tree +from jax.sharding import Sharding +from jumanji.types import TimeStep + +# todo: remove the ppo dependencies when we make sebulba for other systems +from mava.systems.ppo.types import Params, PPOTransition + +QUEUE_PUT_TIMEOUT = 100 + + +class ThreadLifetime: + """Simple class for a mutable boolean that can be used to signal a thread to stop.""" + + def __init__(self) -> None: + self._stop = False + + def should_stop(self) -> bool: + return self._stop + + def stop(self) -> None: + self._stop = True + + +@jax.jit +def _stack_trajectory(trajectory: List[PPOTransition]) -> PPOTransition: + """Stack a list of parallel_env transitions into a single + transition of shape [rollout_len, num_envs, ...].""" + return tree.map(lambda *x: jnp.stack(x, axis=0).swapaxes(0, 1), *trajectory) # type: ignore + + +# Modified from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py +class Pipeline(threading.Thread): + """ + The `Pipeline` shards trajectories into learner devices, + ensuring trajectories are consumed in the right order to avoid being off-policy + and limit the max number of samples in device memory at one time to avoid OOM issues. + """ + + def __init__(self, max_size: int, learner_sharding: Sharding, lifetime: ThreadLifetime): + """ + Initializes the pipeline with a maximum size and the devices to shard trajectories across. + + Args: + max_size: The maximum number of trajectories to keep in the pipeline. + learner_sharding: The sharding used for the learner's update function. + lifetime: A `ThreadLifetime` which is used to stop this thread. + """ + super().__init__(name="Pipeline") + + self.sharding = learner_sharding + self.tickets_queue: queue.Queue = queue.Queue() + self._queue: queue.Queue = queue.Queue(maxsize=max_size) + self.lifetime = lifetime + + def run(self) -> None: + """This function ensures that trajectories on the queue are consumed in the right order. The + start_condition and end_condition are used to ensure that only 1 thread is processing an + item from the queue at one time, ensuring predictable memory usage. + """ + while not self.lifetime.should_stop(): + try: + start_condition, end_condition = self.tickets_queue.get(timeout=1) + with end_condition: + with start_condition: + start_condition.notify() + end_condition.wait() + except queue.Empty: + continue + + def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict) -> None: + """Put a trajectory on the queue to be consumed by the learner.""" + start_condition, end_condition = (threading.Condition(), threading.Condition()) + with start_condition: + self.tickets_queue.put((start_condition, end_condition)) + start_condition.wait() # wait to be allowed to start + + # [Transition(num_envs)] * rollout_len -> Transition[done=(num_envs, rollout_len, ...)] + traj = _stack_trajectory(traj) + traj, timestep = jax.device_put((traj, timestep), device=self.sharding) + + # We block on the `put` to ensure that actors wait for the learners to catch up. + # This ensures two things: + # The actors don't get too far ahead of the learners, which could lead to off-policy data. + # The actors don't "waste" samples by generating samples that the learners can't consume. + # However, we put a timeout of 100 seconds to avoid deadlocks in case the learner + # is not consuming the data. This is a safety measure and should not normally occur. + # We use a try-finally so the lock is released even if an exception is raised. + try: + self._queue.put( + (traj, timestep, time_dict), + block=True, + timeout=QUEUE_PUT_TIMEOUT, + ) + except queue.Full: + print( + f"{Fore.RED}{Style.BRIGHT}Pipeline is full and actor has timed out, " + f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}" + ) + finally: + with end_condition: + end_condition.notify() # notify that we have finished + + def qsize(self) -> int: + """Returns the number of trajectories in the pipeline.""" + return self._queue.qsize() + + def get( + self, block: bool = True, timeout: Union[float, None] = None + ) -> Tuple[PPOTransition, TimeStep, Dict]: + """Get a trajectory from the pipeline.""" + return self._queue.get(block, timeout) # type: ignore + + def clear(self) -> None: + """Clear the pipeline.""" + while not self._queue.empty(): + try: + self._queue.get(block=False) + except queue.Empty: + break + + +class ParamsSource(threading.Thread): + """A `ParamSource` is a component that allows networks params to be passed from a + `Learner` component to `Actor` components. + """ + + def __init__(self, init_value: Params, device: jax.Device, lifetime: ThreadLifetime): + super().__init__(name=f"ParamsSource-{device.id}") + self.value: Params = jax.device_put(init_value, device) + self.device = device + self.new_value: queue.Queue = queue.Queue() + self.lifetime = lifetime + + def run(self) -> None: + """This function is responsible for updating the value of the `ParamSource` when a new value + is available. + """ + while not self.lifetime.should_stop(): + try: + waiting = self.new_value.get(block=True, timeout=1) + self.value = jax.device_put(waiting, self.device) + except queue.Empty: + continue + + def update(self, new_params: Params) -> None: + """Update the value of the `ParamSource` with a new value. + + Args: + new_params: The new value to update the `ParamSource` with. + """ + self.new_value.put(new_params) + + def get(self) -> Params: + """Get the current value of the `ParamSource`.""" + return self.value + + +class RecordTimeTo: + """Context manager to record the runtime in a `with` block""" + + def __init__(self, to: Any): + self.to = to + + def __enter__(self) -> None: + self.start = time.monotonic() + + def __exit__(self, *args: Any) -> None: + end = time.monotonic() + self.to.append(end - self.start) diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index d50b316b1..fc9dadb31 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -16,7 +16,15 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.gigastep import GigastepWrapper -from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper +from mava.wrappers.gym import ( + GymAgentIDWrapper, + GymRecordEpisodeMetrics, + GymToJumanji, + SmacWrapper, + UoeWrapper, + async_multiagent_worker, +) +from mava.wrappers.jaxmarl import MabraxWrapper, MPEWrapper, SmaxWrapper from mava.wrappers.jumanji import ( CleanerWrapper, ConnectorWrapper, diff --git a/mava/wrappers/episode_metrics.py b/mava/wrappers/episode_metrics.py index e9e130819..f4c34002e 100644 --- a/mava/wrappers/episode_metrics.py +++ b/mava/wrappers/episode_metrics.py @@ -17,6 +17,7 @@ import chex import jax import jax.numpy as jnp +import numpy as np from jax import tree from jumanji.types import TimeStep from jumanji.wrappers import Wrapper @@ -120,12 +121,12 @@ def get_final_step_metrics(metrics: Dict[str, chex.Array]) -> Tuple[Dict[str, ch expects arrays for computing summary statistics on the episode metrics. """ is_final_ep = metrics.pop("is_terminal_step") - has_final_ep_step = bool(jnp.any(is_final_ep)) + has_final_ep_step = bool(np.any(is_final_ep)) final_metrics: Dict[str, chex.Array] # If it didn't make it to the final step, return zeros. if not has_final_ep_step: - final_metrics = tree.map(jnp.zeros_like, metrics) + final_metrics = tree.map(np.zeros_like, metrics) else: final_metrics = tree.map(lambda x: x[is_final_ep], metrics) diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py new file mode 100644 index 000000000..9258bde6a --- /dev/null +++ b/mava/wrappers/gym.py @@ -0,0 +1,411 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import traceback +import warnings +from dataclasses import field +from enum import IntEnum +from multiprocessing import Queue +from multiprocessing.connection import Connection +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +import gymnasium +import gymnasium.vector.async_vector_env +import numpy as np +from gymnasium import spaces +from gymnasium.spaces.utils import is_space_dtype_shape_equiv +from gymnasium.vector.utils import write_to_shared_memory +from numpy.typing import NDArray + +from mava.types import Observation, ObservationGlobalState + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + +# Filter out the warnings +warnings.filterwarnings("ignore", module="gymnasium.utils.passive_env_checker") + + +# needed to avoid host -> device transfers when calling TimeStep.last() +class StepType(IntEnum): + """Copy of Jumanji's step type but without jax arrays""" + + FIRST = 0 + MID = 1 + LAST = 2 + + +@dataclass +class TimeStep: + step_type: StepType + reward: NDArray + discount: NDArray + observation: Union[Observation, ObservationGlobalState] + extras: Dict = field(default_factory=dict) + + def first(self) -> NDArray: + return self.step_type == StepType.FIRST + + def mid(self) -> NDArray: + return self.step_type == StepType.MID + + def last(self) -> NDArray: + return self.step_type == StepType.LAST + + +class UoeWrapper(gymnasium.Wrapper): + """A base wrapper for multi-agent environments developed by the University of Edinburgh. + This wrapper is compatible with the RobotWarehouse and Level-Based Foraging environments. + """ + + def __init__( + self, + env: gymnasium.Env, + use_shared_rewards: bool = True, + add_global_state: bool = False, + ): + """Initialise the gym wrapper + Args: + env (gymnasium.env): gymnasium env instance. + use_shared_rewards (bool, optional): Use individual or shared rewards. + Defaults to False. + add_global_state (bool, optional) : Create global observations. Defaults to False. + """ + super().__init__(env) + self._env = env + self.use_shared_rewards = use_shared_rewards + self.add_global_state = add_global_state + self.num_agents = len(self._env.action_space) + self.num_actions = self._env.action_space[0].n + + # Tuple(Box(...) * N) --> Box(N, ...) + single_obs = self.observation_space[0] # type: ignore + shape = (self.num_agents, *single_obs.shape) + low = np.tile(single_obs.low, (self.num_agents, 1)) + high = np.tile(single_obs.high, (self.num_agents, 1)) + self.observation_space = spaces.Box(low=low, high=high, shape=shape, dtype=single_obs.dtype) + + # Tuple(Discrete(...) * N) --> MultiDiscrete(... * N) + self.action_space = spaces.MultiDiscrete([self.num_actions] * self.num_agents) + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[NDArray, Dict]: + if seed is not None: + self.env.unwrapped.seed(seed) + + agents_view, info = self._env.reset() + + info["action_mask"] = self.get_action_mask(info) + if self.add_global_state: + info["global_obs"] = self.get_global_obs(agents_view) + + return np.array(agents_view), info + + def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + agents_view, reward, terminated, truncated, info = self._env.step(actions) + + info["action_mask"] = self.get_action_mask(info) + if self.add_global_state: + info["global_obs"] = self.get_global_obs(agents_view) + + if self.use_shared_rewards: + reward = np.array([np.array(reward).sum()] * self.num_agents) + else: + reward = np.array(reward) + + return agents_view, reward, terminated, truncated, info + + def get_action_mask(self, info: Dict) -> NDArray: + if "action_mask" in info: + return np.array(info["action_mask"]) + return np.ones((self.num_agents, self.num_actions), dtype=np.float32) + + def get_global_obs(self, obs: NDArray) -> NDArray: + global_obs = np.concatenate(obs, axis=0) + return np.tile(global_obs, (self.num_agents, 1)) + + +class SmacWrapper(UoeWrapper): + """A wrapper that converts actions step to integers.""" + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[NDArray, Dict]: + agents_view, info = super().reset() + info["won_episode"] = info["battle_won"] + return agents_view, info + + def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + # Convert actions to integers before passing them to the environment + actions = [int(action) for action in actions] + + agents_view, reward, terminated, truncated, info = super().step(actions) + info["won_episode"] = info["battle_won"] + + return agents_view, reward, terminated, truncated, info + + def get_action_mask(self, info: Dict) -> NDArray: + return np.array(self._env.unwrapped.get_avail_actions()) + + +class GymRecordEpisodeMetrics(gymnasium.Wrapper): + """Record the episode returns and lengths.""" + + def __init__(self, env: gymnasium.Env): + super().__init__(env) + self._env = env + self.running_count_episode_return = 0.0 + self.running_count_episode_length = 0.0 + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[NDArray, Dict]: + agents_view, info = self._env.reset(seed, options) + + # Reset the metrics + self.running_count_episode_return = 0.0 + self.running_count_episode_length = 0.0 + + # Create the metrics dict + metrics = { + "episode_return": self.running_count_episode_return, + "episode_length": self.running_count_episode_length, + "is_terminal_step": False, + } + + info["metrics"] = metrics + + return agents_view, info + + def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + agents_view, reward, terminated, truncated, info = self._env.step(actions) + + self.running_count_episode_return += float(np.mean(reward)) + self.running_count_episode_length += 1 + + metrics = { + "episode_return": self.running_count_episode_return, + "episode_length": self.running_count_episode_length, + "is_terminal_step": np.logical_or(terminated, truncated).all().item(), + } + + info["metrics"] = metrics + + return agents_view, reward, terminated, truncated, info + + +class GymAgentIDWrapper(gymnasium.Wrapper): + """Add one hot agent IDs to observation.""" + + def __init__(self, env: gymnasium.Env): + super().__init__(env) + + self.agent_ids = np.eye(self.env.num_agents) + self.observation_space = self.modify_space(self.env.observation_space) + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[NDArray, Dict]: + """Reset the environment.""" + obs, info = self.env.reset(seed, options) + obs = np.concatenate([self.agent_ids, obs], axis=1) + return obs, info + + def step(self, action: list) -> Tuple[NDArray, float, bool, bool, Dict]: + """Step the environment.""" + obs, reward, terminated, truncated, info = self.env.step(action) + obs = np.concatenate([self.agent_ids, obs], axis=1) + return obs, reward, terminated, truncated, info + + def modify_space(self, space: spaces.Space) -> spaces.Space: + if isinstance(space, spaces.Box): + new_shape = (space.shape[0], space.shape[1] + self.env.num_agents) + high = np.concatenate((space.high, np.ones_like(self.agent_ids)), axis=1) + low = np.concatenate((space.low, np.zeros_like(self.agent_ids)), axis=1) + return spaces.Box(low=low, high=high, shape=new_shape, dtype=space.dtype) + elif isinstance(space, spaces.Tuple): + return spaces.Tuple(self.modify_space(s) for s in space) + else: + raise ValueError(f"Space {type(space)} is not currently supported.") + + +class GymToJumanji: + """Converts from the Gym API to the Jumanji API.""" + + def __init__(self, env: gymnasium.vector.VectorEnv): + self.env = env + self.single_action_space = env.unwrapped.single_action_space + self.single_observation_space = env.unwrapped.single_observation_space + + def reset(self, seed: Optional[list[int]] = None, options: Optional[dict] = None) -> TimeStep: + obs, info = self.env.reset(seed=seed, options=options) # type: ignore + + num_agents = len(self.env.single_action_space) # type: ignore + num_envs = self.env.num_envs + + step_type = np.full(num_envs, StepType.FIRST) + rewards = np.zeros((num_envs, num_agents), dtype=float) + teminated = np.zeros(num_envs, dtype=float) + + timestep = self._create_timestep(obs, step_type, teminated, rewards, info) + + return timestep + + def step(self, action: list) -> TimeStep: + obs, rewards, terminated, truncated, info = self.env.step(action) + + ep_done = np.logical_or(terminated, truncated) + step_type = np.where(ep_done, StepType.LAST, StepType.MID) + + timestep = self._create_timestep(obs, step_type, terminated, rewards, info) + + return timestep + + def _format_observation( + self, obs: NDArray, info: Dict + ) -> Union[Observation, ObservationGlobalState]: + """Create an observation from the raw observation and environment state.""" + + action_mask = np.stack(info["action_mask"]) + obs_data = {"agents_view": obs, "action_mask": action_mask} + + if "global_obs" in info: + global_obs = np.array(info["global_obs"]) + obs_data["global_state"] = global_obs + return ObservationGlobalState(**obs_data) + else: + return Observation(**obs_data) + + def _create_timestep( + self, obs: NDArray, step_type: NDArray, terminated: NDArray, rewards: NDArray, info: Dict + ) -> TimeStep: + observation = self._format_observation(obs, info) + # Filter out the masks and auxiliary data + extras = {} + extras["episode_metrics"] = { + key: value for key, value in info["metrics"].items() if key[0] != "_" + } + if "won_episode" in info: + extras["won_episode"] = info["won_episode"] + + return TimeStep( + step_type=step_type, # type: ignore + reward=rewards, + discount=1.0 - terminated, + observation=observation, + extras=extras, + ) + + def close(self) -> None: + self.env.close() + + +# Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py +# Modified to work with multiple agents +# Note: The worker handles auto-resetting the environments. +# Each environment resets when all of its agents have either terminated or been truncated. +def async_multiagent_worker( # CCR001 + index: int, + env_fn: Callable, + pipe: Connection, + parent_pipe: Connection, + shared_memory: Union[NDArray, dict[str, Any], tuple[Any, ...]], + error_queue: Queue, +) -> None: + env = env_fn() + observation_space = env.observation_space + action_space = env.action_space + parent_pipe.close() + + try: + while True: + command, data = pipe.recv() + + if command == "reset": + observation, info = env.reset(**data) + if shared_memory: + write_to_shared_memory(observation_space, index, observation, shared_memory) + observation = None + pipe.send(((observation, info), True)) + elif command == "step": + # Modified the step function to align with 'AutoResetWrapper'. + # The environment resets immediately upon termination or truncation. + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + if np.logical_or(terminated, truncated).all(): + observation, new_info = env.reset() + info["action_mask"] = new_info["action_mask"] + + if shared_memory: + write_to_shared_memory(observation_space, index, observation, shared_memory) + observation = None + + pipe.send(((observation, reward, terminated, truncated, info), True)) + elif command == "close": + pipe.send((None, True)) + break + elif command == "_call": + name, args, kwargs = data + if name in ["reset", "step", "close", "_setattr", "_check_spaces"]: + raise ValueError( + f"Trying to call function `{name}` with \ + `call`, use `{name}` directly instead." + ) + + attr = env.get_wrapper_attr(name) + if callable(attr): + pipe.send((attr(*args, **kwargs), True)) + else: + pipe.send((attr, True)) + elif command == "_setattr": + name, value = data + env.set_wrapper_attr(name, value) + pipe.send((None, True)) + elif command == "_check_spaces": + obs_mode, single_obs_space, single_action_space = data + pipe.send( + ( + ( + ( + single_obs_space == observation_space + if obs_mode == "same" + else is_space_dtype_shape_equiv(single_obs_space, observation_space) + ), + single_action_space == action_space, + ), + True, + ) + ) + else: + raise RuntimeError( + f"Received unknown command `{command}`. Must be one of \ + [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]." + ) + except (KeyboardInterrupt, Exception): + error_type, error_message, _ = sys.exc_info() + trace = traceback.format_exc() + + error_queue.put((index, error_type, error_message, trace)) + pipe.send((None, False)) + finally: + env.close() diff --git a/mava/wrappers/jaxmarl.py b/mava/wrappers/jaxmarl.py index aa343b3ad..23322fc0c 100644 --- a/mava/wrappers/jaxmarl.py +++ b/mava/wrappers/jaxmarl.py @@ -27,6 +27,7 @@ from jaxmarl.environments import SMAX from jaxmarl.environments import spaces as jaxmarl_spaces from jaxmarl.environments.mabrax import MABraxEnv +from jaxmarl.environments.mpe.simple_spread import SimpleSpreadMPE from jaxmarl.environments.multi_agent_env import MultiAgentEnv from jumanji import specs from jumanji.types import StepType, TimeStep, restart @@ -139,13 +140,13 @@ def jaxmarl_space_to_jumanji_spec(space: jaxmarl_spaces.Space) -> specs.Spec: ) elif _is_dict(space): # Jumanji needs something to hold the specs - contructor = namedtuple("SubSpace", list(space.spaces.keys())) # type: ignore + constructor = namedtuple("SubSpace", list(space.spaces.keys())) # type: ignore # Recursively convert spaces to specs sub_specs = { sub_space_name: jaxmarl_space_to_jumanji_spec(sub_space) for sub_space_name, sub_space in space.spaces.items() } - return specs.Spec(constructor=contructor, name="", **sub_specs) + return specs.Spec(constructor=constructor, name="", **sub_specs) elif _is_tuple(space): # Jumanji needs something to hold the specs field_names = [f"sub_space_{i}" for i in range(len(space.spaces))] @@ -214,7 +215,6 @@ def reset( def step( self, state: JaxMarlState, action: Array ) -> Tuple[JaxMarlState, TimeStep[Union[Observation, ObservationGlobalState]]]: - # todo: how do you know if it's a truncation with only dones? key, step_key = jax.random.split(state.key) obs, env_state, reward, done, _ = self._env.step( step_key, state.state, unbatchify(action, self.agents) @@ -407,3 +407,37 @@ def get_global_state(self, wrapped_env_state: BraxState, obs: Dict[str, Array]) """Get global state from observation and copy it for each agent.""" # Use the global state of brax. return jnp.tile(wrapped_env_state.obs, (self.num_agents, 1)) + + +class MPEWrapper(JaxMarlWrapper): + """Wrapper for the MPE environment.""" + + def __init__( + self, + env: SimpleSpreadMPE, + has_global_state: bool = False, + ): + super().__init__(env, has_global_state, env.max_steps) + self._env: SimpleSpreadMPE + + @cached_property + def action_dim(self) -> chex.Array: + "Get the actions dim for each agent." + # Adjusted automatically based on the action_type specified in the kwargs. + if _is_discrete(self._env.action_space(self.agents[0])): + return self._env.action_space(self.agents[0]).n + return self._env.action_space(self.agents[0]).shape[0] + + @cached_property + def state_size(self) -> chex.Array: + "Get the state size of the global observation" + return self._env.observation_space(self.agents[0]).shape[0] * self.num_agents + + def action_mask(self, wrapped_env_state: Any) -> Array: + """Get action mask for each agent.""" + return jnp.ones((self.num_agents, self.action_dim), dtype=bool) + + def get_global_state(self, wrapped_env_state: Any, obs: Dict[str, Array]) -> Array: + """Get global state from observation and copy it for each agent.""" + global_state = jnp.concatenate([obs[agent_id] for agent_id in obs]) + return jnp.tile(global_state, (self.num_agents, 1)) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index b7395c389..e004a3c23 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,12 +4,14 @@ distrax flashbax~=0.1.0 flax>=0.8.1 gigastep @ git+https://github.com/mlech26l/gigastep +gymnasium hydra-core==1.3.2 id-marl-eval @ git+https://github.com/instadeepai/marl-eval jax==0.4.30 jaxlib==0.4.30 -jaxmarl +jaxmarl @ git+https://github.com/RuanJohn/JaxMARL@unpin-jax # This only unpins the version of Jax. jumanji @ git+https://github.com/sash-a/jumanji@old_jumanji # Includes a few extra MARL envs +lbforaging matrax @ git+https://github.com/instadeepai/matrax@4c5d8aa97214848ea659274f16c48918c13e845b mujoco==3.1.3 mujoco-mjx==3.1.3 @@ -18,7 +20,9 @@ numpy==1.26.4 omegaconf optax protobuf~=3.20 +rware scipy==1.12.0 +smaclite @ git+https://github.com/uoe-agents/smaclite.git tensorboard_logger tensorflow_probability type_enforced # needed because gigastep is missing this dependency diff --git a/test/integration_test.py b/test/integration_test.py index 419cf5799..6311e938a 100644 --- a/test/integration_test.py +++ b/test/integration_test.py @@ -41,7 +41,7 @@ discrete_envs = ["gigastep", "lbf", "matrax", "rware", "smax", "vector-connector"] cnn_envs = ["cleaner", "connector"] -continuous_envs = ["mabrax"] +continuous_envs = ["mabrax", "mpe"] def _run_system(system_name: str, cfg: DictConfig) -> float: @@ -82,7 +82,7 @@ def test_ppo_system(fast_config: dict, system_path: str) -> None: def test_sable_system(fast_config: dict, system_path: str) -> None: """Test all sable systems on random envs.""" _, _, system_name = system_path.split(".") - env = random.choice(discrete_envs) + env = random.choice(continuous_envs + discrete_envs) with initialize(version_base=None, config_path=config_path): cfg = compose(config_name=f"{system_name}", overrides=[f"env={env}"]) @@ -159,15 +159,13 @@ def test_discrete_cnn_env(fast_config: dict, env_name: str) -> None: _run_system(system_path, cfg) -# leaving this here for the future if we have some new continuous envs -@pytest.mark.skip(reason="MaBrax is the only continuous env and already tested in test_mava_system") @pytest.mark.parametrize("env_name", continuous_envs) def test_continuous_env(fast_config: dict, env_name: str) -> None: """Test all continuous envs on random systems.""" system_path = random.choice(ppo_systems + sac_systems) _, _, system_name = system_path.split(".") - overrides = [f"env={env_name}"] + with initialize(version_base=None, config_path=config_path): cfg = compose(config_name=f"{system_name}", overrides=overrides) cfg = _get_fast_config(cfg, fast_config)