Skip to content

Commit

Permalink
[BugFix]: Fix examples (#290)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 18, 2022
1 parent bab8767 commit 3781512
Show file tree
Hide file tree
Showing 28 changed files with 219 additions and 324 deletions.
22 changes: 22 additions & 0 deletions examples/EXAMPLES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Examples

We provide examples to train the following algorithms:
- [DQN](dqn/dqn.py)
- [DDPG](ddpg/ddpg.py)
- [SAC](sac/sac.py)
- [REDQ](redq/redq.py)
- [PPO](ppo/ppo.py)

To run these examples, make sure you have installed hydra:
```
pip install hydra-code
```

Then, go to the directory that interests you and run
```
python sac.py
```
or similar. Hyperparameters can be easily changed by providing the arguments to hydra:
```
python sac frames_per_batch=63
```
36 changes: 36 additions & 0 deletions examples/ddpg/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
env_name: HalfCheetah-v4
env_task: ""
env_library: gym
async_collection: 1
record_video: 0
normalize_rewards_online: 1
normalize_rewards_online_scale: 5
frame_skip: 1
frames_per_batch: 1024
optim_steps_per_batch: 128
batch_size: 256
total_frames: 1000000
prb: 1
lr: 3e-4
ou_exploration: 1
multi_step: 1
init_random_frames: 25000
activation: elu
gSDE: 0
from_pixels: 0
#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1]
collector_devices: [cpu,cpu,cpu,cpu]
env_per_collector: 8
num_workers: 32
lr_scheduler: ""
value_network_update_interval: 200
record_interval: 10
max_frames_per_traj: -1
weight_decay: 0.0
annealing_frames: 1000000
init_env_steps: 10000
record_frames: 10000
loss_function: smooth_l1
batch_transform: 1
buffer_prefetch: 64
norm_stats: 1
10 changes: 0 additions & 10 deletions examples/ddpg/configs/cheetah.txt

This file was deleted.

10 changes: 0 additions & 10 deletions examples/ddpg/configs/halfcheetah.txt

This file was deleted.

18 changes: 0 additions & 18 deletions examples/ddpg/configs/humanoid.txt

This file was deleted.

22 changes: 0 additions & 22 deletions examples/ddpg/configs/humanoid_pixels.txt

This file was deleted.

20 changes: 4 additions & 16 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
}


@hydra.main(version_base=None, config_path=None, config_name="config")
@hydra.main(version_base=None, config_path=".", config_name="config")
def main(cfg: "DictConfig"):
from torchrl.trainers.loggers import TensorboardLogger

Expand Down Expand Up @@ -96,7 +96,9 @@ def main(cfg: "DictConfig"):
if not cfg.vecnorm and cfg.norm_stats:
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
stats = get_stats_random_rollout(
cfg, proof_env, key="next_pixels" if cfg.from_pixels else None
cfg,
proof_env,
key="next_pixels" if cfg.from_pixels else "next_observation_vector",
)
# make sure proof_env is closed
proof_env.close()
Expand Down Expand Up @@ -199,20 +201,6 @@ def main(cfg: "DictConfig"):
cfg,
)

def select_keys(batch):
return batch.select(
"reward",
"done",
"steps_to_next_obs",
"pixels",
"next_pixels",
"observation_vector",
"next_observation_vector",
"action",
)

trainer.register_op("batch_process", select_keys)

final_seed = collector.set_seed(cfg.seed)
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

