Skip to content

Commit

Permalink
feat: add option for hard or soft target updates
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Oct 22, 2024
1 parent e13a6e1 commit 63eb99f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
6 changes: 1 addition & 5 deletions mava/configs/system/q_learning/rec_qmix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ seed: 1

# --- Agent observations ---
add_agent_id: True
observe_step_count: False
# This flag toggles whether the agent IDs will be rotational (ie. only two continuous values on the
# unit circle) if True otherwise agent IDs are one-hot.
rotational_agent_ids: False

# --- RL hyperparameters ---
min_buffer_size: 32
Expand All @@ -28,7 +24,7 @@ q_lr: 3e-5 # the learning rate of the Q network network optimizer
max_grad_norm: 10 # value used to clip optimiser - set big for no clipping

# other
hard_update: False
hard_update: True
update_period: 200
tau: 0.01 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
Expand Down
26 changes: 14 additions & 12 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,19 +424,21 @@ def update_q(
(params.online, params.mixer_online), q_updates
)

# TODO (ruan): Implement soft target network update.
# Target network update.
# next_target_params = jax.lax.select(
# cfg.system.hard_update,
next_target_params = optax.periodic_update(
next_online_params, params.target, t_train, cfg.system.update_period
)
next_mixer_target_params = optax.periodic_update(
next_mixer_params, params.mixer_target, t_train, cfg.system.update_period
)
# optax.incremental_update(next_online_params, params.target, cfg.system.tau)
# )

if cfg.system.hard_update:
next_target_params = optax.periodic_update(
next_online_params, params.target, t_train, cfg.system.update_period
)
next_mixer_target_params = optax.periodic_update(
next_mixer_params, params.mixer_target, t_train, cfg.system.update_period
)
else:
next_target_params = optax.incremental_update(
next_online_params, params.target, cfg.system.tau
)
next_mixer_target_params = optax.incremental_update(
next_mixer_params, params.mixer_target, cfg.system.tau
)
# Repack params and opt_states.
next_params = QMIXParams(
next_online_params,
Expand Down

0 comments on commit 63eb99f

Please sign in to comment.