Skip to content

Commit

Permalink
tree map replace fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
amacrutherford committed Dec 13, 2024
1 parent 69e1888 commit 050d3c8
Show file tree
Hide file tree
Showing 29 changed files with 67 additions and 63 deletions.
6 changes: 3 additions & 3 deletions baselines/IPPO/ippo_cnn_overcooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree.map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand Down
4 changes: 2 additions & 2 deletions baselines/IPPO/ippo_ff_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ def _loss_fn(params, traj_batch, gae, targets):
batch = (traj_batch, advantages.squeeze(), targets.squeeze())
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down
6 changes: 3 additions & 3 deletions baselines/IPPO/ippo_ff_mabrax.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree.map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand Down
6 changes: 3 additions & 3 deletions baselines/IPPO/ippo_ff_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree.map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand Down
6 changes: 3 additions & 3 deletions baselines/IPPO/ippo_ff_mpe_facmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree.map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand Down
6 changes: 3 additions & 3 deletions baselines/IPPO/ippo_ff_overcooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree.map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand Down
6 changes: 3 additions & 3 deletions baselines/IPPO/ippo_ff_switch_riddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree.map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand Down
4 changes: 2 additions & 2 deletions baselines/IPPO/ippo_rnn_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
batch = (init_hstate, traj_batch, advantages.squeeze(), targets.squeeze())
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down
4 changes: 2 additions & 2 deletions baselines/IPPO/ippo_rnn_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
)
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down
4 changes: 2 additions & 2 deletions baselines/IPPO/ippo_rnn_smax.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
)
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down
4 changes: 2 additions & 2 deletions baselines/MAPPO/mappo_ff_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,11 @@ def _critic_loss_fn(critic_params, traj_batch, targets):
)
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down
4 changes: 2 additions & 2 deletions baselines/MAPPO/mappo_rnn_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,11 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets):
)
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down
4 changes: 2 additions & 2 deletions baselines/MAPPO/mappo_rnn_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,11 +448,11 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets):
)
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down
4 changes: 2 additions & 2 deletions baselines/MAPPO/mappo_rnn_smax.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,11 +479,11 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets):
)
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree.map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down
2 changes: 1 addition & 1 deletion baselines/QLearning/iql_cnn_overcooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _step_env(carry, _):
) # update timesteps count

# BUFFER UPDATE
timesteps = jax.tree_util.tree.map(
timesteps = jax.tree.map(
lambda x: x.reshape(-1, *x.shape[2:]), timesteps
) # (num_envs*num_steps, ...)
buffer_state = buffer.add(buffer_state, timesteps)
Expand Down
2 changes: 1 addition & 1 deletion baselines/QLearning/iql_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _step_env(carry, _):
) # update timesteps count

# BUFFER UPDATE
buffer_traj_batch = jax.tree_util.tree.map(
buffer_traj_batch = jax.tree.map(
lambda x: jnp.swapaxes(x, 0, 1)[
:, np.newaxis
], # put the batch dim first and add a dummy sequence dim
Expand Down
2 changes: 1 addition & 1 deletion baselines/QLearning/pqn_vdn_cnn_overcooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def preprocess_transition(x, rng):
return x

rng, _rng = jax.random.split(rng)
minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: preprocess_transition(x, _rng),
transitions,
) # num_minibatches, num_agents, num_envs/num_minbatches ...
Expand Down
2 changes: 1 addition & 1 deletion baselines/QLearning/pqn_vdn_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def preprocess_transition(x, rng):
return x

rng, _rng = jax.random.split(rng)
minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: preprocess_transition(x, _rng),
transitions,
) # num_minibatches, num_agents, num_envs/num_minbatches ...
Expand Down
4 changes: 2 additions & 2 deletions baselines/QLearning/pqn_vdn_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _learn_phase(carry, minibatch):
minibatch.last_done,
)
# batchify the agent input: num_agents*batch_size
agent_in = jax.tree_util.tree.map(
agent_in = jax.tree.map(
lambda x: x.reshape(x.shape[0], -1, *x.shape[3:]), agent_in
) # (num_steps, num_agents*batch_size, ...)

Expand Down Expand Up @@ -442,7 +442,7 @@ def preprocess_transition(x, rng):
return x

rng, _rng = jax.random.split(rng)
minibatches = jax.tree_util.tree.map(
minibatches = jax.tree.map(
lambda x: preprocess_transition(x, _rng),
memory_transitions,
) # num_minibatches, num_steps+memory_window, num_agents, batch_size/num_minbatches, num_agents, ...
Expand Down
2 changes: 1 addition & 1 deletion baselines/QLearning/qmix_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _step_env(carry, _):
) # update timesteps count

# BUFFER UPDATE
buffer_traj_batch = jax.tree_util.tree.map(
buffer_traj_batch = jax.tree.map(
lambda x: jnp.swapaxes(x, 0, 1)[
:, np.newaxis
], # put the batch dim first and add a dummy sequence dim
Expand Down
10 changes: 5 additions & 5 deletions baselines/QLearning/shaq.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def _env_step(step_state, unused):
# get the q_values from the agent network
hstate, q_vals = homogeneous_pass(params, hstate, obs_, dones_)
# remove the dummy time_step dimension and index qs by the valid actions of each agent
valid_q_vals = jax.tree_util.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions)
valid_q_vals = jax.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions)
# explore with epsilon greedy_exploration
actions = explorer.choose_actions(valid_q_vals, t, key_a)

Expand Down Expand Up @@ -488,7 +488,7 @@ def _env_step(step_state, unused):
)