Expand Down
34 changes: 34 additions & 0 deletions examples/dqn/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
env_name: ALE/Pong-v5
env_task: ""
env_library: gym
async_collection: 1
record_video: 1
normalize_rewards_online: 1
normalize_rewards_online_scale: 5
frame_skip: 4
noops: 30
frames_per_batch: 1024
optim_steps_per_batch: 128
batch_size: 256
total_frames: 1000000
prb: 1
lr: 3e-4
multi_step: 1
init_random_frames: 25000
from_pixels: 1
#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1]
collector_devices: [cpu,cpu,cpu,cpu]
env_per_collector: 8
num_workers: 32
lr_scheduler: ""
value_network_update_interval: 200
record_interval: 10
max_frames_per_traj: -1
weight_decay: 0.0
annealing_frames: 1000000
init_env_steps: 10000
record_frames: 50000
loss_function: smooth_l1
batch_transform: 1
buffer_prefetch: 64
catframes: 4
20 changes: 0 additions & 20 deletions examples/dqn/configs/pong.txt

This file was deleted.

24 changes: 0 additions & 24 deletions examples/dqn/configs/pong_smoketest.txt

This file was deleted.

20 changes: 4 additions & 16 deletions examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
cs.store(name="config", node=Config)


@hydra.main(version_base=None, config_path=None, config_name="config")
@hydra.main(version_base=None, config_path=".", config_name="config")
def main(cfg: "DictConfig"):

from torchrl.trainers.loggers import TensorboardLogger
Expand Down Expand Up @@ -87,7 +87,9 @@ def main(cfg: "DictConfig"):
if not cfg.vecnorm and cfg.norm_stats:
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
stats = get_stats_random_rollout(
cfg, proof_env, key="next_pixels" if cfg.from_pixels else None
cfg,
proof_env,
key="next_pixels" if cfg.from_pixels else "next_observation_vector",
)
# make sure proof_env is closed
proof_env.close()
Expand Down Expand Up @@ -169,20 +171,6 @@ def main(cfg: "DictConfig"):
cfg,
)

def select_keys(batch):
return batch.select(
"reward",
"done",
"steps_to_next_obs",
"pixels",
"next_pixels",
"observation_vector",
"next_observation_vector",
"action",
)

trainer.register_op("batch_process", select_keys)

final_seed = collector.set_seed(cfg.seed)
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

Expand Down
26 changes: 26 additions & 0 deletions examples/ppo/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
env_name: HalfCheetah-v4
env_task: ""
env_library: gym
async_collection: 0
record_video: 1
normalize_rewards_online: 1
normalize_rewards_online_scale: 5
frame_skip: 1
frames_per_batch: 1000
optim_steps_per_batch: 10
batch_size: 256
total_frames: 1000000
lr: 3e-4
from_pixels: 0
#collector_devices: [cuda:1]
collector_devices: [cpu]
env_per_collector: 4
num_workers: 4
lr_scheduler: ""
record_interval: 100
max_frames_per_traj: -1
weight_decay: 0.0
init_env_steps: 10000
record_frames: 50000
loss_function: smooth_l1
batch_transform: 1
15 changes: 0 additions & 15 deletions examples/ppo/configs/cheetah.txt

This file was deleted.

12 changes: 0 additions & 12 deletions examples/ppo/configs/cheetah_pixels.txt

This file was deleted.

17 changes: 0 additions & 17 deletions examples/ppo/configs/cheetah_smoketest.txt

This file was deleted.

13 changes: 0 additions & 13 deletions examples/ppo/configs/humanoid.txt

This file was deleted.

6 changes: 4 additions & 2 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
cs.store(name="config", node=Config)


@hydra.main(version_base=None, config_path=None, config_name="config")
@hydra.main(version_base=None, config_path=".", config_name="config")
def main(cfg: "DictConfig"):
from torchrl.trainers.loggers import TensorboardLogger

Expand Down Expand Up @@ -81,7 +81,9 @@ def main(cfg: "DictConfig"):
if not cfg.vecnorm and cfg.norm_stats:
proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)()
stats = get_stats_random_rollout(
cfg, proof_env, key="next_pixels" if cfg.from_pixels else None
cfg,
proof_env,
key="next_pixels" if cfg.from_pixels else "next_observation_vector",
)
# make sure proof_env is closed
proof_env.close()
Expand Down
Loading

0 comments on commit 3781512

Please sign in to comment.