Skip to content

Commit

Permalink
Merge branch 'develop' into feat/update_juamnji
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi committed Dec 3, 2024
2 parents 9047410 + 80554ef commit 4832e38
Show file tree
Hide file tree
Showing 18 changed files with 325 additions and 567 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ Additionally, we also have a [Quickstart notebook][quickstart] that can be used

## Advanced Usage 👽

Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./mava/advanced_usage/) for more information.
Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./examples/advanced_usage/README.md) for more information.

## Contributing 🤝

Expand Down
12 changes: 3 additions & 9 deletions examples/Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,6 @@
" )\n",
"\n",
" # Compute the parallel mean (pmean) over the batch.\n",
" # This calculation is inspired by the Anakin architecture demo notebook.\n",
" # available at https://tinyurl.com/26tdzs5x\n",
" # This pmean could be a regular mean as the batch axis is on the same device.\n",
" actor_grads, actor_loss_info = jax.lax.pmean(\n",
" (actor_grads, actor_loss_info), axis_name=\"batch\"\n",
Expand Down Expand Up @@ -1113,7 +1111,8 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "mava",
"language": "python",
"name": "python3"
},
"language_info": {
Expand All @@ -1126,12 +1125,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -225,8 +226,6 @@ def _critic_loss_fn(
)

# Compute the parallel mean (pmean) over the batch.
# This calculation is inspired by the Anakin architecture demo notebook.
# available at https://tinyurl.com/26tdzs5x
# This pmean could be a regular mean as the batch axis is on the same device.
actor_grads, actor_loss_info = jax.lax.pmean(
(actor_grads, actor_loss_info), axis_name="batch"
Expand Down
4 changes: 2 additions & 2 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
# find the first instance of done to get the metrics at that timestep, we don't
# care about subsequent steps because we only the results from the first episode
done_idx = np.argmax(timesteps.last(), axis=0)
metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
metrics = tree.map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
del metrics["is_terminal_step"] # uneeded for logging

return key, metrics
Expand All @@ -307,7 +307,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
metrics_array.append(metric)

# flatten metrics
metrics: Metrics = jax.tree_map(lambda *x: np.array(x).reshape(-1), *metrics_array)
metrics: Metrics = tree.map(lambda *x: np.array(x).reshape(-1), *metrics_array)
return metrics

def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics:
Expand Down
8 changes: 4 additions & 4 deletions mava/networks/retention.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ def setup(self) -> None:
self.w_g = self.param(
"w_g",
nn.initializers.normal(stddev=1 / self.embed_dim),
(self.embed_dim, self.head_size),
(self.embed_dim, self.embed_dim),
)
self.w_o = self.param(
"w_o",
nn.initializers.normal(stddev=1 / self.embed_dim),
(self.head_size, self.embed_dim),
(self.embed_dim, self.embed_dim),
)
self.group_norm = nn.GroupNorm(num_groups=self.n_head)

Expand Down Expand Up @@ -278,7 +278,7 @@ def __call__(
if self.memory_config.timestep_positional_encoding:
key, query, value = self.pe(key, query, value, step_count)

ret_output = jnp.zeros((B, C, self.head_size), dtype=value.dtype)
ret_output = jnp.zeros((B, C, self.embed_dim), dtype=value.dtype)
for head in range(self.n_head):
y, new_hs = self.retention_heads[head](key, query, value, hstate[:, head], dones)
ret_output = ret_output.at[
Expand All @@ -304,7 +304,7 @@ def recurrent(
if self.memory_config.timestep_positional_encoding:
key_n, query_n, value_n = self.pe(key_n, query_n, value_n, step_count)

ret_output = jnp.zeros((B, S, self.head_size), dtype=value_n.dtype)
ret_output = jnp.zeros((B, S, self.embed_dim), dtype=value_n.dtype)
for head in range(self.n_head):
y, new_hs = self.retention_heads[head].recurrent(
key_n, query_n, value_n, hstate[:, head]
Expand Down
90 changes: 29 additions & 61 deletions mava/systems/mat/anakin/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,13 @@
ExperimentOutput,
LearnerFn,
MarlEnv,
Metrics,
TimeStep,
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import (
merge_leading_dims,
unreplicate_batch_dim,
unreplicate_n_dims,
)
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.training import make_learning_rate
Expand Down Expand Up @@ -83,51 +80,35 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup
_ (Any): The current metrics info.
"""

def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]:
def _env_step(
learner_state: LearnerState, _: Any
) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]:
"""Step the environment."""
params, opt_state, key, env_state, last_timestep = learner_state

# SELECT ACTION
# Select action
key, policy_key = jax.random.split(key)
action, log_prob, value = actor_action_select_fn( # type: ignore
params,
last_timestep.observation,
policy_key,
)
# STEP ENVIRONMENT
# Step environment
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

# LOG EPISODE METRICS
# Repeat along the agent dimension. This is needed to handle the
# shuffling along the agent dimension during training.
info = tree.map(
lambda x: jnp.repeat(x[..., jnp.newaxis], config.system.num_agents, axis=-1),
timestep.extras["episode_metrics"],
)

# SET TRANSITION
done = tree.map(
lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1),
timestep.last(),
)
done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1)
transition = PPOTransition(
done,
action,
value,
timestep.reward,
log_prob,
last_timestep.observation,
info,
done, action, value, timestep.reward, log_prob, last_timestep.observation
)
learner_state = LearnerState(params, opt_state, key, env_state, timestep)
return learner_state, transition
return learner_state, (transition, timestep.extras["episode_metrics"])

# STEP ENVIRONMENT FOR ROLLOUT LENGTH
learner_state, traj_batch = jax.lax.scan(
# Step environment for rollout length
learner_state, (traj_batch, episode_metrics) = jax.lax.scan(
_env_step, learner_state, None, config.system.rollout_length
)

# CALCULATE ADVANTAGE
# Calculate advantage
params, opt_state, key, env_state, last_timestep = learner_state

key, last_val_key = jax.random.split(key)
Expand Down Expand Up @@ -171,8 +152,6 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple:

def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
"""Update the network for a single minibatch."""

# UNPACK TRAIN STATE AND BATCH INFO
params, opt_state, key = train_state
traj_batch, advantages, targets = batch_info

Expand All @@ -184,52 +163,47 @@ def _loss_fn(
entropy_key: chex.PRNGKey,
) -> Tuple:
"""Calculate the actor loss."""
# RERUN NETWORK

# Rerun network
log_prob, value, entropy = actor_apply_fn( # type: ignore
params,
traj_batch.obs,
traj_batch.action,
entropy_key,
)

# CALCULATE ACTOR LOSS
# Calculate actor loss
ratio = jnp.exp(log_prob - traj_batch.log_prob)

# Nomalise advantage at minibatch level
gae = (gae - gae.mean()) / (gae.std() + 1e-8)

loss_actor1 = ratio * gae
loss_actor2 = (
actor_loss1 = ratio * gae
actor_loss2 = (
jnp.clip(
ratio,
1.0 - config.system.clip_eps,
1.0 + config.system.clip_eps,
)
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
actor_loss = -jnp.minimum(actor_loss1, actor_loss2)
actor_loss = actor_loss.mean()
entropy = entropy.mean()

# CALCULATE VALUE LOSS
# Clipped MSE loss
value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
-config.system.clip_eps, config.system.clip_eps
)

# MSE LOSS
value_losses = jnp.square(value - value_targets)
value_losses_clipped = jnp.square(value_pred_clipped - value_targets)
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

total_loss = (
loss_actor
actor_loss
- config.system.ent_coef * entropy
+ config.system.vf_coef * value_loss
)
return total_loss, (loss_actor, entropy, value_loss)
return total_loss, (actor_loss, entropy, value_loss)

# CALCULATE ACTOR LOSS
# Calculate loss
key, entropy_key = jax.random.split(key)
actor_grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
actor_loss_info, actor_grads = actor_grad_fn(
Expand All @@ -248,15 +222,11 @@ def _loss_fn(
(actor_grads, actor_loss_info), axis_name="device"
)

# UPDATE ACTOR PARAMS AND OPTIMISER STATE
# Update params and optimiser state
actor_updates, new_opt_state = actor_update_fn(actor_grads, opt_state)
new_params = optax.apply_updates(params, actor_updates)

# PACK LOSS INFO
total_loss = actor_loss_info[0]
value_loss = actor_loss_info[1][2]
actor_loss = actor_loss_info[1][0]
entropy = actor_loss_info[1][1]
total_loss, (actor_loss, entropy, value_loss) = actor_loss_info
loss_info = {
"total_loss": total_loss,
"value_loss": value_loss,
Expand All @@ -269,7 +239,7 @@ def _loss_fn(
params, opt_state, traj_batch, advantages, targets, key = update_state
key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4)

# SHUFFLE MINIBATCHES
# Shuffle minibatches
batch_size = config.system.rollout_length * config.arch.num_envs
permutation = jax.random.permutation(batch_shuffle_key, batch_size)

Expand All @@ -286,7 +256,7 @@ def _loss_fn(
shuffled_batch,
)

# UPDATE MINIBATCHES
# Update minibatches
(params, opt_state, entropy_key), loss_info = jax.lax.scan(
_update_minibatch, (params, opt_state, entropy_key), minibatches
)
Expand All @@ -296,17 +266,15 @@ def _loss_fn(

update_state = params, opt_state, traj_batch, advantages, targets, key

# UPDATE EPOCHS
# Update epochs
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config.system.ppo_epochs
)

params, opt_state, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_state, key, env_state, last_timestep)

metric = traj_batch.info

return learner_state, (metric, loss_info)
return learner_state, (episode_metrics, loss_info)

def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]:
"""Learner function.
Expand Down
Loading

0 comments on commit 4832e38

Please sign in to comment.