Skip to content

Commit

Permalink
feat: set the action head automatically based on env name
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi committed Oct 18, 2024
1 parent c4e40ce commit 8ac214a
Show file tree
Hide file tree
Showing 17 changed files with 52 additions and 40 deletions.
2 changes: 1 addition & 1 deletion mava/configs/default/ff_ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/ff_ippo
- network: mlp # [mlp, continuous_mlp, cnn]
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax]
- _self_

Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_isac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- logger: logger
- arch: anakin
- system: sac/ff_isac
- network: continuous_mlp # [continuous_mlp]
- network: mlp
- env: mabrax # [mabrax]

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_mappo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/ff_mappo
- network: mlp # [mlp, continuous_mlp, cnn]
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax]
- _self_

Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_masac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- logger: logger
- arch: anakin
- system: sac/ff_masac
- network: continuous_mlp # [continuous_mlp]
- network: mlp
- env: mabrax # [mabrax]

hydra:
Expand Down
3 changes: 0 additions & 3 deletions mava/configs/network/cnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ actor_network:
use_layer_norm: False
activation: relu

action_head:
_target_: mava.networks.heads.DiscreteActionHead # [DiscreteActionHead, ContinuousActionHead]

critic_network:
pre_torso:
_target_: mava.networks.torsos.CNNTorso
Expand Down
17 changes: 0 additions & 17 deletions mava/configs/network/continuous_mlp.yaml

This file was deleted.

3 changes: 0 additions & 3 deletions mava/configs/network/mlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ actor_network:
use_layer_norm: False
activation: relu

action_head:
_target_: mava.networks.heads.DiscreteActionHead # [DiscreteActionHead, ContinuousActionHead]

critic_network:
pre_torso:
_target_: mava.networks.torsos.MLPTorso
Expand Down
3 changes: 0 additions & 3 deletions mava/configs/network/rcnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ actor_network:
use_layer_norm: False
activation: relu

action_head:
_target_: mava.networks.heads.DiscreteActionHead # [DiscreteActionHead, ContinuousActionHead]

critic_network:
pre_torso:
_target_: mava.networks.torsos.CNNTorso
Expand Down
3 changes: 0 additions & 3 deletions mava/configs/network/rnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ actor_network:
use_layer_norm: False
activation: relu

action_head:
_target_: mava.networks.heads.DiscreteActionHead # [DiscreteActionHead, ContinuousActionHead]

critic_network:
pre_torso:
_target_: mava.networks.torsos.MLPTorso
Expand Down
3 changes: 2 additions & 1 deletion mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
unreplicate_n_dims,
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -362,7 +363,7 @@ def learner_setup(
# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
config.network.action_head, action_dim=env.action_dim
get_action_head(config.env.env_name), action_dim=env.action_dim
)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
3 changes: 2 additions & 1 deletion mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -346,7 +347,7 @@ def learner_setup(
# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
config.network.action_head, action_dim=env.action_dim
get_action_head(config.env.env_name), action_dim=env.action_dim
)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
3 changes: 2 additions & 1 deletion mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -457,7 +458,7 @@ def learner_setup(
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
actor_action_head = hydra.utils.instantiate(
config.network.action_head, action_dim=env.action_dim
get_action_head(config.env.env_name), action_dim=env.action_dim
)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down
3 changes: 2 additions & 1 deletion mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -452,7 +453,7 @@ def learner_setup(
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
actor_action_head = hydra.utils.instantiate(
config.network.action_head, action_dim=env.action_dim
get_action_head(config.env.env_name), action_dim=env.action_dim
)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down
3 changes: 2 additions & 1 deletion mava/systems/sac/anakin/ff_isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.wrappers import episode_metrics

Expand Down Expand Up @@ -111,7 +112,7 @@ def replicate(x: Any) -> Any:
# Making actor network
actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
cfg.network.action_head, action_dim=action_dim, independent_std=False
get_action_head(cfg.env.env_name), action_dim=env.action_dim
)
actor_network = Actor(actor_torso, actor_action_head)
actor_params = actor_network.init(actor_key, obs_single_batched)
Expand Down
3 changes: 2 additions & 1 deletion mava/systems/sac/anakin/ff_masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.wrappers import episode_metrics

Expand Down Expand Up @@ -114,7 +115,7 @@ def replicate(x: Any) -> Any:
# Making actor network
actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
cfg.network.action_head, action_dim=action_dim, independent_std=False
get_action_head(cfg.env.env_name), action_dim=env.action_dim
)
actor_network = Actor(actor_torso, actor_action_head)
actor_params = actor_network.init(actor_key, obs_single_batched)
Expand Down
35 changes: 35 additions & 0 deletions mava/utils/network_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Mapping of environments to their respective action heads.
action_head_per_env = {
"Cleaner": "mava.networks.heads.DiscreteActionHead",
"MaConnector": "mava.networks.heads.DiscreteActionHead",
"LevelBasedForaging": "mava.networks.heads.DiscreteActionHead",
"Matrax": "mava.networks.heads.DiscreteActionHead",
"RobotWarehouse": "mava.networks.heads.DiscreteActionHead",
"Smax": "mava.networks.heads.DiscreteActionHead",
"Gigastep": "mava.networks.heads.DiscreteActionHead",
"MaBrax": "mava.networks.heads.ContinuousActionHead",
}


def get_action_head(env_name: str) -> dict:
"""Returns the appropriate action head config based on the environment name."""
action_head = action_head_per_env.get(env_name)

if action_head is None:
raise ValueError(f"Environment {env_name} is not recognized.")

return {"_target_": action_head}
2 changes: 1 addition & 1 deletion test/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_continuous_env(fast_config: dict, env_name: str) -> None:
system_path = random.choice(ppo_systems + sac_systems)
_, _, system_name = system_path.split(".")

overrides = [f"env={env_name}", "network=continuous_mlp"]
overrides = [f"env={env_name}"]
with initialize(version_base=None, config_path=config_path):
cfg = compose(config_name=f"{system_name}", overrides=overrides)
cfg = _get_fast_config(cfg, fast_config)
Expand Down

0 comments on commit 8ac214a

Please sign in to comment.