Skip to content

Commit

Permalink
chore: type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Oct 22, 2024
1 parent 3b8d761 commit e13a6e1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 24 deletions.
4 changes: 3 additions & 1 deletion mava/systems/q_learning/anakin/rec_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,9 @@ def update_q(

return next_params, next_opt_state, q_loss_info

def train(train_state: TrainState, _: Any) -> Tuple[TrainState, Metrics]:
def train(
train_state: TrainState[QNetParams], _: Any
) -> Tuple[TrainState[QNetParams], Metrics]:
"""Sample, train and repack."""
# unpack and get keys
buffer_state, params, opt_states, t_train, key = train_state
Expand Down
23 changes: 11 additions & 12 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from flax.core.scope import FrozenVariableDict
from flax.linen import FrozenDict
from jax import Array, tree
from jumanji.env import Environment
from jumanji.types import TimeStep
from omegaconf import DictConfig, OmegaConf
from rich.pretty import pprint
Expand All @@ -46,7 +45,7 @@
TrainState,
Transition,
)
from mava.types import Observation, ObservationGlobalState
from mava.types import MarlEnv, Observation
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import (
Expand All @@ -59,12 +58,14 @@
from mava.wrappers import episode_metrics


# (env, eval_env), learner_state, q_net, q_mixer, opt, rb, logger, key
def init(
cfg: DictConfig,
) -> Tuple[
Tuple[Environment, Environment],
Tuple[MarlEnv, MarlEnv],
LearnerState,
RecQNetwork,
QMixNetwork,
optax.GradientTransformation,
TrajectoryBuffer,
MavaLogger,
Expand Down Expand Up @@ -222,15 +223,15 @@ def replicate(x: Any) -> Any:

def make_update_fns(
cfg: DictConfig,
env: Environment,
env: MarlEnv,
q_net: RecQNetwork,
mixer: QMixNetwork,
opt: optax.GradientTransformation,
rb: TrajectoryBuffer,
) -> Any:
def select_eps_greedy_action(
action_selection_state: ActionSelectionState,
obs: ObservationGlobalState,
obs: Observation,
term_or_trunc: Array,
) -> Tuple[ActionSelectionState, Array]:
"""Select action to take in eps-greedy way. Batch and agent dims are included."""
Expand Down Expand Up @@ -364,12 +365,8 @@ def update_q(
"""Update the Q parameters."""

# Get data aligned with current/next timestep
data_first: Dict[str, chex.Array] = jax.tree_map(
lambda x: x[:, :-1, ...], data
) # (B, T, ...)
data_next: Dict[str, chex.Array] = jax.tree_map(
lambda x: x[:, 1:, ...], data
) # (B, T, ...)
data_first = jax.tree_map(lambda x: x[:, :-1, ...], data) # (B, T, ...)
data_next = jax.tree_map(lambda x: x[:, 1:, ...], data) # (B, T, ...)

first_reward = data_first.reward
next_done = data_next.term_or_trunc
Expand Down Expand Up @@ -450,7 +447,9 @@ def update_q(

return next_params, next_opt_state, q_loss_info

def train(train_state: TrainState, _: Any) -> TrainState:
def train(
train_state: TrainState[QMIXParams], _: Any
) -> Tuple[TrainState[QMIXParams], Metrics]:
"""Sample, train and repack."""

buffer_state, params, opt_states, t_train, key = train_state
Expand Down
25 changes: 14 additions & 11 deletions mava/systems/q_learning/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, NamedTuple
from typing import Dict, Generic, TypeVar

import optax
from chex import PRNGKey
from flashbax.buffers.trajectory_buffer import TrajectoryBufferState
from flax.core.scope import FrozenVariableDict
from jax import Array
from jumanji.env import State
from typing_extensions import TypeAlias
from typing_extensions import NamedTuple, TypeAlias

from mava.types import Observation

Expand Down Expand Up @@ -90,18 +90,21 @@ class ActionState(NamedTuple):
term_or_trunc: Array


class TrainState(NamedTuple):
class QMIXParams(NamedTuple):
online: FrozenVariableDict
target: FrozenVariableDict
mixer_online: FrozenVariableDict
mixer_target: FrozenVariableDict


QLearningParams = TypeVar("QLearningParams", QNetParams, QMIXParams)


class TrainState(NamedTuple, Generic[QLearningParams]):
"""The carry in the training loop."""

buffer_state: BufferState
params: QNetParams
params: QLearningParams
opt_state: optax.OptState
train_steps: Array
key: PRNGKey


class QMIXParams(NamedTuple):
online: FrozenVariableDict
target: FrozenVariableDict
mixer_online: FrozenVariableDict
mixer_target: FrozenVariableDict

0 comments on commit e13a6e1

Please sign in to comment.