Skip to content

Commit

Permalink
[Algorithm] Update DQN example (pytorch#1512)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <vincentmoens@gmail.com>
Co-authored-by: albert bou <albertbo@kth.se>
  • Loading branch information
3 people authored Dec 7, 2023
1 parent ee89728 commit 4a6cc52
Show file tree
Hide file tree
Showing 12 changed files with 731 additions and 225 deletions.
40 changes: 16 additions & 24 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,14 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c_atari.
loss.mini_batch_size=20 \
logger.backend= \
logger.test_interval=40
python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
total_frames=48 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=16 \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
buffer.batch_size=10 \
device=cuda:0 \
loss.num_updates=1 \
buffer.buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_cql_online.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down Expand Up @@ -215,18 +211,14 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
logger.backend=
# record_video=True \
# record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
total_frames=48 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=16 \
num_workers=2 \
env_per_collector=1 \
collector_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
buffer.batch_size=10 \
device=cuda:0 \
loss.num_updates=1 \
buffer.buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
num_workers=2 \
collector.total_frames=48 \
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ If you would like to contribute to new features, check our [call for contributio
## Examples, tutorials and demos

A series of [examples](examples/) are provided with an illustrative purpose:
- [DQN and Rainbow](examples/dqn/dqn.py)
- [DQN](examples/dqn)
- [DDPG](examples/ddpg/ddpg.py)
- [IQL](examples/iql/iql.py)
- [CQL](examples/iql/cql.py)
Expand Down
3 changes: 3 additions & 0 deletions examples/distributed/collectors/multi_nodes/lol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchrl.envs.libs.gym import GymEnv

env = GymEnv("ALE/Pong-v5")
30 changes: 30 additions & 0 deletions examples/dqn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## Reproducing Deep Q-Learning (DQN) Algorithm Results

This repository contains scripts that enable training agents using the Deep Q-Learning (DQN) Algorithm on CartPole and Atari environments. For Atari, We follow the original paper [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602) by Mnih et al. (2013).


## Examples Structure

Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files:

1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. dqn_atari.py).

2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py).

3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml).


## Running the Examples

You can execute the DQN algorithm on the CartPole environment by running the following command:

```bash
python dqn_cartpole.py

You can execute the DQN algorithm on Atari environments by running the following command:

```bash
python dqn_atari.py
```

```
32 changes: 0 additions & 32 deletions examples/dqn/config.yaml

This file was deleted.

38 changes: 38 additions & 0 deletions examples/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
device: cuda:0

# Environment
env:
env_name: PongNoFrameskip-v4

# collector
collector:
total_frames: 40_000_100
frames_per_batch: 16
eps_start: 1.0
eps_end: 0.01
annealing_frames: 4_000_000
init_random_frames: 200_000

# buffer
buffer:
buffer_size: 1_000_000
batch_size: 32
scratch_dir: null

# logger
logger:
backend: null
exp_name: DQN
test_interval: 1_000_000
num_test_episodes: 3

# Optim
optim:
lr: 0.00025
max_grad_norm: 10

# loss
loss:
gamma: 0.99
hard_update_freq: 10_000
num_updates: 1
37 changes: 37 additions & 0 deletions examples/dqn/config_cartpole.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
device: cuda:0

# Environment
env:
env_name: CartPole-v1

# collector
collector:
total_frames: 500_100
frames_per_batch: 10
eps_start: 1.0
eps_end: 0.05
annealing_frames: 250_000
init_random_frames: 10_000

# buffer
buffer:
buffer_size: 10_000
batch_size: 128

# logger
logger:
backend: null
exp_name: DQN
test_interval: 50_000
num_test_episodes: 5

# Optim
optim:
lr: 2.5e-4
max_grad_norm: 10

# loss
loss:
gamma: 0.99
hard_update_freq: 50
num_updates: 1
168 changes: 0 additions & 168 deletions examples/dqn/dqn.py

This file was deleted.

Loading

0 comments on commit 4a6cc52

Please sign in to comment.