Skip to content

Commit

Permalink
feat: ff ippo system clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a committed Nov 1, 2024
1 parent 5bf7188 commit f77b782
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 58 deletions.
89 changes: 32 additions & 57 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@
from flax.core.frozen_dict import FrozenDict
from jax import tree
from omegaconf import DictConfig, OmegaConf
from optax._src.base import OptState
from rich.pretty import pprint

from mava.evaluator import get_eval_fn, make_ff_eval_act_fn
from mava.networks import FeedForwardActor as Actor
from mava.networks import FeedForwardValueNet as Critic
from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition
from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv
from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv, Metrics
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import (
Expand Down Expand Up @@ -79,46 +78,36 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup
"""

def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]:
def _env_step(
learner_state: LearnerState, _: Any
) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]:
"""Step the environment."""
params, opt_states, key, env_state, last_timestep = learner_state

# SELECT ACTION
# Select action
key, policy_key = jax.random.split(key)
actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation)
value = critic_apply_fn(params.critic_params, last_timestep.observation)

action = actor_policy.sample(seed=policy_key)
log_prob = actor_policy.log_prob(action)

# STEP ENVIRONMENT
# Step environment
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

# LOG EPISODE METRICS
done = tree.map(
lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1),
timestep.last(),
)
info = timestep.extras["episode_metrics"]

done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1)
transition = PPOTransition(
done,
action,
value,
timestep.reward,
log_prob,
last_timestep.observation,
info,
done, action, value, timestep.reward, log_prob, last_timestep.observation
)
learner_state = LearnerState(params, opt_states, key, env_state, timestep)
return learner_state, transition
return learner_state, (transition, timestep.extras["episode_metrics"])

# STEP ENVIRONMENT FOR ROLLOUT LENGTH
learner_state, traj_batch = jax.lax.scan(
# Step environment for rollout length
learner_state, (traj_batch, episode_metrics) = jax.lax.scan(
_env_step, learner_state, None, config.system.rollout_length
)

# CALCULATE ADVANTAGE
# Calculate advantage
params, opt_states, key, env_state, last_timestep = learner_state
last_val = critic_apply_fn(params.critic_params, last_timestep.observation)

Expand Down Expand Up @@ -156,23 +145,21 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple:

def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
"""Update the network for a single minibatch."""
# UNPACK TRAIN STATE AND BATCH INFO
params, opt_states, key = train_state
traj_batch, advantages, targets = batch_info

def _actor_loss_fn(
actor_params: FrozenDict,
actor_opt_state: OptState,
traj_batch: PPOTransition,
gae: chex.Array,
key: chex.PRNGKey,
) -> Tuple:
"""Calculate the actor loss."""
# RERUN NETWORK
# Rerun network
actor_policy = actor_apply_fn(actor_params, traj_batch.obs)
log_prob = actor_policy.log_prob(traj_batch.action)

# CALCULATE ACTOR LOSS
# Calculate actor loss
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
Expand All @@ -194,15 +181,14 @@ def _actor_loss_fn(

def _critic_loss_fn(
critic_params: FrozenDict,
critic_opt_state: OptState,
traj_batch: PPOTransition,
targets: chex.Array,
) -> Tuple:
"""Calculate the critic loss."""
# RERUN NETWORK
# Rerun network
value = critic_apply_fn(critic_params, traj_batch.obs)

# CALCULATE VALUE LOSS
# Calculate value loss
value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
-config.system.clip_eps, config.system.clip_eps
)
Expand All @@ -211,26 +197,19 @@ def _critic_loss_fn(
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

critic_total_loss = config.system.vf_coef * value_loss
return critic_total_loss, (value_loss)
return critic_total_loss, value_loss

# CALCULATE ACTOR LOSS
# Calculate actor loss
key, entropy_key = jax.random.split(key)
actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True)
actor_loss_info, actor_grads = actor_grad_fn(
params.actor_params,
opt_states.actor_opt_state,
traj_batch,
advantages,
entropy_key,
params.actor_params, traj_batch, advantages, entropy_key
)

# CALCULATE CRITIC LOSS
# Calculate critic loss
critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True)
critic_loss_info, critic_grads = critic_grad_fn(
params.critic_params,
opt_states.critic_opt_state,
traj_batch,
targets,
params.critic_params, traj_batch, targets
)

# Compute the parallel mean (pmean) over the batch.
Expand All @@ -253,39 +232,36 @@ def _critic_loss_fn(
(critic_grads, critic_loss_info), axis_name="device"
)

# UPDATE ACTOR PARAMS AND OPTIMISER STATE
# Update params and optimiser state
actor_updates, actor_new_opt_state = actor_update_fn(
actor_grads, opt_states.actor_opt_state
)
actor_new_params = optax.apply_updates(params.actor_params, actor_updates)

# UPDATE CRITIC PARAMS AND OPTIMISER STATE
critic_updates, critic_new_opt_state = critic_update_fn(
critic_grads, opt_states.critic_opt_state
)
critic_new_params = optax.apply_updates(params.critic_params, critic_updates)

# PACK NEW PARAMS AND OPTIMISER STATE
new_params = Params(actor_new_params, critic_new_params)
new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state)

# PACK LOSS INFO
total_loss = actor_loss_info[0] + critic_loss_info[0]
value_loss = critic_loss_info[1]
actor_loss = actor_loss_info[1][0]
entropy = actor_loss_info[1][1]
actor_loss, (raw_actor_loss, entropy) = actor_loss_info
critic_loss, raw_critic_loss = critic_loss_info

total_loss = actor_loss + critic_loss
loss_info = {
"total_loss": total_loss,
"value_loss": value_loss,
"actor_loss": actor_loss,
"value_loss": raw_critic_loss,
"actor_loss": raw_actor_loss,
"entropy": entropy,
}
return (new_params, new_opt_state, entropy_key), loss_info

params, opt_states, traj_batch, advantages, targets, key = update_state
key, shuffle_key, entropy_key = jax.random.split(key, 3)

# SHUFFLE MINIBATCHES
# Shuffle minibatches
batch_size = config.system.rollout_length * config.arch.num_envs
permutation = jax.random.permutation(shuffle_key, batch_size)
batch = (traj_batch, advantages, targets)
Expand All @@ -296,7 +272,7 @@ def _critic_loss_fn(
shuffled_batch,
)

# UPDATE MINIBATCHES
# Update minibatches
(params, opt_states, entropy_key), loss_info = jax.lax.scan(
_update_minibatch, (params, opt_states, entropy_key), minibatches
)
Expand All @@ -306,15 +282,14 @@ def _critic_loss_fn(

update_state = (params, opt_states, traj_batch, advantages, targets, key)

# UPDATE EPOCHS
# Update epochs
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config.system.ppo_epochs
)

params, opt_states, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep)
metric = traj_batch.info
return learner_state, (metric, loss_info)
return learner_state, (episode_metrics, loss_info)

def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]:
"""Learner function.
Expand Down
1 change: 0 additions & 1 deletion mava/systems/ppo/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class PPOTransition(NamedTuple):
reward: chex.Array
log_prob: chex.Array
obs: chex.Array
info: Dict


class RNNPPOTransition(NamedTuple):
Expand Down

0 comments on commit f77b782

Please sign in to comment.