Skip to content

Commit

Permalink
fix: updated sebulba
Browse files Browse the repository at this point in the history
  • Loading branch information
Louay-Ben-nessir committed Nov 14, 2024
1 parent 27bdc2f commit e8a0c07
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 21 deletions.
30 changes: 13 additions & 17 deletions mava/systems/ppo/sebulba/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from mava.types import (
ActorApply,
CriticApply,
ExperimentOutput,
Metrics,
Observation,
SebulbaLearnerFn,
)
Expand Down Expand Up @@ -113,6 +113,7 @@ def act_fn(
while not thread_lifetime.should_stop():
# Rollout
traj: List[PPOTransition] = []
episode_metrics: List[Dict] = []
actor_timings: Dict[str, List[float]] = defaultdict(list)
with RecordTimeTo(actor_timings["rollout_time"]):
for _ in range(config.system.rollout_length):
Expand Down Expand Up @@ -142,14 +143,14 @@ def act_fn(
timestep.reward,
log_prob,
obs_tpu,
timestep.extras["episode_metrics"],
)
)
episode_metrics.append(timestep.extras["episode_metrics"])

# send trajectories to learner
with RecordTimeTo(actor_timings["rollout_put_time"]):
try:
rollout_queue.put(traj, timestep, actor_timings)
rollout_queue.put(traj, timestep, (actor_timings, episode_metrics))
except queue.Full:
err = "Waited too long to add to the rollout queue, killing the actor thread"
warnings.warn(err, stacklevel=2)
Expand All @@ -175,7 +176,7 @@ def get_learner_step_fn(
def _update_step(
learner_state: LearnerState,
traj_batch: PPOTransition,
) -> Tuple[LearnerState, Tuple]:
) -> Tuple[LearnerState, Metrics]:
"""A single update of the network.
This function calculates advantages and targets based on the trajectories
Expand Down Expand Up @@ -216,7 +217,7 @@ def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tup
last_val = critic_apply_fn(params.critic_params, final_timestep.observation)
advantages, targets = _calculate_gae(traj_batch, last_val)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
def _update_epoch(update_state: Tuple, _: Any) -> Tuple[Tuple, Metrics]:
"""Update the network for a single epoch."""

def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
Expand Down Expand Up @@ -359,12 +360,11 @@ def _critic_loss_fn(

params, opt_states, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep)
metric = traj_batch.info
return learner_state, (metric, loss_info)
return learner_state, loss_info

def learner_fn(
learner_state: LearnerState, traj_batch: PPOTransition
) -> ExperimentOutput[LearnerState]:
) -> Tuple[LearnerState, Metrics]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -382,13 +382,9 @@ def learner_fn(
# This function is shard mapped on the batch axis, but `_update_step` needs
# the first axis to be time
traj_batch = tree.map(switch_leading_axes, traj_batch)
learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch)
learner_state, loss_info = _update_step(learner_state, traj_batch)

return ExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
)
return learner_state, loss_info

return learner_fn

Expand All @@ -412,15 +408,15 @@ def learner_thread(
# Get the trajectory batch from the pipeline
# This is blocking so it will wait until the pipeline has data.
with RecordTimeTo(learn_times["rollout_get_time"]):
traj_batch, timestep, rollout_time = pipeline.get(block=True)
traj_batch, timestep, rollout_time, ep_metrics = pipeline.get(block=True)

# Replace the timestep in the learner state with the latest timestep
# This means the learner has access to the entire trajectory as well as
# an additional timestep which it can use to bootstrap.
learner_state = learner_state._replace(timestep=timestep)
# Update the networks
with RecordTimeTo(learn_times["learning_time"]):
learner_state, ep_metrics, train_metrics = learn_fn(learner_state, traj_batch)
learner_state, train_metrics = learn_fn(learner_state, traj_batch)

metrics.append((ep_metrics, train_metrics))
rollout_times_array.append(rollout_time)
Expand Down Expand Up @@ -515,7 +511,7 @@ def learner_setup(
learn,
mesh=mesh,
in_specs=(learn_state_spec, data_spec),
out_specs=ExperimentOutput(learn_state_spec, data_spec, data_spec),
out_specs=(learn_state_spec, data_spec),
)
)

Expand Down
2 changes: 1 addition & 1 deletion mava/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]):


LearnerFn = Callable[[MavaState], ExperimentOutput[MavaState]]
SebulbaLearnerFn = Callable[[MavaState, MavaTransition], ExperimentOutput[MavaState]]
SebulbaLearnerFn = Callable[[MavaState, MavaTransition], Tuple[MavaState, Metrics]]
ActorApply = Callable[[FrozenDict, Observation], Distribution]
CriticApply = Callable[[FrozenDict, Observation], Value]
RecActorApply = Callable[
Expand Down
13 changes: 10 additions & 3 deletions mava/utils/sebulba.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

# todo: remove the ppo dependencies when we make sebulba for other systems
from mava.systems.ppo.types import Params, PPOTransition
from mava.types import Metrics

QUEUE_PUT_TIMEOUT = 100

Expand Down Expand Up @@ -90,7 +91,9 @@ def run(self) -> None:
except queue.Empty:
continue

def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict) -> None:
def put(
self, traj: Sequence[PPOTransition], timestep: TimeStep, metrics: Tuple[Dict, List[Dict]]
) -> None:
"""Put a trajectory on the queue to be consumed by the learner."""
start_condition, end_condition = (threading.Condition(), threading.Condition())
with start_condition:
Expand All @@ -101,6 +104,10 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict
traj = _stack_trajectory(traj)
traj, timestep = jax.device_put((traj, timestep), device=self.sharding)

time_dict, episode_metrics = metrics
# [{'metric1' : value1, ...} * rollout_len -> {'metric1' : [value1, value2, ...], ...}
episode_metrics = _stack_trajectory(episode_metrics)

# We block on the `put` to ensure that actors wait for the learners to catch up.
# This ensures two things:
# The actors don't get too far ahead of the learners, which could lead to off-policy data.
Expand All @@ -110,7 +117,7 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict
# We use a try-finally so the lock is released even if an exception is raised.
try:
self._queue.put(
(traj, timestep, time_dict),
(traj, timestep, time_dict, episode_metrics),
block=True,
timeout=QUEUE_PUT_TIMEOUT,
)
Expand All @@ -129,7 +136,7 @@ def qsize(self) -> int:

def get(
self, block: bool = True, timeout: Union[float, None] = None
) -> Tuple[PPOTransition, TimeStep, Dict]:
) -> Tuple[PPOTransition, TimeStep, Dict, Metrics]:
"""Get a trajectory from the pipeline."""
return self._queue.get(block, timeout) # type: ignore

Expand Down

0 comments on commit e8a0c07

Please sign in to comment.