Skip to content

Commit

Permalink
chore: rename data variables in training
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Oct 28, 2024
1 parent fc09189 commit 5f3f8e0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
26 changes: 13 additions & 13 deletions mava/systems/q_learning/anakin/rec_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,26 +326,26 @@ def q_loss_fn(
return q_loss, loss_info

def update_q(
params: QNetParams, opt_states: optax.OptState, data: Transition, t_train: int
params: QNetParams, opt_states: optax.OptState, data_full: Transition, t_train: int
) -> Tuple[QNetParams, optax.OptState, Metrics]:
"""Update the Q parameters."""
# Get data aligned with current/next timestep
data_t0 = tree.map(lambda x: x[:, :-1, ...], data)
data_t1 = tree.map(lambda x: x[:, 1:, ...], data)
data = tree.map(lambda x: x[:, :-1, ...], data_full)
data_next = tree.map(lambda x: x[:, 1:, ...], data_full)

obs = data_t0.obs
term_or_trunc = data_t0.term_or_trunc
reward = data_t0.reward
action = data_t0.action
obs = data.obs
term_or_trunc = data.term_or_trunc
reward = data.reward
action = data.action

# The three following variables all come from the same time step.
# They are stored and accessed in this way because of the `AutoResetWrapper`.
# At the end of an episode `data_t0.next_obs` and `data_t1.obs` will be
# different, which is why we need to store both. Thus `data_t0.next_obs`
# aligns with the `terminal` from `data_t1`.
next_obs = data_t0.next_obs
next_term_or_trunc = data_t1.term_or_trunc
next_terminal = data_t1.terminal
# At the end of an episode `data.next_obs` and `data_next.obs` will be
# different, which is why we need to store both. Thus `data.next_obs`
# aligns with the `terminal` from `data_next`.
next_obs = data.next_obs
next_term_or_trunc = data_next.term_or_trunc
next_terminal = data_next.terminal

# Scan over each sample
hidden_state, next_obs_term_or_trunc = prep_inputs_to_scannedrnn(
Expand Down
28 changes: 14 additions & 14 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,29 +368,29 @@ def q_loss_fn(
return q_loss, loss_info

def update_q(
params: QMIXParams, opt_states: optax.OptState, data: Transition, t_train: int
params: QMIXParams, opt_states: optax.OptState, data_full: Transition, t_train: int
) -> Tuple[QMIXParams, optax.OptState, Metrics]:
"""Update the Q parameters."""

# Get data aligned with current/next timestep
data_t0 = tree.map(lambda x: x[:, :-1, ...], data) # (B, T, ...)
data_t1 = tree.map(lambda x: x[:, 1:, ...], data) # (B, T, ...)
data = tree.map(lambda x: x[:, :-1, ...], data_full) # (B, T, ...)
data_next = tree.map(lambda x: x[:, 1:, ...], data_full) # (B, T, ...)

reward_t0 = data_t0.reward
next_done = data_t1.term_or_trunc
reward = data.reward
next_done = data_next.term_or_trunc

# Get the greedy action using the distribution.
# Epsilon defaults to 0.
hidden_state, next_obs_term_or_trunc = prep_inputs_to_scannedrnn(
data.obs, data.term_or_trunc
data_full.obs, data_full.term_or_trunc
) # (T, B, ...)
_, next_greedy_dist = q_net.apply(params.online, hidden_state, next_obs_term_or_trunc)
next_action = next_greedy_dist.mode() # (T, B, ...)
next_action = switch_leading_axes(next_action) # (T, B, ...) -> (B, T, ...)
next_action = next_action[:, 1:, ...] # (B, T, ...)

hidden_state, next_obs_term_or_trunc = prep_inputs_to_scannedrnn(
data.obs, data.term_or_trunc
data_full.obs, data_full.term_or_trunc
) # (T, B, ...)

_, next_q_vals_target = q_net.apply(
Expand All @@ -405,23 +405,23 @@ def update_q(
)

next_q_val = mixer.apply(
params.mixer_target, next_q_val, data_t1.obs.global_state[:, :, 0, ...]
params.mixer_target, next_q_val, data_next.obs.global_state[:, :, 0, ...]
) # B,T,A,... -> B,T,1,...

# TD Target
target_q_val = reward_t0 + (1.0 - next_done) * cfg.system.gamma * next_q_val
target_q_val = reward + (1.0 - next_done) * cfg.system.gamma * next_q_val

q_grad_fn = jax.grad(q_loss_fn, has_aux=True)
q_grads, q_loss_info = q_grad_fn(
(params.online, params.mixer_online),
data_t0.obs,
data_t0.term_or_trunc,
data_t0.action,
data.obs,
data.term_or_trunc,
data.action,
target_q_val,
)
q_loss_info["mean_reward_t0"] = jnp.mean(reward_t0)
q_loss_info["mean_reward_t0"] = jnp.mean(reward)
q_loss_info["mean_next_qval"] = jnp.mean(next_q_val)
q_loss_info["done"] = jnp.mean(data.term_or_trunc)
q_loss_info["done"] = jnp.mean(data_full.term_or_trunc)

# Mean over the device and batch dimension.
q_grads, q_loss_info = lax.pmean((q_grads, q_loss_info), axis_name="device")
Expand Down

0 comments on commit 5f3f8e0

Please sign in to comment.