Skip to content

Commit

Permalink
fix: sebulba compatiable get_action_head
Browse files Browse the repository at this point in the history
  • Loading branch information
Louay-Ben-nessir committed Nov 13, 2024
1 parent 649b93b commit a723392
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 15 deletions.
3 changes: 2 additions & 1 deletion mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
RecActorApply,
State,
)
from mava.wrappers.gym import GymToJumanji

# Optional extras that are passed out of the actor and then into the actor in the next step
ActorState: TypeAlias = Dict[str, Any]
Expand Down Expand Up @@ -211,7 +212,7 @@ def eval_act_fn(


def get_sebulba_eval_fn(
env_maker: Callable,
env_maker: Callable[[int, int], GymToJumanji],
act_fn: EvalActFn,
config: DictConfig,
np_rng: np.random.Generator,
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def learner_setup(
# Define network and optimisers.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def learner_setup(
# Define network and optimiser.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down
7 changes: 4 additions & 3 deletions mava/systems/ppo/sebulba/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from mava.utils.config import check_sebulba_config, check_total_timesteps
from mava.utils.jax_utils import merge_leading_dims, switch_leading_axes
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -466,9 +467,9 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_action_head = hydra.utils.instantiate(
config.network.action_head, action_dim=config.system.num_actions
)
action_head, _ = get_action_head(action_space)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=config.system.num_actions)

critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

actor_network = Actor(torso=actor_torso, action_head=actor_action_head)
Expand Down
2 changes: 1 addition & 1 deletion mava/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
SmacWrapper,
SmaxWrapper,
UoeWrapper,
async_multiagent_worker,
VectorConnectorWrapper,
async_multiagent_worker,
)

# Registry mapping environment names to their generator and wrapper classes.
Expand Down
12 changes: 6 additions & 6 deletions mava/utils/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Tuple
from typing import Dict, Tuple, Union

from jumanji.specs import DiscreteArray, MultiDiscreteArray
from jumanji.specs import DiscreteArray, MultiDiscreteArray, Spec
from gymnasium.spaces import Discrete, MultiDiscrete, Space

from mava.types import MarlEnv

_DISCRETE = "discrete"
_CONTINUOUS = "continuous"


def get_action_head(env: MarlEnv) -> Tuple[Dict[str, str], str]:
def get_action_head(action_types: Union[Spec, Space]) -> Tuple[Dict[str, str], str]:
"""Returns the appropriate action head config based on the environment action_spec."""
if isinstance(env.action_spec(), (DiscreteArray, MultiDiscreteArray)):
if isinstance(action_types, (DiscreteArray, MultiDiscreteArray, Discrete, MultiDiscrete)):
return {"_target_": "mava.networks.heads.DiscreteActionHead"}, _DISCRETE

return {"_target_": "mava.networks.heads.ContinuousActionHead"}, _CONTINUOUS
return {"_target_": "mava.networks.heads.ContinuousActionHead"}, _CONTINUOUS

0 comments on commit a723392

Please sign in to comment.