Skip to content

Commit

Permalink
chore: shape comments legend
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Oct 28, 2024
1 parent 738ec3c commit e49a22f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
22 changes: 12 additions & 10 deletions mava/systems/q_learning/anakin/rec_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,19 @@ def replicate(x: Any) -> Any:
num_agents = env.num_agents

key, q_key = jax.random.split(key, 2)

# Shape legend:
# T: Time (dummy dimension size = 1)
# B: Batch (dummy dimension size = 1)
# A: Agent
# Make dummy inputs to init recurrent Q network -> need shape (T, B, A, ...)
init_obs = env.observation_spec().generate_value() # (A, ...)
# (B, T, A, ...)
# T: Time
# B: Batch
# N: Agent

# Make dummy inputs to init recurrent Q network -> need shape (T, B, N, ...)
init_obs = env.observation_spec().generate_value() # (N, ...)
# (B, T, N, ...)
init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs)
init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1)
init_x = (init_obs_batched, init_term_or_trunc) # pack the RNN dummy inputs
# (B, A, ...)
# (B, N, ...)
init_hidden_state = ScannedRNN.initialize_carry(
(cfg.arch.num_envs, num_agents), cfg.network.hidden_state_dim
)
Expand Down Expand Up @@ -128,9 +130,9 @@ def replicate(x: Any) -> Any:
init_hidden_state = replicate(init_hidden_state)

# Create dummy transition
init_acts = env.action_spec().generate_value() # (A,)
init_acts = env.action_spec().generate_value() # (N,)
init_transition = Transition(
obs=init_obs, # (A, ...)
obs=init_obs, # (N, ...)
action=init_acts,
reward=jnp.zeros((num_agents,), dtype=float),
terminal=jnp.zeros((1,), dtype=bool), # one flag for all agents
Expand Down Expand Up @@ -226,7 +228,7 @@ def select_eps_greedy_action(
new_key, explore_key = jax.random.split(key, 2)

action = eps_greedy_dist.sample(seed=explore_key)
action = action[0, ...] # (1, B, A) -> (B, A)
action = action[0, ...] # (1, B, N) -> (B, N)

next_action_selection_state = ActionSelectionState(
params, next_hidden_state, t + cfg.arch.num_envs, new_key
Expand Down
21 changes: 11 additions & 10 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,15 @@ def replicate(x: Any) -> Any:
# Shape legend:
# T: Time
# B: Batch
# A: Agent
# Make dummy inputs to init recurrent Q network -> need shape (T, B, A, ...)
init_obs = env.observation_spec().generate_value() # (A, ...)
# (B, T, A, ...)
# N: Agent

# Make dummy inputs to init recurrent Q network -> need shape (T, B, N, ...)
init_obs = env.observation_spec().generate_value() # (N, ...)
# (B, T, N, ...)
init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs)
init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1)
init_x = (init_obs_batched, init_term_or_trunc)
# (B, A, ...)
# (B, N, ...)
init_hidden_state = ScannedRNN.initialize_carry(
(cfg.arch.num_envs, num_agents), cfg.network.hidden_state_dim
)
Expand Down Expand Up @@ -164,8 +165,8 @@ def replicate(x: Any) -> Any:
# episode horizon has been reached. We use this exclusively in QMIX.
# Terminal refers to individual agent dones. We keep this here for consistency with IQL.
init_transition = Transition(
obs=init_obs, # (A, ...)
action=init_acts, # (A,)
obs=init_obs, # (N, ...)
action=init_acts, # (N,)
reward=jnp.zeros((1,), dtype=float),
terminal=jnp.zeros((1,), dtype=bool),
term_or_trunc=jnp.zeros((1,), dtype=bool),
Expand Down Expand Up @@ -261,7 +262,7 @@ def select_eps_greedy_action(
new_key, explore_key = jax.random.split(key, 2)

action = eps_greedy_dist.sample(seed=explore_key)
action = action[0, ...] # (1, B, A) -> (B, A)
action = action[0, ...] # (1, B, N) -> (B, N)

# repack new selection params
next_action_selection_state = ActionSelectionState(
Expand Down Expand Up @@ -352,7 +353,7 @@ def q_loss_fn(
# NOTE: States are replicated over agents so we take only take first one
q_online = mixer.apply(
online_mixer_params, q_online, obs.global_state[:, :, 0, ...]
) # B,T,A,... -> B,T,1,...
) # (B, T, N, ...) -> (B , T, 1 , ...)

q_loss = jnp.mean((q_online - target) ** 2)

Expand Down Expand Up @@ -406,7 +407,7 @@ def update_q(

next_q_val = mixer.apply(
params.mixer_target, next_q_val, data_next.obs.global_state[:, :, 0, ...]
) # B,T,A,... -> B,T,1,...
) # (B, T, N, ...) -> (B , T, 1 , ...)

# TD Target
target_q_val = reward + (1.0 - next_done) * cfg.system.gamma * next_q_val
Expand Down

0 comments on commit e49a22f

Please sign in to comment.