diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 6f34c0b1a..97a464184 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -45,7 +45,7 @@ from mava.types import ( ActorApply, CriticApply, - ExperimentOutput, + Metrics, Observation, SebulbaLearnerFn, ) @@ -113,6 +113,7 @@ def act_fn( while not thread_lifetime.should_stop(): # Rollout traj: List[PPOTransition] = [] + episode_metrics: List[Dict] = [] actor_timings: Dict[str, List[float]] = defaultdict(list) with RecordTimeTo(actor_timings["rollout_time"]): for _ in range(config.system.rollout_length): @@ -142,14 +143,14 @@ def act_fn( timestep.reward, log_prob, obs_tpu, - timestep.extras["episode_metrics"], ) ) + episode_metrics.append(timestep.extras["episode_metrics"]) # send trajectories to learner with RecordTimeTo(actor_timings["rollout_put_time"]): try: - rollout_queue.put(traj, timestep, actor_timings) + rollout_queue.put(traj, timestep, (actor_timings, episode_metrics)) except queue.Full: err = "Waited too long to add to the rollout queue, killing the actor thread" warnings.warn(err, stacklevel=2) @@ -175,7 +176,7 @@ def get_learner_step_fn( def _update_step( learner_state: LearnerState, traj_batch: PPOTransition, - ) -> Tuple[LearnerState, Tuple]: + ) -> Tuple[LearnerState, Metrics]: """A single update of the network. This function calculates advantages and targets based on the trajectories @@ -216,7 +217,7 @@ def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tup last_val = critic_apply_fn(params.critic_params, final_timestep.observation) advantages, targets = _calculate_gae(traj_batch, last_val) - def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + def _update_epoch(update_state: Tuple, _: Any) -> Tuple[Tuple, Metrics]: """Update the network for a single epoch.""" def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: @@ -359,12 +360,11 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, loss_info def learner_fn( learner_state: LearnerState, traj_batch: PPOTransition - ) -> ExperimentOutput[LearnerState]: + ) -> Tuple[LearnerState, Metrics]: """Learner function. This function represents the learner, it updates the network parameters @@ -382,13 +382,9 @@ def learner_fn( # This function is shard mapped on the batch axis, but `_update_step` needs # the first axis to be time traj_batch = tree.map(switch_leading_axes, traj_batch) - learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch) + learner_state, loss_info = _update_step(learner_state, traj_batch) - return ExperimentOutput( - learner_state=learner_state, - episode_metrics=episode_info, - train_metrics=loss_info, - ) + return learner_state, loss_info return learner_fn @@ -412,7 +408,7 @@ def learner_thread( # Get the trajectory batch from the pipeline # This is blocking so it will wait until the pipeline has data. with RecordTimeTo(learn_times["rollout_get_time"]): - traj_batch, timestep, rollout_time = pipeline.get(block=True) + traj_batch, timestep, rollout_time, ep_metrics = pipeline.get(block=True) # Replace the timestep in the learner state with the latest timestep # This means the learner has access to the entire trajectory as well as @@ -420,7 +416,7 @@ def learner_thread( learner_state = learner_state._replace(timestep=timestep) # Update the networks with RecordTimeTo(learn_times["learning_time"]): - learner_state, ep_metrics, train_metrics = learn_fn(learner_state, traj_batch) + learner_state, train_metrics = learn_fn(learner_state, traj_batch) metrics.append((ep_metrics, train_metrics)) rollout_times_array.append(rollout_time) @@ -515,7 +511,7 @@ def learner_setup( learn, mesh=mesh, in_specs=(learn_state_spec, data_spec), - out_specs=ExperimentOutput(learn_state_spec, data_spec, data_spec), + out_specs=(learn_state_spec, data_spec), ) ) diff --git a/mava/types.py b/mava/types.py index 4072629dc..d60175f50 100644 --- a/mava/types.py +++ b/mava/types.py @@ -152,7 +152,7 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]): LearnerFn = Callable[[MavaState], ExperimentOutput[MavaState]] -SebulbaLearnerFn = Callable[[MavaState, MavaTransition], ExperimentOutput[MavaState]] +SebulbaLearnerFn = Callable[[MavaState, MavaTransition], Tuple[MavaState, Metrics]] ActorApply = Callable[[FrozenDict, Observation], Distribution] CriticApply = Callable[[FrozenDict, Observation], Value] RecActorApply = Callable[ diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index dc51140f5..1fe441e2a 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -27,6 +27,7 @@ # todo: remove the ppo dependencies when we make sebulba for other systems from mava.systems.ppo.types import Params, PPOTransition +from mava.types import Metrics QUEUE_PUT_TIMEOUT = 100 @@ -90,7 +91,9 @@ def run(self) -> None: except queue.Empty: continue - def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict) -> None: + def put( + self, traj: Sequence[PPOTransition], timestep: TimeStep, metrics: Tuple[Dict, List[Dict]] + ) -> None: """Put a trajectory on the queue to be consumed by the learner.""" start_condition, end_condition = (threading.Condition(), threading.Condition()) with start_condition: @@ -101,6 +104,10 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict traj = _stack_trajectory(traj) traj, timestep = jax.device_put((traj, timestep), device=self.sharding) + time_dict, episode_metrics = metrics + # [{'metric1' : value1, ...} * rollout_len -> {'metric1' : [value1, value2, ...], ...} + episode_metrics = _stack_trajectory(episode_metrics) + # We block on the `put` to ensure that actors wait for the learners to catch up. # This ensures two things: # The actors don't get too far ahead of the learners, which could lead to off-policy data. @@ -110,7 +117,7 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict # We use a try-finally so the lock is released even if an exception is raised. try: self._queue.put( - (traj, timestep, time_dict), + (traj, timestep, time_dict, episode_metrics), block=True, timeout=QUEUE_PUT_TIMEOUT, ) @@ -129,7 +136,7 @@ def qsize(self) -> int: def get( self, block: bool = True, timeout: Union[float, None] = None - ) -> Tuple[PPOTransition, TimeStep, Dict]: + ) -> Tuple[PPOTransition, TimeStep, Dict, Metrics]: """Get a trajectory from the pipeline.""" return self._queue.get(block, timeout) # type: ignore