# BUFFER UPDATE: save the collected trajectory in the buffer
buffer_traj_batch = jax.tree_util.tree.map(
buffer_traj_batch = jax.tree.map(
lambda x:jnp.swapaxes(x, 0, 1)[:, np.newaxis], # put the batch dim first and add a dummy sequence dim
traj_batch
) # (num_envs, 1, time_steps, ...)
Expand Down Expand Up @@ -519,7 +519,7 @@ def _loss_fn(params_agent, params_mixer, target_network_params_agent, target_net
)

# get the target q value of the greedy actions for each agent
valid_q_vals = jax.tree_util.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions)
valid_q_vals = jax.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions)
target_max_qvals = jax.tree.map(
lambda t_q, q: q_of_action(t_q, jnp.argmax(q, axis=-1))[1:], # avoid first timestep
target_q_vals,
Expand Down Expand Up @@ -642,7 +642,7 @@ def _td_lambda_target(ret, values):
'timesteps': time_state['timesteps']*config['NUM_ENVS'],
'updates' : time_state['updates'],
'loss': loss,
'rewards': jax.tree_util.tree.map(lambda x: jnp.sum(x, axis=0).mean(), traj_batch.rewards),
'rewards': jax.tree.map(lambda x: jnp.sum(x, axis=0).mean(), traj_batch.rewards),
'eps': explorer.get_epsilon(time_state['timesteps'])
}
metrics['test_metrics'] = test_metrics # add the test metrics dictionary
Expand Down Expand Up @@ -691,7 +691,7 @@ def _greedy_env_step(step_state, unused):
obs_ = jax.tree.map(lambda x: x[np.newaxis, :], obs_)
dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones)
hstate, q_vals = homogeneous_pass(params, hstate, obs_, dones_)
actions = jax.tree_util.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions)
actions = jax.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions)
obs, env_state, rewards, dones, infos = test_env.batch_step(key_s, env_state, actions)
step_state = (params, env_state, obs, dones, hstate, rng)
return step_state, (rewards, dones, infos)
Expand Down
14 changes: 7 additions & 7 deletions baselines/QLearning/transf_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ class Transition(NamedTuple):

def tree_mean(tree):
return jnp.array(
jax.tree_util.tree_leaves(jax.tree.map(lambda x: x.mean(), tree))
jax.tree_leaves(jax.tree.map(lambda x: x.mean(), tree))
).mean()


Expand Down Expand Up @@ -487,8 +487,8 @@ def _env_sample_step(env_state, unused):
network_stats = {'agent':agent_params['batch_stats'],'mixer':mixer_params['batch_stats']}

# print number of params
agent_params = sum(x.size for x in jax.tree_util.tree_leaves(network_params['agent']))
mixer_params = sum(x.size for x in jax.tree_util.tree_leaves(network_params['mixer']))
agent_params = sum(x.size for x in jax.tree_leaves(network_params['agent']))
mixer_params = sum(x.size for x in jax.tree_leaves(network_params['mixer']))
jax.debug.print("Number of agent params: {x}", x=agent_params)
jax.debug.print("Number of mixer params: {x}", x=mixer_params)

Expand Down Expand Up @@ -609,7 +609,7 @@ def _env_step(step_state, unused):
# get the q_values from the agent netwoek
_, hstate, q_vals = homogeneous_pass(env_params, env_batch_norm, hstate, obs_, dones_, train=False)
# remove the dummy time_step dimension and index qs by the valid actions of each agent
valid_q_vals = jax.tree_util.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions)
valid_q_vals = jax.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions)
# explore with epsilon greedy_exploration
actions = explorer.choose_actions(valid_q_vals, t, key_a)

Expand Down Expand Up @@ -641,7 +641,7 @@ def _env_step(step_state, unused):
)

# BUFFER UPDATE: save the collected trajectory in the buffer
buffer_traj_batch = jax.tree_util.tree.map(
buffer_traj_batch = jax.tree.map(
lambda x:jnp.swapaxes(x, 0, 1)[:, np.newaxis], # put the batch dim first and add a dummy sequence dim
traj_batch
) # (num_envs, 1, time_steps, ...)
Expand Down Expand Up @@ -702,7 +702,7 @@ def _loss_fn(params, init_hs, learn_traj):
)

# get the target q value of the greedy actions for each agent
valid_q_vals = jax.tree_util.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions)
valid_q_vals = jax.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions)
target_max_qvals = jax.tree.map(
lambda t_q, q: q_of_action(t_q, jnp.argmax(q, axis=-1))[1:], # avoid first timestep
target_q_vals,
Expand Down Expand Up @@ -867,7 +867,7 @@ def _greedy_env_step(step_state, unused):
obs_ = jax.tree.map(lambda x: x[np.newaxis, :], obs_)
dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones)
_, hstate, q_vals = homogeneous_pass(env_params, env_batch_norm, hstate, obs_, dones_, train=False)
actions = jax.tree_util.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions)
actions = jax.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions)
obs, env_state, rewards, dones, infos = test_env.batch_step(key_s, env_state, actions)
step_state = (env_state, obs, dones, hstate, rng)
return step_state, (rewards, dones, infos)
Expand Down
Loading

0 comments on commit 050d3c8

Please sign in to comment.