diff --git a/mava/advanced_usage/ff_ippo_store_experience.py b/mava/advanced_usage/ff_ippo_store_experience.py index b967adca7..fa84a70f8 100644 --- a/mava/advanced_usage/ff_ippo_store_experience.py +++ b/mava/advanced_usage/ff_ippo_store_experience.py @@ -352,9 +352,8 @@ def learner_setup( n_devices = len(jax.devices()) # Get number of actions and agents. - num_actions = int(env.action_spec().num_values[0]) - num_agents = env.action_spec().shape[0] - config.system.num_agents = num_agents + num_actions = env.action_dim + config.system.num_agents = env.num_agents config.system.num_actions = num_actions # PRNG keys. @@ -362,9 +361,7 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - actor_action_head = hydra.utils.instantiate( - get_action_head(config.env.env_name), action_dim=num_actions - ) + actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=num_actions) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) actor_network = Actor(torso=actor_torso, action_head=actor_action_head) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index bd7358f12..9fed2cea9 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -362,9 +362,7 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - actor_action_head = hydra.utils.instantiate( - get_action_head(config.env.env_name), action_dim=env.action_dim - ) + actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) actor_network = Actor(torso=actor_torso, action_head=actor_action_head) diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index 7040dc088..4e30679d0 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -346,9 +346,7 @@ def learner_setup( # Define network and optimiser. actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - actor_action_head = hydra.utils.instantiate( - get_action_head(config.env.env_name), action_dim=env.action_dim - ) + actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim) critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) actor_network = Actor(torso=actor_torso, action_head=actor_action_head) diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index 5ef027028..4b70f6ec7 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -457,9 +457,7 @@ def learner_setup( # Define network and optimisers. actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - actor_action_head = hydra.utils.instantiate( - get_action_head(config.env.env_name), action_dim=env.action_dim - ) + actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim) critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index 38531110f..381e3759f 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -452,9 +452,7 @@ def learner_setup( # Define network and optimiser. actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - actor_action_head = hydra.utils.instantiate( - get_action_head(config.env.env_name), action_dim=env.action_dim - ) + actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim) critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) diff --git a/mava/systems/sac/anakin/ff_isac.py b/mava/systems/sac/anakin/ff_isac.py index 06a65d0ae..9971d9a88 100644 --- a/mava/systems/sac/anakin/ff_isac.py +++ b/mava/systems/sac/anakin/ff_isac.py @@ -111,9 +111,7 @@ def replicate(x: Any) -> Any: # Making actor network actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso) - actor_action_head = hydra.utils.instantiate( - get_action_head(cfg.env.env_name), action_dim=env.action_dim - ) + actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim) actor_network = Actor(actor_torso, actor_action_head) actor_params = actor_network.init(actor_key, obs_single_batched) diff --git a/mava/systems/sac/anakin/ff_masac.py b/mava/systems/sac/anakin/ff_masac.py index f72cc7dea..865352d59 100644 --- a/mava/systems/sac/anakin/ff_masac.py +++ b/mava/systems/sac/anakin/ff_masac.py @@ -114,9 +114,7 @@ def replicate(x: Any) -> Any: # Making actor network actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso) - actor_action_head = hydra.utils.instantiate( - get_action_head(cfg.env.env_name), action_dim=env.action_dim - ) + actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim) actor_network = Actor(actor_torso, actor_action_head) actor_params = actor_network.init(actor_key, obs_single_batched) diff --git a/mava/utils/network_utils.py b/mava/utils/network_utils.py index 1a7262e4a..3591beede 100644 --- a/mava/utils/network_utils.py +++ b/mava/utils/network_utils.py @@ -12,24 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Mapping of environments to their respective action heads. -action_head_per_env = { - "Cleaner": "mava.networks.heads.DiscreteActionHead", - "MaConnector": "mava.networks.heads.DiscreteActionHead", - "LevelBasedForaging": "mava.networks.heads.DiscreteActionHead", - "Matrax": "mava.networks.heads.DiscreteActionHead", - "RobotWarehouse": "mava.networks.heads.DiscreteActionHead", - "Smax": "mava.networks.heads.DiscreteActionHead", - "Gigastep": "mava.networks.heads.DiscreteActionHead", - "MaBrax": "mava.networks.heads.ContinuousActionHead", -} +from typing import Dict +from jumanji.specs import DiscreteArray, MultiDiscreteArray -def get_action_head(env_name: str) -> dict: - """Returns the appropriate action head config based on the environment name.""" - action_head = action_head_per_env.get(env_name) +from mava.types import MarlEnv - if action_head is None: - raise ValueError(f"Environment {env_name} is not recognized.") - return {"_target_": action_head} +def get_action_head(env: MarlEnv) -> Dict[str, str]: + """Returns the appropriate action head config based on the environment action_spec.""" + if isinstance(env.action_spec(), (DiscreteArray, MultiDiscreteArray)): + return {"_target_": "mava.networks.heads.DiscreteActionHead"} + + return {"_target_": "mava.networks.heads.ContinuousActionHead"}