-
Notifications
You must be signed in to change notification settings - Fork 327
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Algorithm] IMPALA and VTrace module (#1506)
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
- Loading branch information
1 parent
c2edf35
commit b38d4b7
Showing
20 changed files
with
2,140 additions
and
187 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
## Reproducing Importance Weighted Actor-Learner Architecture (IMPALA) Algorithm Results | ||
|
||
This repository contains scripts that enable training agents using the IMPALA Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Espeholt et al. 2018. | ||
|
||
## Examples Structure | ||
|
||
Please note that we provide 2 examples, one for single node training and one for distributed training. Both examples rely on the same utils file, but besides that are independent. Each example contains the following files: | ||
|
||
1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. impala_single_node_ray.py). | ||
|
||
2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils.py). | ||
|
||
3. **Configuration File:** This file includes default hyperparameters specified in the original paper. For the multi-node case, the file also includes the configuration file of the Ray cluster. Users can modify these hyperparameters to customize their experiments (e.g. config_single_node.yaml). | ||
|
||
|
||
## Running the Examples | ||
|
||
You can execute the single node IMPALA algorithm on Atari environments by running the following command: | ||
|
||
```bash | ||
python impala_single_node.py | ||
``` | ||
|
||
You can execute the multi-node IMPALA algorithm on Atari environments by running the following command: | ||
|
||
```bash | ||
python impala_single_node_ray.py | ||
``` | ||
or | ||
|
||
```bash | ||
python impala_single_node_submitit.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Environment | ||
env: | ||
env_name: PongNoFrameskip-v4 | ||
|
||
# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html | ||
ray_init_config: | ||
address: null | ||
num_cpus: null | ||
num_gpus: null | ||
resources: null | ||
object_store_memory: null | ||
local_mode: False | ||
ignore_reinit_error: False | ||
include_dashboard: null | ||
dashboard_host: 127.0.0.1 | ||
dashboard_port: null | ||
job_config: null | ||
configure_logging: True | ||
logging_level: info | ||
logging_format: null | ||
log_to_driver: True | ||
namespace: null | ||
runtime_env: null | ||
storage: null | ||
|
||
# Device for the forward and backward passes | ||
local_device: "cuda:0" | ||
|
||
# Resources assigned to each IMPALA rollout collection worker | ||
remote_worker_resources: | ||
num_cpus: 1 | ||
num_gpus: 0.25 | ||
memory: 1073741824 # 1*1024**3 - 1GB | ||
|
||
# collector | ||
collector: | ||
frames_per_batch: 80 | ||
total_frames: 200_000_000 | ||
num_workers: 12 | ||
|
||
# logger | ||
logger: | ||
backend: wandb | ||
exp_name: Atari_IMPALA | ||
test_interval: 200_000_000 | ||
num_test_episodes: 3 | ||
|
||
# Optim | ||
optim: | ||
lr: 0.0006 | ||
eps: 1e-8 | ||
weight_decay: 0.0 | ||
momentum: 0.0 | ||
alpha: 0.99 | ||
max_grad_norm: 40.0 | ||
anneal_lr: True | ||
|
||
# loss | ||
loss: | ||
gamma: 0.99 | ||
batch_size: 32 | ||
sgd_updates: 1 | ||
critic_coef: 0.5 | ||
entropy_coef: 0.01 | ||
loss_critic_type: l2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Environment | ||
env: | ||
env_name: PongNoFrameskip-v4 | ||
|
||
# Device for the forward and backward passes | ||
local_device: "cuda:0" | ||
|
||
# SLURM config | ||
slurm_config: | ||
timeout_min: 10 | ||
slurm_partition: train | ||
slurm_cpus_per_task: 1 | ||
slurm_gpus_per_node: 1 | ||
|
||
# collector | ||
collector: | ||
backend: gloo | ||
frames_per_batch: 80 | ||
total_frames: 200_000_000 | ||
num_workers: 1 | ||
|
||
# logger | ||
logger: | ||
backend: wandb | ||
exp_name: Atari_IMPALA | ||
test_interval: 200_000_000 | ||
num_test_episodes: 3 | ||
|
||
# Optim | ||
optim: | ||
lr: 0.0006 | ||
eps: 1e-8 | ||
weight_decay: 0.0 | ||
momentum: 0.0 | ||
alpha: 0.99 | ||
max_grad_norm: 40.0 | ||
anneal_lr: True | ||
|
||
# loss | ||
loss: | ||
gamma: 0.99 | ||
batch_size: 32 | ||
sgd_updates: 1 | ||
critic_coef: 0.5 | ||
entropy_coef: 0.01 | ||
loss_critic_type: l2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Environment | ||
env: | ||
env_name: PongNoFrameskip-v4 | ||
|
||
# Device for the forward and backward passes | ||
device: "cuda:0" | ||
|
||
# collector | ||
collector: | ||
frames_per_batch: 80 | ||
total_frames: 200_000_000 | ||
num_workers: 12 | ||
|
||
# logger | ||
logger: | ||
backend: wandb | ||
exp_name: Atari_IMPALA | ||
test_interval: 200_000_000 | ||
num_test_episodes: 3 | ||
|
||
# Optim | ||
optim: | ||
lr: 0.0006 | ||
eps: 1e-8 | ||
weight_decay: 0.0 | ||
momentum: 0.0 | ||
alpha: 0.99 | ||
max_grad_norm: 40.0 | ||
anneal_lr: True | ||
|
||
# loss | ||
loss: | ||
gamma: 0.99 | ||
batch_size: 32 | ||
sgd_updates: 1 | ||
critic_coef: 0.5 | ||
entropy_coef: 0.01 | ||
loss_critic_type: l2 |
Oops, something went wrong.
b38d4b7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold
2
.benchmarks/test_objectives_benchmarks.py::test_values[vec_generalized_advantage_estimate-True-True]
100.48021464232674
iter/sec (stddev: 0.05903550706063675
)305.2798305395709
iter/sec (stddev: 0.012019387037625095
)3.04
This comment was automatically generated by workflow using github-action-benchmark.
CC: @vmoens