Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/feat/sebulba_arch' into seb-ff…
Browse files Browse the repository at this point in the history
…-ippo-only
  • Loading branch information
Louay-Ben-nessir committed Jul 23, 2024
2 parents e09fd60 + 6a1fad4 commit 0cae539
Show file tree
Hide file tree
Showing 49 changed files with 2,069 additions and 2,193 deletions.
41 changes: 13 additions & 28 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ default_stages: [ "commit", "commit-msg", "push" ]
default_language_version:
python: python3


repos:
- repo: https://github.com/timothycrosley/isort
rev: 5.13.2
hooks:
- id: isort

- repo: https://github.com/psf/black
rev: 24.2.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.8
hooks:
- id: black
name: "Code formatter"
# Run the linter.
- id: ruff
types_or: [ python ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi, jupyter ]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: end-of-file-fixer
name: "End of file fixer"
Expand All @@ -32,21 +32,6 @@ repos:
- id: trailing-whitespace
name: "Trailing whitespace fixer"

- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
name: "Linter"
additional_dependencies:
- pep8-naming
- flake8-builtins
- flake8-comprehensions
- flake8-bugbear
- flake8-pytest-style
- flake8-cognitive-complexity
- flake8-pyproject
- importlib-metadata<5.0

- repo: local
hooks:
- id: mypy
Expand All @@ -57,15 +42,15 @@ repos:
pass_filenames: false

- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
rev: v9.11.0
rev: v9.16.0
hooks:
- id: commitlint
name: "Commit linter"
stages: [ commit-msg ]
additional_dependencies: [ '@commitlint/config-conventional' ]

- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.3.0
rev: v1.5.5
hooks:
- id: insert-license
name: "License inserter"
Expand Down
2,394 changes: 1,204 additions & 1,190 deletions examples/Quickstart.ipynb

Large diffs are not rendered by default.

52 changes: 27 additions & 25 deletions mava/advanced_usage/ff_ippo_store_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@
from colorama import Fore, Style
from flashbax.vault import Vault
from flax.core.frozen_dict import FrozenDict
from jumanji.env import Environment
from omegaconf import DictConfig, OmegaConf
from optax._src.base import OptState
from rich.pretty import pprint

from mava.evaluator import make_anakin_eval_fns
from mava.evaluator import get_eval_fn, make_ff_eval_act_fn
from mava.networks import FeedForwardActor as Actor
from mava.networks import FeedForwardValueNet as Critic
from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition
from mava.types import ActorApply, CriticApply, ExperimentOutput, MavaState
from mava.types import ActorApply, CriticApply, ExperimentOutput, MarlEnv, MavaState
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import (
merge_leading_dims,
Expand All @@ -55,13 +54,12 @@


def get_learner_fn(
env: Environment,
env: MarlEnv,
apply_fns: Tuple[ActorApply, CriticApply],
update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn],
config: DictConfig,
) -> StoreExpLearnerFn[LearnerState]:
"""Get the learner function."""

