forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Algorithm] Update DQN example (pytorch#1512)
Co-authored-by: vmoens <vincentmoens@gmail.com> Co-authored-by: albert bou <albertbo@kth.se>
- Loading branch information
1 parent
ee89728
commit 4a6cc52
Showing
12 changed files
with
731 additions
and
225 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from torchrl.envs.libs.gym import GymEnv | ||
|
||
env = GymEnv("ALE/Pong-v5") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
## Reproducing Deep Q-Learning (DQN) Algorithm Results | ||
|
||
This repository contains scripts that enable training agents using the Deep Q-Learning (DQN) Algorithm on CartPole and Atari environments. For Atari, We follow the original paper [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602) by Mnih et al. (2013). | ||
|
||
|
||
## Examples Structure | ||
|
||
Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files: | ||
|
||
1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. dqn_atari.py). | ||
|
||
2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py). | ||
|
||
3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml). | ||
|
||
|
||
## Running the Examples | ||
|
||
You can execute the DQN algorithm on the CartPole environment by running the following command: | ||
|
||
```bash | ||
python dqn_cartpole.py | ||
|
||
You can execute the DQN algorithm on Atari environments by running the following command: | ||
|
||
```bash | ||
python dqn_atari.py | ||
``` | ||
|
||
``` |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
device: cuda:0 | ||
|
||
# Environment | ||
env: | ||
env_name: PongNoFrameskip-v4 | ||
|
||
# collector | ||
collector: | ||
total_frames: 40_000_100 | ||
frames_per_batch: 16 | ||
eps_start: 1.0 | ||
eps_end: 0.01 | ||
annealing_frames: 4_000_000 | ||
init_random_frames: 200_000 | ||
|
||
# buffer | ||
buffer: | ||
buffer_size: 1_000_000 | ||
batch_size: 32 | ||
scratch_dir: null | ||
|
||
# logger | ||
logger: | ||
backend: null | ||
exp_name: DQN | ||
test_interval: 1_000_000 | ||
num_test_episodes: 3 | ||
|
||
# Optim | ||
optim: | ||
lr: 0.00025 | ||
max_grad_norm: 10 | ||
|
||
# loss | ||
loss: | ||
gamma: 0.99 | ||
hard_update_freq: 10_000 | ||
num_updates: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
device: cuda:0 | ||
|
||
# Environment | ||
env: | ||
env_name: CartPole-v1 | ||
|
||
# collector | ||
collector: | ||
total_frames: 500_100 | ||
frames_per_batch: 10 | ||
eps_start: 1.0 | ||
eps_end: 0.05 | ||
annealing_frames: 250_000 | ||
init_random_frames: 10_000 | ||
|
||
# buffer | ||
buffer: | ||
buffer_size: 10_000 | ||
batch_size: 128 | ||
|
||
# logger | ||
logger: | ||
backend: null | ||
exp_name: DQN | ||
test_interval: 50_000 | ||
num_test_episodes: 5 | ||
|
||
# Optim | ||
optim: | ||
lr: 2.5e-4 | ||
max_grad_norm: 10 | ||
|
||
# loss | ||
loss: | ||
gamma: 0.99 | ||
hard_update_freq: 50 | ||
num_updates: 1 |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.