From a723392b23e02dc015e9a43a5fd47008e6bf6f1c Mon Sep 17 00:00:00 2001 From: Louay Ben Nessir Date: Wed, 13 Nov 2024 09:38:27 +0100 Subject: [PATCH] fix: sebulba compatiable get_action_head --- mava/evaluator.py | 3 ++- mava/systems/ppo/anakin/ff_ippo.py | 2 +- mava/systems/ppo/anakin/ff_mappo.py | 2 +- mava/systems/ppo/anakin/rec_ippo.py | 2 +- mava/systems/ppo/anakin/rec_mappo.py | 2 +- mava/systems/ppo/sebulba/ff_ippo.py | 7 ++++--- mava/utils/make_env.py | 2 +- mava/utils/network_utils.py | 12 ++++++------ 8 files changed, 17 insertions(+), 15 deletions(-) diff --git a/mava/evaluator.py b/mava/evaluator.py index 21037c2c3..e1b35b7d9 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -37,6 +37,7 @@ RecActorApply, State, ) +from mava.wrappers.gym import GymToJumanji # Optional extras that are passed out of the actor and then into the actor in the next step ActorState: TypeAlias = Dict[str, Any] @@ -211,7 +212,7 @@ def eval_act_fn( def get_sebulba_eval_fn( - env_maker: Callable, + env_maker: Callable[[int, int], GymToJumanji], act_fn: EvalActFn, config: DictConfig, np_rng: np.random.Generator, diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index 698c505b2..201bd5fc0 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -362,7 +362,7 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index 3103cc164..680e6361a 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -346,7 +346,7 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index b936262ff..182382ac6 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -457,7 +457,7 @@ def learner_setup( # Define network and optimisers. actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index f1105fe73..671d5cbc5 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -452,7 +452,7 @@ def learner_setup( # Define network and optimiser. actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - action_head, _ = get_action_head(env) + action_head, _ = get_action_head(env.action_spec()) actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 468957c46..76d133985 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -54,6 +54,7 @@ from mava.utils.config import check_sebulba_config, check_total_timesteps from mava.utils.jax_utils import merge_leading_dims, switch_leading_axes from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.network_utils import get_action_head from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -466,9 +467,9 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - actor_action_head = hydra.utils.instantiate( - config.network.action_head, action_dim=config.system.num_actions - ) + action_head, _ = get_action_head(action_space) + actor_action_head = hydra.utils.instantiate(action_head, action_dim=config.system.num_actions) + critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) actor_network = Actor(torso=actor_torso, action_head=actor_action_head) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 8794093ac..e0360c706 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -55,8 +55,8 @@ SmacWrapper, SmaxWrapper, UoeWrapper, - async_multiagent_worker, VectorConnectorWrapper, + async_multiagent_worker, ) # Registry mapping environment names to their generator and wrapper classes. diff --git a/mava/utils/network_utils.py b/mava/utils/network_utils.py index a2949bdd3..03a7e439f 100644 --- a/mava/utils/network_utils.py +++ b/mava/utils/network_utils.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Tuple +from typing import Dict, Tuple, Union -from jumanji.specs import DiscreteArray, MultiDiscreteArray +from jumanji.specs import DiscreteArray, MultiDiscreteArray, Spec +from gymnasium.spaces import Discrete, MultiDiscrete, Space -from mava.types import MarlEnv _DISCRETE = "discrete" _CONTINUOUS = "continuous" -def get_action_head(env: MarlEnv) -> Tuple[Dict[str, str], str]: +def get_action_head(action_types: Union[Spec, Space]) -> Tuple[Dict[str, str], str]: """Returns the appropriate action head config based on the environment action_spec.""" - if isinstance(env.action_spec(), (DiscreteArray, MultiDiscreteArray)): + if isinstance(action_types, (DiscreteArray, MultiDiscreteArray, Discrete, MultiDiscrete)): return {"_target_": "mava.networks.heads.DiscreteActionHead"}, _DISCRETE - return {"_target_": "mava.networks.heads.ContinuousActionHead"}, _CONTINUOUS + return {"_target_": "mava.networks.heads.ContinuousActionHead"}, _CONTINUOUS \ No newline at end of file