# Get apply and update functions for actor and critic networks.
actor_apply_fn, critic_apply_fn = apply_fns
actor_update_fn, critic_update_fn = update_fns
Expand All @@ -75,13 +73,15 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup
losses.
Args:
----
learner_state (NamedTuple):
- params (Params): The current model parameters.
- opt_states (OptStates): The current optimizer states.
- key (PRNGKey): The random number generator state.
- env_state (State): The environment state.
- last_timestep (TimeStep): The last timestep in the current trajectory.
_ (Any): The current metrics info.
"""

def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]:
Expand Down Expand Up @@ -155,7 +155,6 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple:

def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
"""Update the network for a single minibatch."""

# UNPACK TRAIN STATE AND BATCH INFO
params, opt_states = train_state
traj_batch, advantages, targets = batch_info
Expand Down Expand Up @@ -285,7 +284,7 @@ def _critic_loss_fn(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])),
lambda x: jnp.reshape(x, (config.system.num_minibatches, -1, *x.shape[1:])),
shuffled_batch,
)

Expand Down Expand Up @@ -319,14 +318,15 @@ def learner_fn(
updates. The `_update_step` function is vectorized over a batch of inputs.
Args:
----
learner_state (NamedTuple):
- params (Params): The initial model parameters.
- opt_states (OptStates): The initial optimizer state.
- key (chex.PRNGKey): The random number generator state.
- env_state (LogEnvState): The environment state.
- timesteps (TimeStep): The initial timestep in the initial trajectory.
"""
"""
batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch")

learner_state, (episode_info, loss_info, traj_batch) = jax.lax.scan(
Expand All @@ -345,7 +345,7 @@ def learner_fn(


def learner_setup(
env: Environment, keys: chex.Array, config: DictConfig
env: MarlEnv, keys: chex.Array, config: DictConfig
) -> Tuple[StoreExpLearnerFn[LearnerState], Actor, LearnerState]:
"""Initialise learner_fn, network, optimiser, environment and states."""
# Get available TPU cores.
Expand Down Expand Up @@ -412,7 +412,7 @@ def learner_setup(

# Broadcast params and optimiser state to cores and batch.
broadcast = lambda x: jnp.broadcast_to(
x, (n_devices, config.system.update_batch_size) + x.shape
x, (n_devices, config.system.update_batch_size, *x.shape)
)

actor_params = jax.tree_map(broadcast, actor_params)
Expand Down Expand Up @@ -449,8 +449,7 @@ def learner_setup(
return learn, actor_network, init_learner_state


# TODO: fix cognitive complexity
def run_experiment(_config: DictConfig) -> None: # noqa: CCR001
def run_experiment(_config: DictConfig) -> None:
"""Runs experiment."""
# Logger setup
config = copy.deepcopy(_config)
Expand All @@ -469,7 +468,8 @@ def run_experiment(_config: DictConfig) -> None: # noqa: CCR001

# Setup evaluator.
eval_keys = jax.random.split(key_e, n_devices)
evaluator, absolute_metric_evaluator = make_anakin_eval_fns(eval_env, actor_network, config)
eval_act_fn = make_ff_eval_act_fn(actor_network, config)
evaluator = get_eval_fn(eval_env, eval_act_fn, config, config.arch.num_eval_episodes)

config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation
steps_per_rollout = (
Expand Down Expand Up @@ -547,7 +547,6 @@ def run_experiment(_config: DictConfig) -> None: # noqa: CCR001
@jax.jit
def _reshape_experience(experience: Dict[str, chex.Array]) -> Dict[str, chex.Array]:
"""Reshape experience to match buffer."""

# Swap the T and NE axes (D, NU, UB, T, NE, ...) -> (D, NU, UB, NE, T, ...)
experience: Dict[str, chex.Array] = jax.tree_map(lambda x: x.swapaxes(3, 4), experience)
# Merge 4 leading dimensions into 1. (D, NU, UB, NE, T ...) -> (D * NU * UB * NE, T, ...)
Expand Down Expand Up @@ -619,16 +618,16 @@ def _reshape_experience(experience: Dict[str, chex.Array]) -> Dict[str, chex.Arr
eval_keys = eval_keys.reshape(n_devices, -1)

# Evaluate.
evaluator_output = evaluator(trained_params, eval_keys)
jax.block_until_ready(evaluator_output)
eval_metrics = evaluator(trained_params, eval_keys, {})
jax.block_until_ready(eval_metrics)

# Log the results of the evaluation.
elapsed_time = time.time() - start_time
episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"])
episode_return = jnp.mean(eval_metrics["episode_return"])

steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"]))
evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time
logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL)
steps_per_eval = int(jnp.sum(eval_metrics["episode_length"]))
eval_metrics["steps_per_second"] = steps_per_eval / elapsed_time
logger.log(eval_metrics, t, eval_step, LogEvent.EVAL)

if save_checkpoint:
# Save checkpoint of learner state
Expand All @@ -652,18 +651,21 @@ def _reshape_experience(experience: Dict[str, chex.Array]) -> Dict[str, chex.Arr
if config.arch.absolute_metric:
start_time = time.time()

eval_episodes = config.arch.num_absolute_metric_eval_episodes
abs_metric_evaluator = get_eval_fn(eval_env, eval_act_fn, config, eval_episodes)

key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
eval_keys = jnp.stack(eval_keys)
eval_keys = eval_keys.reshape(n_devices, -1)

evaluator_output = absolute_metric_evaluator(best_params, eval_keys)
jax.block_until_ready(evaluator_output)
eval_metrics = abs_metric_evaluator(best_params, eval_keys, {})
jax.block_until_ready(eval_metrics)

elapsed_time = time.time() - start_time
steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"]))
steps_per_eval = int(jnp.sum(eval_metrics["episode_length"]))
t = int(steps_per_rollout * (eval_step + 1))
evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time
logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE)
eval_metrics["steps_per_second"] = steps_per_eval / elapsed_time
logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE)

# Stop logger
logger.stop()
Expand Down
3 changes: 2 additions & 1 deletion mava/configs/arch/anakin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ num_envs: 16 # Number of vectorised environments per device.
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
# an action which corresponds to the greatest logit. If false, the policy will sample
# from the logits.
num_eval_episodes: 32 # Number of episodes to evaluate per evaluation.
num_evaluation: 200 # Number of evenly spaced evaluations to perform during training.
num_eval_episodes: 32 # Number of episodes to evaluate per evaluation.
num_absolute_metric_eval_episodes: 320 # Number of episodes to evaluate the absolute metric (the final evaluation).
absolute_metric: True # Whether the absolute metric should be computed. For more details
# on the absolute metric please see: https://arxiv.org/abs/2209.10485
4 changes: 2 additions & 2 deletions mava/configs/env/connector.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ---Environment Configs---
defaults:
- _self_
- scenario: con-10x10x5a # [con-5x5x3a, con-7x7x5a, con-10x10x5a, con-15x15x10a]
- scenario: con-5x5x3a # [con-5x5x3a, con-7x7x5a, con-10x10x10a, con-15x15x23a]
# Further environment config details in "con-10x10x5a" file.

env_name: MaConnector # Used for logging purposes.
Expand All @@ -18,4 +18,4 @@ implicit_agent_id: True
log_win_rate: False

kwargs:
time_limit: 100
{} # time limit set in scenario
10 changes: 10 additions & 0 deletions mava/configs/env/scenario/con-10x10x10a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# The config of the 10x10x10a scenario
name: MaConnector-v2
task_name: con-10x10x10a

task_config:
grid_size: 10
num_agents: 10

env_kwargs:
time_limit: 100
10 changes: 0 additions & 10 deletions mava/configs/env/scenario/con-10x10x5a.yaml

This file was deleted.

10 changes: 0 additions & 10 deletions mava/configs/env/scenario/con-15x15x10a.yaml

This file was deleted.

10 changes: 10 additions & 0 deletions mava/configs/env/scenario/con-15x15x23a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# The config of the 15x15x23a scenario
name: MaConnector-v2
task_name: con-15x15x23a

task_config:
grid_size: 15
num_agents: 23

env_kwargs:
time_limit: 225
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-5x5x3a.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ task_config:
num_agents: 3

env_kwargs:
{} # there are no scenario specific env_kwargs for this env
time_limit: 25
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-7x7x5a.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ task_config:
num_agents: 5

env_kwargs:
{} # there are no scenario specific env_kwargs for this env
time_limit: 49
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/gym-lbf-10x10-3p-3f.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ task_config:
max_num_food: 3
max_player_level: 2
force_coop: False
max_episode_steps: 50
max_episode_steps: 100
min_player_level : 1
min_food_level : null
max_food_level : null
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/gym-lbf-15x15-3p-5f.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ task_config:
max_num_food: 5
max_player_level: 2
force_coop: False
max_episode_steps: 50
max_episode_steps: 100
min_player_level : 1
min_food_level : null
max_food_level : null
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/gym-lbf-15x15-4p-3f.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ task_config:
max_num_food: 3
max_player_level: 2
force_coop: False
max_episode_steps: 50
max_episode_steps: 100
min_player_level : 1
min_food_level : null
max_food_level : null
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/gym-lbf-15x15-4p-5f.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ task_config:
max_num_food: 5
max_player_level: 2
force_coop: False
max_episode_steps: 50
max_episode_steps: 100
min_player_level : 1
min_food_level : null
max_food_level : null
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/gym-lbf-2s-10x10-3p-3f.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ task_config:
max_num_food: 3
max_player_level: 2
force_coop: False
max_episode_steps: 50
max_episode_steps: 100
min_player_level : 1
min_food_level : null
max_food_level : null
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/gym-lbf-2s-8x8-2p-2f-coop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ task_config:
max_num_food: 2 # number of food in the environment.
max_player_level: 2 # maximum level of the agents (inclusive).
force_coop: True # force cooperation between agents.
max_episode_steps: 50 # max number of steps per episode.
max_episode_steps: 100 # max number of steps per episode.
min_player_level : 1 # minimum level of the agents (inclusive).
min_food_level : null
max_food_level : null
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/gym-lbf-8x8-2p-2f-coop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ task_config:
max_num_food: 2
max_player_level: 2
force_coop: True
max_episode_steps: 50
max_episode_steps: 100
min_player_level : 1
min_food_level : null
max_food_level : null
Expand Down
Loading

0 comments on commit 0cae539

Please sign in to comment.