Skip to content

Commit

Permalink
[Algorithm] IMPALA and VTrace module (#1506)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
albertbou92 and vmoens authored Nov 23, 2023
1 parent c2edf35 commit b38d4b7
Show file tree
Hide file tree
Showing 20 changed files with 2,140 additions and 187 deletions.
6 changes: 6 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
# ==================================================================================== #
# ================================ Gymnasium ========================================= #

python .github/unittest/helpers/coverage_run_parallel.py examples/impala/impala_single_node.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
collector.num_workers=1 \
logger.backend= \
logger.test_interval=10
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \
env.env_name=HalfCheetah-v4 \
collector.total_frames=40 \
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
"object_store_memory": 1024**3,
}
collector = RayCollector(
env_makers=[env] * num_collectors,
create_env_fn=[env] * num_collectors,
policy=policy_module,
collector_class=SyncDataCollector,
collector_kwargs={
Expand Down
33 changes: 33 additions & 0 deletions examples/impala/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## Reproducing Importance Weighted Actor-Learner Architecture (IMPALA) Algorithm Results

This repository contains scripts that enable training agents using the IMPALA Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Espeholt et al. 2018.

## Examples Structure

Please note that we provide 2 examples, one for single node training and one for distributed training. Both examples rely on the same utils file, but besides that are independent. Each example contains the following files:

1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. impala_single_node_ray.py).

2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils.py).

3. **Configuration File:** This file includes default hyperparameters specified in the original paper. For the multi-node case, the file also includes the configuration file of the Ray cluster. Users can modify these hyperparameters to customize their experiments (e.g. config_single_node.yaml).


## Running the Examples

You can execute the single node IMPALA algorithm on Atari environments by running the following command:

```bash
python impala_single_node.py
```

You can execute the multi-node IMPALA algorithm on Atari environments by running the following command:

```bash
python impala_single_node_ray.py
```
or

```bash
python impala_single_node_submitit.py
```
65 changes: 65 additions & 0 deletions examples/impala/config_multi_node_ray.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Environment
env:
env_name: PongNoFrameskip-v4

# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html
ray_init_config:
address: null
num_cpus: null
num_gpus: null
resources: null
object_store_memory: null
local_mode: False
ignore_reinit_error: False
include_dashboard: null
dashboard_host: 127.0.0.1
dashboard_port: null
job_config: null
configure_logging: True
logging_level: info
logging_format: null
log_to_driver: True
namespace: null
runtime_env: null
storage: null

# Device for the forward and backward passes
local_device: "cuda:0"

# Resources assigned to each IMPALA rollout collection worker
remote_worker_resources:
num_cpus: 1
num_gpus: 0.25
memory: 1073741824 # 1*1024**3 - 1GB

# collector
collector:
frames_per_batch: 80
total_frames: 200_000_000
num_workers: 12

# logger
logger:
backend: wandb
exp_name: Atari_IMPALA
test_interval: 200_000_000
num_test_episodes: 3

# Optim
optim:
lr: 0.0006
eps: 1e-8
weight_decay: 0.0
momentum: 0.0
alpha: 0.99
max_grad_norm: 40.0
anneal_lr: True

# loss
loss:
gamma: 0.99
batch_size: 32
sgd_updates: 1
critic_coef: 0.5
entropy_coef: 0.01
loss_critic_type: l2
46 changes: 46 additions & 0 deletions examples/impala/config_multi_node_submitit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Environment
env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
local_device: "cuda:0"

# SLURM config
slurm_config:
timeout_min: 10
slurm_partition: train
slurm_cpus_per_task: 1
slurm_gpus_per_node: 1

# collector
collector:
backend: gloo
frames_per_batch: 80
total_frames: 200_000_000
num_workers: 1

# logger
logger:
backend: wandb
exp_name: Atari_IMPALA
test_interval: 200_000_000
num_test_episodes: 3

# Optim
optim:
lr: 0.0006
eps: 1e-8
weight_decay: 0.0
momentum: 0.0
alpha: 0.99
max_grad_norm: 40.0
anneal_lr: True

# loss
loss:
gamma: 0.99
batch_size: 32
sgd_updates: 1
critic_coef: 0.5
entropy_coef: 0.01
loss_critic_type: l2
38 changes: 38 additions & 0 deletions examples/impala/config_single_node.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Environment
env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
device: "cuda:0"

# collector
collector:
frames_per_batch: 80
total_frames: 200_000_000
num_workers: 12

# logger
logger:
backend: wandb
exp_name: Atari_IMPALA
test_interval: 200_000_000
num_test_episodes: 3

# Optim
optim:
lr: 0.0006
eps: 1e-8
weight_decay: 0.0
momentum: 0.0
alpha: 0.99
max_grad_norm: 40.0
anneal_lr: True

# loss
loss:
gamma: 0.99
batch_size: 32
sgd_updates: 1
critic_coef: 0.5
entropy_coef: 0.01
loss_critic_type: l2
Loading

1 comment on commit b38d4b7

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: b38d4b7 Previous: c2edf35 Ratio
benchmarks/test_objectives_benchmarks.py::test_values[vec_generalized_advantage_estimate-True-True] 100.48021464232674 iter/sec (stddev: 0.05903550706063675) 305.2798305395709 iter/sec (stddev: 0.012019387037625095) 3.04

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.