Skip to content

Commit

Permalink
Bug fixes and improvements:
Browse files Browse the repository at this point in the history
- Fix bug in computing PLR replay distribution and state reset after first episode.
- Update hyperparameter configs for PLR-based methods
- Only log env metrics for passable levels
- Update package requirements
- Fix module path in output of minimax.make_cmd
- Update version number to 0.2.0
- Note: The technical report on arXiv has been updated to reflect improved results from these changes.
  • Loading branch information
minqi committed Aug 24, 2024
1 parent 2ae9e04 commit 3fb8b46
Show file tree
Hide file tree
Showing 23 changed files with 1,020 additions and 959 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ src/*.egg-info
**/.ipynb_checkpoints/
config/
!src/minimax/config
build/
91 changes: 46 additions & 45 deletions README.md

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
[project]
name = "minimax-lib"
version = "0.1.0"
version = "0.2.0"
authors = [
{name="Minqi Jiang"},
{email="msj@meta.com"},
]
description = "Efficient baselines for autocurricula in JAX"
readme = "README.md"
license = {text = "Apache 2.0"}
requires-python = ">=3.9"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent"
]
dependencies = [
"numpy>=1.25,<1.26",
"pandas==1.5.3",
"jax>=0.4.19",
"jaxlib>=0.4.19",
"jax>=0.4.31",
"jaxlib>=0.4.31",
"flax>=0.7.4",
"optax>=0.1.7",
"tensorflow_probability==0.23.0",
"chex>=0.1.83",
"wandb>=0.13",
"ipython>=7.34.0",
"GitPython>=3.1.29"
"GitPython>=3.1.29",
"tqdm>=4.66.1"
]

[tool.hatchling.scripts]
Expand Down
2 changes: 1 addition & 1 deletion src/minimax/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _update_minibatch(carry, step):

return train_state, loss_stats

