From e13a6e15c698f76f694d94c5a5c0200ea4ba82db Mon Sep 17 00:00:00 2001 From: Ruan de Kock Date: Tue, 22 Oct 2024 10:10:48 +0200 Subject: [PATCH] chore: type hints --- mava/systems/q_learning/anakin/rec_iql.py | 4 +++- mava/systems/q_learning/anakin/rec_qmix.py | 23 ++++++++++---------- mava/systems/q_learning/types.py | 25 ++++++++++++---------- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/mava/systems/q_learning/anakin/rec_iql.py b/mava/systems/q_learning/anakin/rec_iql.py index 2a3c9783c..fa49fbc5e 100644 --- a/mava/systems/q_learning/anakin/rec_iql.py +++ b/mava/systems/q_learning/anakin/rec_iql.py @@ -443,7 +443,9 @@ def update_q( return next_params, next_opt_state, q_loss_info - def train(train_state: TrainState, _: Any) -> Tuple[TrainState, Metrics]: + def train( + train_state: TrainState[QNetParams], _: Any + ) -> Tuple[TrainState[QNetParams], Metrics]: """Sample, train and repack.""" # unpack and get keys buffer_state, params, opt_states, t_train, key = train_state diff --git a/mava/systems/q_learning/anakin/rec_qmix.py b/mava/systems/q_learning/anakin/rec_qmix.py index f518c62bf..84e40127a 100644 --- a/mava/systems/q_learning/anakin/rec_qmix.py +++ b/mava/systems/q_learning/anakin/rec_qmix.py @@ -28,7 +28,6 @@ from flax.core.scope import FrozenVariableDict from flax.linen import FrozenDict from jax import Array, tree -from jumanji.env import Environment from jumanji.types import TimeStep from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint @@ -46,7 +45,7 @@ TrainState, Transition, ) -from mava.types import Observation, ObservationGlobalState +from mava.types import MarlEnv, Observation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.jax_utils import ( @@ -59,12 +58,14 @@ from mava.wrappers import episode_metrics +# (env, eval_env), learner_state, q_net, q_mixer, opt, rb, logger, key def init( cfg: DictConfig, ) -> Tuple[ - Tuple[Environment, Environment], + Tuple[MarlEnv, MarlEnv], LearnerState, RecQNetwork, + QMixNetwork, optax.GradientTransformation, TrajectoryBuffer, MavaLogger, @@ -222,7 +223,7 @@ def replicate(x: Any) -> Any: def make_update_fns( cfg: DictConfig, - env: Environment, + env: MarlEnv, q_net: RecQNetwork, mixer: QMixNetwork, opt: optax.GradientTransformation, @@ -230,7 +231,7 @@ def make_update_fns( ) -> Any: def select_eps_greedy_action( action_selection_state: ActionSelectionState, - obs: ObservationGlobalState, + obs: Observation, term_or_trunc: Array, ) -> Tuple[ActionSelectionState, Array]: """Select action to take in eps-greedy way. Batch and agent dims are included.""" @@ -364,12 +365,8 @@ def update_q( """Update the Q parameters.""" # Get data aligned with current/next timestep - data_first: Dict[str, chex.Array] = jax.tree_map( - lambda x: x[:, :-1, ...], data - ) # (B, T, ...) - data_next: Dict[str, chex.Array] = jax.tree_map( - lambda x: x[:, 1:, ...], data - ) # (B, T, ...) + data_first = jax.tree_map(lambda x: x[:, :-1, ...], data) # (B, T, ...) + data_next = jax.tree_map(lambda x: x[:, 1:, ...], data) # (B, T, ...) first_reward = data_first.reward next_done = data_next.term_or_trunc @@ -450,7 +447,9 @@ def update_q( return next_params, next_opt_state, q_loss_info - def train(train_state: TrainState, _: Any) -> TrainState: + def train( + train_state: TrainState[QMIXParams], _: Any + ) -> Tuple[TrainState[QMIXParams], Metrics]: """Sample, train and repack.""" buffer_state, params, opt_states, t_train, key = train_state diff --git a/mava/systems/q_learning/types.py b/mava/systems/q_learning/types.py index f72ed4372..e5f33062d 100644 --- a/mava/systems/q_learning/types.py +++ b/mava/systems/q_learning/types.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, NamedTuple +from typing import Dict, Generic, TypeVar import optax from chex import PRNGKey @@ -19,7 +19,7 @@ from flax.core.scope import FrozenVariableDict from jax import Array from jumanji.env import State -from typing_extensions import TypeAlias +from typing_extensions import NamedTuple, TypeAlias from mava.types import Observation @@ -90,18 +90,21 @@ class ActionState(NamedTuple): term_or_trunc: Array -class TrainState(NamedTuple): +class QMIXParams(NamedTuple): + online: FrozenVariableDict + target: FrozenVariableDict + mixer_online: FrozenVariableDict + mixer_target: FrozenVariableDict + + +QLearningParams = TypeVar("QLearningParams", QNetParams, QMIXParams) + + +class TrainState(NamedTuple, Generic[QLearningParams]): """The carry in the training loop.""" buffer_state: BufferState - params: QNetParams + params: QLearningParams opt_state: optax.OptState train_steps: Array key: PRNGKey - - -class QMIXParams(NamedTuple): - online: FrozenVariableDict - target: FrozenVariableDict - mixer_online: FrozenVariableDict - mixer_target: FrozenVariableDict