Skip to content

Commit

Permalink
feat: minor refactor to sebulba utils
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a committed Oct 10, 2024
1 parent fd8aece commit ae53415
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 57 deletions.
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import (
merge_leading_dims,
unreplicate_batch_dim,
unreplicate_n_dims,
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down
26 changes: 8 additions & 18 deletions mava/systems/ppo/sebulba/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,13 @@
from mava.networks import FeedForwardActor as Actor
from mava.networks import FeedForwardValueNet as Critic
from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition
from mava.types import (
ActorApply,
CriticApply,
ExperimentOutput,
Observation,
SebulbaLearnerFn,
)
from mava.types import ActorApply, CriticApply, ExperimentOutput, Observation, SebulbaLearnerFn
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_sebulba_config, check_total_timesteps
from mava.utils.jax_utils import merge_leading_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.sebulba_utils import (
ParamsSource,
Pipeline,
RecordTimeTo,
ThreadLifetime,
check_config,
)
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -95,7 +83,7 @@ def rollout(
# Define the util functions: select action function and prepare data to share it with learner.
@jax.jit
def get_action_and_value(
params: FrozenDict,
params: Params,
observation: Observation,
key: chex.PRNGKey,
) -> Tuple:
Expand Down Expand Up @@ -147,7 +135,8 @@ def get_action_and_value(

# Append data to storage
reward = timestep.reward
info = timestep.extras
info = timestep.extras # todo: [metrics]?
# todo: when logging make sure timing dict has parent timing/...
traj.append(
PPOTransition(
cached_next_dones, action, value, reward, log_prob, cached_next_obs, info
Expand Down Expand Up @@ -547,7 +536,7 @@ def run_experiment(_config: DictConfig) -> float:

# Calculate total timesteps.
config = check_total_timesteps(config)
check_config(config)
check_sebulba_config(config)

steps_per_rollout = (
config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval
Expand Down Expand Up @@ -674,6 +663,7 @@ def run_experiment(_config: DictConfig) -> float:
t = int(steps_per_rollout * (eval_step + 1))
logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE)
abs_metric_evaluator_envs.close()

# Stop the logger.
logger.stop()

Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from optax._src.base import OptState
from typing_extensions import NamedTuple

from mava.types import Action, Done, HiddenState, State, Value
from mava.types import Action, Done, HiddenState, Observation, State, Value


class Params(NamedTuple):
Expand Down Expand Up @@ -74,7 +74,7 @@ class PPOTransition(NamedTuple):
value: Value
reward: chex.Array
log_prob: chex.Array
obs: chex.Array
obs: Observation
info: Dict


Expand Down
2 changes: 1 addition & 1 deletion mava/systems/q_learning/anakin/rec_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@
from mava.types import Observation
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import (
switch_leading_axes,
unreplicate_batch_dim,
unreplicate_n_dims,
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.wrappers import episode_metrics


Expand Down
2 changes: 1 addition & 1 deletion mava/systems/sac/anakin/ff_isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@
from mava.types import MarlEnv, Observation
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.wrappers import episode_metrics


Expand Down
2 changes: 1 addition & 1 deletion mava/systems/sac/anakin/ff_masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@
from mava.utils import make_env as environments
from mava.utils.centralised_training import get_joint_action, get_updated_joint_actions
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.wrappers import episode_metrics


Expand Down
22 changes: 22 additions & 0 deletions mava/utils/total_timestep_checker.py → mava/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@
from omegaconf import DictConfig


def check_sebulba_config(config: DictConfig) -> None:
"""Checks that the given config does not have conflicting values."""
assert (
config.system.num_updates > config.arch.num_evaluation
), "Number of updates per evaluation must be less than total number of updates."
config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation

assert config.arch.num_envs % len(config.arch.learner_device_ids) == 0, (
"Number of environments must be divisible by the number of learner."
+ "The output of each actor is equally split across the learners."
)

num_eval_samples = (
int(config.arch.num_envs / len(config.arch.learner_device_ids))
* config.system.rollout_length
)
assert num_eval_samples % config.system.num_minibatches == 0, (
f"Number of training samples per evaluator ({num_eval_samples})"
+ f"must be divisible by num_minibatches ({config.system.num_minibatches})."
)


def check_total_timesteps(config: DictConfig) -> DictConfig:
"""Check if total_timesteps is set, if not, set it based on the other parameters"""

Expand Down
75 changes: 45 additions & 30 deletions mava/utils/sebulba_utils.py → mava/utils/sebulba.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@

import jax
import jax.numpy as jnp
from colorama import Fore, Style
from jax import tree
from jumanji.types import TimeStep
from omegaconf import DictConfig

from mava.systems.ppo.types import Params, PPOTransition # todo: remove the ppo dependencies
# todo: remove the ppo dependencies
from mava.systems.ppo.types import Params, PPOTransition

QUEUE_PUT_TIMEOUT = 180


# Copied from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py
class ThreadLifetime:
"""Simple class for a mutable boolean that can be used to signal a thread to stop."""

Expand All @@ -40,6 +43,14 @@ def stop(self) -> None:
self._stop = True


@jax.jit
def _stack_trajectory(trajectory: List[PPOTransition]) -> PPOTransition:
"""Stack a list of parallel_env transitions into a single
transition of shape [rollout_len, num_envs, ...]."""
return tree.map(lambda *x: jnp.stack(x, axis=0), *trajectory) # type: ignore


# Modified from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py
class Pipeline(threading.Thread):
"""
The `Pipeline` shards trajectories into `learner_devices`,
Expand All @@ -54,6 +65,7 @@ def __init__(self, max_size: int, learner_devices: List[jax.Device], lifetime: T
Args:
max_size: The maximum number of trajectories to keep in the pipeline.
learner_devices: The devices to shard trajectories across.
lifetime: A `ThreadLifetime` which is used to stop this thread.
"""
super().__init__(name="Pipeline")
self.learner_devices = learner_devices
Expand Down Expand Up @@ -83,19 +95,39 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict
self.tickets_queue.put((start_condition, end_condition))
start_condition.wait() # wait to be allowed to start

# [PPOTransition()] * rollout_len --> PPOTransition[done=(rollout_len, num_envs, ...)]
sharded_traj = jax.tree.map(lambda *x: self.shard_split_playload(jnp.stack(x), 1), *traj)
# [Transition(num_envs)] * rollout_len --> Transition[done=(rollout_len, num_envs, ...)]
traj = _stack_trajectory(traj)
# Split trajectory on the num envs axis so each learner device gets a valid full rollout
sharded_traj = jax.tree.map(lambda x: self.shard_split_playload(x, axis=1), traj)

# Timestep[(num_envs, num_agents, ...), ...] -->
# [(num_envs / num_learner_devices, num_agents, ...)] * num_learner_devices
sharded_timestep = jax.tree.map(self.shard_split_playload, timestep)

# The lock has to be released even if an exception is raised.
# We block on the put to ensure that actors wait for the learners to catch up. This does two
# things:
# 1. It ensures that the actors don't get too far ahead of the learners, which could lead to
# off-policy data.
# 2. It ensures that the actors don't in a sense "waste" samples and their time by
# generating samples that the learners can't consume.
# However, we put a timeout of 180 seconds to avoid deadlocks in case the learner
# is not consuming the data. This is a safety measure and should not be hit in normal
# operation. We use a try-finally since the lock has to be released even if an exception
# is raised.
try:
self._queue.put((sharded_traj, sharded_timestep, time_dict), timeout=90)
self._queue.put(
(sharded_traj, sharded_timestep, time_dict),
block=True,
timeout=QUEUE_PUT_TIMEOUT,
)
except queue.Full: # todo: check if this is needed because we catch this exception outside
print(
f"{Fore.RED}{Style.BRIGHT}Pipeline is full and actor has timed out, "
f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}"
)
finally:
with end_condition:
end_condition.notify() # tell we have finish
end_condition.notify() # notify that we have finished

def qsize(self) -> int:
"""Returns the number of trajectories in the pipeline."""
Expand All @@ -107,6 +139,11 @@ def get(
"""Get a trajectory from the pipeline."""
return self._queue.get(block, timeout) # type: ignore

def clear(self) -> None:
"""Clear the pipeline."""
while not self._queue.empty():
self._queue.get()

def shard_split_playload(self, payload: Any, axis: int = 0) -> Any:
split_payload = jnp.split(payload, len(self.learner_devices), axis=axis)
return jax.device_put_sharded(split_payload, devices=self.learner_devices)
Expand Down Expand Up @@ -158,25 +195,3 @@ def __enter__(self) -> None:
def __exit__(self, *args: Any) -> None:
end = time.monotonic()
self.to.append(end - self.start)


def check_config(config: DictConfig) -> None:
"""Checks that the given config does not have conflicting values."""
assert (
config.system.num_updates > config.arch.num_evaluation
), "Number of updates per evaluation must be less than total number of updates."
config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation

assert config.arch.num_envs % len(config.arch.learner_device_ids) == 0, (
"Number of environments must be divisible by the number of learner."
+ "The output of each actor is equally split across the learners."
)

num_eval_samples = (
int(config.arch.num_envs / len(config.arch.learner_device_ids))
* config.system.rollout_length
)
assert num_eval_samples % config.system.num_minibatches == 0, (
f"Number of training samples per evaluator ({num_eval_samples})"
+ f"must be divisible by num_minibatches ({config.system.num_minibatches})."
)

0 comments on commit ae53415

Please sign in to comment.