@partial(jax.jit, static_argnums=(0,2,4))
@partial(jax.jit, static_argnums=(0,2))
def _loss(
self,
params,
Expand Down
144 changes: 72 additions & 72 deletions src/minimax/config/configs/maze/accel.json
Original file line number Diff line number Diff line change
@@ -1,73 +1,73 @@
{
"args": {
"seed": [1],
"agent_rl_algo": ["ppo"],
"n_total_updates": [30000],
"train_runner": ["plr"],
"n_devices": [1],
"student_model_name": ["default_student_cnn"],
"env_name": ["Maze"],
"verbose": [false],
"log_dir": ["~/logs/minimax"],
"log_interval": [10],
"from_last_checkpoint": [true],
"checkpoint_interval": [1000],
"archive_interval": [0],
"archive_init_checkpoint": [false],
"test_interval": [100],
"n_students": [1],
"n_parallel": [32],
"n_eval": [1],
"n_rollout_steps": [256],
"lr": [0.0001],
"lr_anneal_steps": [0],
"max_grad_norm": [0.5],
"adam_eps": [1e-05],
"track_env_metrics": [true],
"discount": [0.995],
"n_unroll_rollout": [10],
"render": [false],
"ued_score": ["max_mc"],
"plr_replay_prob": [0.8],
"plr_buffer_size": [4000],
"plr_staleness_coef": [0.3],
"plr_temp": [0.1],
"plr_use_score_ranks": [true],
"plr_min_fill_ratio": [0.5],
"plr_use_robust_plr": [true],
"plr_use_parallel_eval": [false],
"plr_force_unique": [true],
"plr_mutation_fn": ["default"],
"plr_n_mutations": [20],
"plr_mutation_criterion": ["batch"],
"plr_mutation_subsample_size": [4],
"student_gae_lambda": [0.98],
"student_entropy_coef": [0.0],
"student_value_loss_coef": [0.5],
"student_n_unroll_update": [5],
"student_ppo_n_epochs": [5],
"student_ppo_n_minibatches": [1],
"student_ppo_clip_eps": [0.2],
"student_ppo_clip_value_loss": [true],
"student_recurrent_arch": ["lstm"],
"student_recurrent_hidden_dim": [256],
"student_hidden_dim": [32],
"student_n_hidden_layers": [1],
"student_n_conv_filters": [16],
"student_n_scalar_embeddings": [4],
"student_scalar_embed_dim": [5],
"maze_height": [13],
"maze_width": [13],
"maze_n_walls": [0],
"maze_replace_wall_pos": [true],
"maze_sample_n_walls": [false],
"maze_see_agent": [false],
"maze_normalize_obs": [true],
"maze_obs_agent_pos": [false],
"maze_max_episode_steps": [250],
"test_n_episodes": [10],
"test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"],
"maze_test_see_agent": [false],
"maze_test_normalize_obs": [true]
}
}
"args": {
"seed": [1],
"agent_rl_algo": ["ppo"],
"n_total_updates": [30000],
"train_runner": ["plr"],
"n_devices": [1],
"student_model_name": ["default_student_cnn"],
"env_name": ["Maze"],
"verbose": [false],
"log_dir": ["~/logs/minimax"],
"log_interval": [10],
"from_last_checkpoint": [true],
"checkpoint_interval": [1000],
"archive_interval": [0],
"archive_init_checkpoint": [false],
"test_interval": [100],
"n_students": [1],
"n_parallel": [32],
"n_eval": [1],
"n_rollout_steps": [256],
"lr": [0.0003],
"lr_anneal_steps": [0],
"max_grad_norm": [0.5],
"adam_eps": [1e-5],
"track_env_metrics": [true],
"discount": [0.999],
"n_unroll_rollout": [10],
"render": [false],
"ued_score": ["max_mc"],
"plr_replay_prob": [0.8],
"plr_buffer_size": [4000],
"plr_staleness_coef": [0.5],
"plr_temp": [0.3],
"plr_use_score_ranks": [true],
"plr_min_fill_ratio": [0.5],
"plr_use_robust_plr": [true],
"plr_use_parallel_eval": [false],
"plr_force_unique": [true],
"plr_mutation_fn": ["default"],
"plr_n_mutations": [20],
"plr_mutation_criterion": ["batch"],
"plr_mutation_subsample_size": [4],
"student_gae_lambda": [0.98],
"student_entropy_coef": [0.0],
"student_value_loss_coef": [0.5],
"student_n_unroll_update": [5],
"student_ppo_n_epochs": [5],
"student_ppo_n_minibatches": [1],
"student_ppo_clip_eps": [0.2],
"student_ppo_clip_value_loss": [true],
"student_recurrent_arch": ["lstm"],
"student_recurrent_hidden_dim": [256],
"student_hidden_dim": [32],
"student_n_hidden_layers": [1],
"student_n_conv_filters": [16],
"student_n_scalar_embeddings": [4],
"student_scalar_embed_dim": [5],
"maze_height": [13],
"maze_width": [13],
"maze_n_walls": [0],
"maze_replace_wall_pos": [true],
"maze_sample_n_walls": [false],
"maze_see_agent": [false],
"maze_normalize_obs": [true],
"maze_obs_agent_pos": [false],
"maze_max_episode_steps": [250],
"test_n_episodes": [10],
"test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"],
"maze_test_see_agent": [false],
"maze_test_normalize_obs": [true]
}
}
116 changes: 58 additions & 58 deletions src/minimax/config/configs/maze/dr.json
Original file line number Diff line number Diff line change
@@ -1,59 +1,59 @@
{
"args": {
"seed": [1],
"agent_rl_algo": ["ppo"],
"n_total_updates": [30000],
"train_runner": ["dr"],
"n_devices": [1],
"student_model_name": ["default_student_cnn"],
"env_name": ["Maze"],
"verbose": [false],
"log_dir": ["~/logs/minimax"],
"log_interval": [10],
"from_last_checkpoint": [true],
"checkpoint_interval": [1000],
"archive_interval": [0],
"archive_init_checkpoint": [false],
"test_interval": [100],
"n_students": [1],
"n_parallel": [32],
"n_eval": [1],
"n_rollout_steps": [256],
"lr": [0.0001],
"lr_anneal_steps": [0],
"max_grad_norm": [0.5],
"adam_eps": [1e-05],
"track_env_metrics": [true],
"discount": [0.995],
"n_unroll_rollout": [10],
"render": [false],
"student_gae_lambda": [0.98],
"student_entropy_coef": [0.001],
"student_value_loss_coef": [0.5],
"student_n_unroll_update": [5],
"student_ppo_n_epochs": [5],
"student_ppo_n_minibatches": [1],
"student_ppo_clip_eps": [0.2],
"student_ppo_clip_value_loss": [true],
"student_recurrent_arch": ["lstm"],
"student_recurrent_hidden_dim": [256],
"student_hidden_dim": [32],
"student_n_hidden_layers": [1],
"student_n_conv_filters": [16],
"student_n_scalar_embeddings": [4],
"student_scalar_embed_dim": [5],
"maze_height": [13],
"maze_width": [13],
"maze_n_walls": [60],
"maze_replace_wall_pos": [true],
"maze_sample_n_walls": [false],
"maze_see_agent": [false],
"maze_normalize_obs": [true],
"maze_obs_agent_pos": [false],
"maze_max_episode_steps": [250],
"test_n_episodes": [10],
"test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"],
"maze_test_see_agent": [false],
"maze_test_normalize_obs": [true]
}
}
"args": {
"seed": [1],
"agent_rl_algo": ["ppo"],
"n_total_updates": [30000],
"train_runner": ["dr"],
"n_devices": [1],
"student_model_name": ["default_student_cnn"],
"env_name": ["Maze"],
"verbose": [false],
"log_dir": ["~/logs/minimax"],
"log_interval": [10],
"from_last_checkpoint": [true],
"checkpoint_interval": [1000],
"archive_interval": [0],
"archive_init_checkpoint": [false],
"test_interval": [100],
"n_students": [1],
"n_parallel": [32],
"n_eval": [1],
"n_rollout_steps": [256],
"lr": [0.0001],
"lr_anneal_steps": [0],
"max_grad_norm": [0.5],
"adam_eps": [1e-5],
"track_env_metrics": [true],
"discount": [0.995],
"n_unroll_rollout": [10],
"render": [false],
"student_gae_lambda": [0.98],
"student_entropy_coef": [0.001],
"student_value_loss_coef": [0.5],
"student_n_unroll_update": [5],
"student_ppo_n_epochs": [5],
"student_ppo_n_minibatches": [1],
"student_ppo_clip_eps": [0.2],
"student_ppo_clip_value_loss": [true],
"student_recurrent_arch": ["lstm"],
"student_recurrent_hidden_dim": [256],
"student_hidden_dim": [32],
"student_n_hidden_layers": [1],
"student_n_conv_filters": [16],
"student_n_scalar_embeddings": [4],
"student_scalar_embed_dim": [5],
"maze_height": [13],
"maze_width": [13],
"maze_n_walls": [60],
"maze_replace_wall_pos": [true],
"maze_sample_n_walls": [false],
"maze_see_agent": [false],
"maze_normalize_obs": [true],
"maze_obs_agent_pos": [false],
"maze_max_episode_steps": [250],
"test_n_episodes": [10],
"test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"],
"maze_test_see_agent": [false],
"maze_test_normalize_obs": [true]
}
}
Loading

0 comments on commit 3fb8b46

Please sign in to comment.