Skip to content

Commit

Permalink
fix: give each learner a unique random key
Browse files Browse the repository at this point in the history
  • Loading branch information
Louay-Ben-nessir committed Nov 4, 2024
1 parent c6d460f commit 659a837
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion mava/configs/default/ff_ippo_sebulba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: sebulba
- system: ppo/ff_ippo
- network: mlp # [mlp, continuous_mlp, cnn]
- env: smac_gym # [rware_gym, lbf_gym, smac_gym]
- env: lbf_gym # [rware_gym, lbf_gym, smac_gym]
- _self_

hydra:
Expand Down
9 changes: 6 additions & 3 deletions mava/systems/ppo/sebulba/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def _critic_loss_fn(
return (new_params, new_opt_state, key), loss_info

params, opt_states, traj_batch, advantages, targets, key = update_state
key = jnp.squeeze(key, axis=0) # Remove the learner_devices axis
key, shuffle_key, entropy_key = jax.random.split(key, 3)
key = jnp.expand_dims(key, axis=0) # add the learner_devices axis for shape consitency
# Shuffle minibatches
batch_size = config.system.rollout_length * num_learner_envs
permutation = jax.random.permutation(shuffle_key, batch_size)
Expand Down Expand Up @@ -518,8 +520,8 @@ def learner_setup(
apply_fns = (actor_network.apply, critic_network.apply)
update_fns = (actor_optim.update, critic_optim.update)

# defines how the learner state is sharded: params, opt and key = replicated, timestep = sharded
learn_state_spec = LearnerState(model_spec, model_spec, model_spec, None, data_spec)
# defines how the learner state is sharded: params, opt and key = sharded, timestep = sharded
learn_state_spec = LearnerState(model_spec, model_spec, data_spec, None, data_spec)
learn = get_learner_step_fn(apply_fns, update_fns, config)
learn = jax.jit(
shard_map(
Expand All @@ -542,7 +544,8 @@ def learner_setup(
params = restored_params

# Define params to be replicated across devices and batches.
key, step_keys = jax.random.split(key)
key, *step_keys = jax.random.split(key, len(learner_devices) + 1)
step_keys = jnp.stack(step_keys, 0)
opt_states = OptStates(actor_opt_state, critic_opt_state)

# Duplicate learner across Learner devices.
Expand Down
2 changes: 1 addition & 1 deletion mava/wrappers/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _format_observation(
) -> Union[Observation, ObservationGlobalState]:
"""Create an observation from the raw observation and environment state."""

# (num_agents, num_envs, ...) -> (num_envs, num_agents, ...)
# (N, B, O) -> (B, N, O)
obs = np.array(obs).swapaxes(0, 1)
action_mask = np.stack(info["actions_mask"])
obs_data = {"agents_view": obs, "action_mask": action_mask}
Expand Down

0 comments on commit 659a837

Please sign in to comment.