Skip to content

Commit

Permalink
chore: clean up comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Oct 22, 2024
1 parent 4915b97 commit 3b8d761
Showing 1 changed file with 23 additions and 38 deletions.
61 changes: 23 additions & 38 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ def replicate(x: Any) -> Any:
key, q_key = jax.random.split(key, 2)

# Shape legend:
# T: Time (dummy dimension size = 1)
# B: Batch (dummy dimension size = 1)
# T: Time
# B: Batch
# A: Agent
# Make dummy inputs to init recurrent Q network -> need shape (T, B, A, ...)
init_obs = env.observation_spec().generate_value() # (A, ...)
# (B, T, A, ...)
init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs)
init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1)
init_x = (init_obs_batched, init_term_or_trunc) # pack the RNN dummy inputs
init_x = (init_obs_batched, init_term_or_trunc)
# (B, A, ...)
init_hidden_state = ScannedRNN.initialize_carry(
(cfg.arch.num_envs, num_agents), cfg.network.hidden_state_dim
Expand Down Expand Up @@ -141,7 +141,7 @@ def replicate(x: Any) -> Any:
# Pack params
params = QMIXParams(q_params, q_target_params, mixer_online_params, mixer_target_params)

# OPTIMISER
# Optimiser
opt = optax.chain(
optax.adam(learning_rate=cfg.system.q_lr),
)
Expand All @@ -152,12 +152,12 @@ def replicate(x: Any) -> Any:
opt_state = replicate(opt_state)
init_hidden_state = replicate(init_hidden_state)

init_acts = env.action_spec().generate_value() # (A,)
init_acts = env.action_spec().generate_value()
init_transition = Transition(
obs=init_obs, # (A, ...)
action=init_acts,
action=init_acts, # (A,)
reward=jnp.zeros((1,), dtype=float),
terminal=jnp.zeros((1,), dtype=bool), # one flag for all agents
terminal=jnp.zeros((1,), dtype=bool),
term_or_trunc=jnp.zeros((1,), dtype=bool),
next_obs=init_obs,
)
Expand Down Expand Up @@ -265,21 +265,18 @@ def action_step(action_state: ActionState, _: Any) -> Tuple[ActionState, Dict]:
parameters for the next step.
"""

# Unpack
action_selection_state, env_state, buffer_state, obs, terminal, term_or_trunc = action_state

# select the actions to take
next_action_selection_state, action = select_eps_greedy_action(
action_selection_state, obs, term_or_trunc
)

# step env with selected actions
next_env_state, next_timestep = jax.vmap(env.step)(env_state, action)

# Get reward
reward = jnp.mean(
next_timestep.reward, axis=-1, keepdims=True
) # NOTE (ruan): combine agent rewards, different to IQL.
) # NOTE: Combine agent rewards, since QMIX is cooperative.

transition = Transition(
obs, action, reward, terminal, term_or_trunc, next_timestep.extras["real_next_obs"]
Expand All @@ -288,13 +285,11 @@ def action_step(action_state: ActionState, _: Any) -> Tuple[ActionState, Dict]:
transition = tree.map(lambda x: x[:, jnp.newaxis, ...], transition)
next_buffer_state = rb.add(buffer_state, transition)

# Nexts
next_obs = next_timestep.observation
# make compatible with network input and transition storage in next step
# Make compatible with network input and transition storage in next step
next_terminal = (1 - next_timestep.discount[..., 0, jnp.newaxis]).astype(bool)
next_term_or_trunc = next_timestep.last()[..., jnp.newaxis]

# Repack
new_act_state = ActionState(
next_action_selection_state,
next_env_state,
Expand Down Expand Up @@ -333,23 +328,23 @@ def q_loss_fn(
"""The portion of the calculation to grad, namely online apply and mse with target."""
q_online_params, online_mixer_params = online_params

# axes switched here to scan over time
# Axes switched to scan over time
hidden_state, obs_term_or_trunc = prep_inputs_to_scannedrnn(obs, term_or_trunc)

# get online q values of all actions
# Get online q values of all actions
_, q_online = q_net.apply(
q_online_params, hidden_state, obs_term_or_trunc, method="get_q_values"
)
q_online = switch_leading_axes(q_online) # (T, B, ...) -> (B, T, ...)
# get the q values of the taken actions and remove extra dim
# Get the q values of the taken actions and remove extra dim
q_online = jnp.squeeze(
jnp.take_along_axis(q_online, action[..., jnp.newaxis], axis=-1), axis=-1
)

q_online = mixer.apply(
online_mixer_params, q_online, obs.global_state[:, :, 0, ...]
) # B,T,A,... -> B,T,1,... # NOTE states are replicated over agents thats
# why we only take first one
) # B,T,A,... -> B,T,1,...
# NOTE: States are replicated over agents so we take only take first one

q_loss = jnp.mean((q_online - target) ** 2)

Expand All @@ -372,15 +367,15 @@ def update_q(
data_first: Dict[str, chex.Array] = jax.tree_map(
lambda x: x[:, :-1, ...], data
) # (B, T, ...)
data_next: Dict[str, chex.Array] = jax.tree_map(lambda x: x[:, 1:, ...], data)
data_next: Dict[str, chex.Array] = jax.tree_map(
lambda x: x[:, 1:, ...], data
) # (B, T, ...)

first_reward = data_first.reward
next_done = data_next.term_or_trunc

# Eps defaults to 0
# Get the greedy action
###############################################################
# using the distribution instead.
# Get the greedy action using the distribution.
# Epsilon defaults to 0.
hidden_state, next_obs_term_or_trunc = prep_inputs_to_scannedrnn(
data.obs,
data.term_or_trunc,
Expand All @@ -389,7 +384,6 @@ def update_q(
next_action = next_greedy_dist.mode() # (T, B, ...)
next_action = switch_leading_axes(next_action) # (T, B, ...) -> (B, T, ...)
next_action = next_action[:, 1:, ...] # (B, T, ...)
###############################################################

hidden_state, next_obs_term_or_trunc = prep_inputs_to_scannedrnn(
data.obs, data.term_or_trunc
Expand All @@ -413,8 +407,6 @@ def update_q(
# TD Target
target_q_val = first_reward + (1.0 - next_done) * cfg.system.gamma * next_q_val

# Update Q function.

q_grad_fn = jax.grad(q_loss_fn, has_aux=True)
q_grads, q_loss_info = q_grad_fn(
(params.online, params.mixer_online),
Expand Down Expand Up @@ -461,14 +453,12 @@ def update_q(
def train(train_state: TrainState, _: Any) -> TrainState:
"""Sample, train and repack."""

# unpack and get keys
buffer_state, params, opt_states, t_train, key = train_state
next_key, buff_key = jax.random.split(key, 2)

# sample
data = rb.sample(buffer_state, buff_key).experience

# learn
# Learn
next_params, next_opt_states, q_loss_info = update_q(params, opt_states, data, t_train)

next_train_state = TrainState(
Expand All @@ -478,17 +468,15 @@ def train(train_state: TrainState, _: Any) -> TrainState:
return next_train_state, q_loss_info

# ---- Act-train loop ----

scanned_act = lambda state: lax.scan(action_step, state, None, length=cfg.system.rollout_length)
scanned_train = lambda state: lax.scan(train, state, None, length=cfg.system.epochs)

# interact and train
# Act and train
def update_step(
learner_state: LearnerState, _: Any
) -> Tuple[LearnerState, Tuple[Metrics, Metrics]]:
"""Interact, then learn. The _ at the end of a var means updated."""

# unpack and get random keys
(
obs,
terminal,
Expand Down Expand Up @@ -548,7 +536,7 @@ def update_step(
donate_argnums=0,
)

return pmaped_updated_step # type:ignore
return pmaped_updated_step


def run_experiment(cfg: DictConfig) -> float:
Expand Down Expand Up @@ -618,9 +606,6 @@ def eval_act_fn(
jax.block_until_ready(learner_state)

# Log:
# Multiply by learn steps here because anakin steps per second is learn + act steps
# But we want to make sure we're counting env steps correctly so it's not included
# in the loop counter.
elapsed_time = time.time() - start_time
eps = jnp.maximum(
cfg.system.eps_min, 1 - (t / cfg.system.eps_decay) * (1 - cfg.system.eps_min)
Expand Down Expand Up @@ -655,7 +640,7 @@ def eval_act_fn(
# Checkpoint:
if cfg.logger.checkpointing.save_model:
# Save checkpoint of learner state
unreplicated_learner_state = unreplicate_n_dims(learner_state) # type: ignore
unreplicated_learner_state = unreplicate_n_dims(learner_state)
checkpointer.save(
timestep=t,
unreplicated_learner_state=unreplicated_learner_state,
Expand Down

0 comments on commit 3b8d761

Please sign in to comment.