Skip to content

Commit

Permalink
feat: use action_sepc to select action head type
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi committed Oct 18, 2024
1 parent 60d8ffa commit f1cb0f2
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 41 deletions.
9 changes: 3 additions & 6 deletions mava/advanced_usage/ff_ippo_store_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,19 +352,16 @@ def learner_setup(
n_devices = len(jax.devices())

# Get number of actions and agents.
num_actions = int(env.action_spec().num_values[0])
num_agents = env.action_spec().shape[0]
config.system.num_agents = num_agents
num_actions = env.action_dim
config.system.num_agents = env.num_agents
config.system.num_actions = num_actions

# PRNG keys.
key, key_p = keys

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
get_action_head(config.env.env_name), action_dim=num_actions
)
actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=num_actions)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

actor_network = Actor(torso=actor_torso, action_head=actor_action_head)
Expand Down
4 changes: 1 addition & 3 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
get_action_head(config.env.env_name), action_dim=env.action_dim
)
actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

actor_network = Actor(torso=actor_torso, action_head=actor_action_head)
Expand Down
4 changes: 1 addition & 3 deletions mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
get_action_head(config.env.env_name), action_dim=env.action_dim
)
actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

actor_network = Actor(torso=actor_torso, action_head=actor_action_head)
Expand Down
4 changes: 1 addition & 3 deletions mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ def learner_setup(
# Define network and optimisers.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
actor_action_head = hydra.utils.instantiate(
get_action_head(config.env.env_name), action_dim=env.action_dim
)
actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)

Expand Down
4 changes: 1 addition & 3 deletions mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,7 @@ def learner_setup(
# Define network and optimiser.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
actor_action_head = hydra.utils.instantiate(
get_action_head(config.env.env_name), action_dim=env.action_dim
)
actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)

Expand Down
4 changes: 1 addition & 3 deletions mava/systems/sac/anakin/ff_isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ def replicate(x: Any) -> Any:

# Making actor network
actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
get_action_head(cfg.env.env_name), action_dim=env.action_dim
)
actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim)
actor_network = Actor(actor_torso, actor_action_head)
actor_params = actor_network.init(actor_key, obs_single_batched)

Expand Down
4 changes: 1 addition & 3 deletions mava/systems/sac/anakin/ff_masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ def replicate(x: Any) -> Any:

# Making actor network
actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
get_action_head(cfg.env.env_name), action_dim=env.action_dim
)
actor_action_head = hydra.utils.instantiate(get_action_head(env), action_dim=env.action_dim)
actor_network = Actor(actor_torso, actor_action_head)
actor_params = actor_network.init(actor_key, obs_single_batched)

Expand Down
26 changes: 9 additions & 17 deletions mava/utils/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Mapping of environments to their respective action heads.
action_head_per_env = {
"Cleaner": "mava.networks.heads.DiscreteActionHead",
"MaConnector": "mava.networks.heads.DiscreteActionHead",
"LevelBasedForaging": "mava.networks.heads.DiscreteActionHead",
"Matrax": "mava.networks.heads.DiscreteActionHead",
"RobotWarehouse": "mava.networks.heads.DiscreteActionHead",
"Smax": "mava.networks.heads.DiscreteActionHead",
"Gigastep": "mava.networks.heads.DiscreteActionHead",
"MaBrax": "mava.networks.heads.ContinuousActionHead",
}
from typing import Dict

from jumanji.specs import DiscreteArray, MultiDiscreteArray

def get_action_head(env_name: str) -> dict:
"""Returns the appropriate action head config based on the environment name."""
action_head = action_head_per_env.get(env_name)
from mava.types import MarlEnv

if action_head is None:
raise ValueError(f"Environment {env_name} is not recognized.")

return {"_target_": action_head}
def get_action_head(env: MarlEnv) -> Dict[str, str]:
"""Returns the appropriate action head config based on the environment action_spec."""
if isinstance(env.action_spec(), (DiscreteArray, MultiDiscreteArray)):
return {"_target_": "mava.networks.heads.DiscreteActionHead"}

return {"_target_": "mava.networks.heads.ContinuousActionHead"}

0 comments on commit f1cb0f2

Please sign in to comment.