Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Update DQN example #1512

Merged
merged 66 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
32262e3
cartpole
albertbou92 Sep 8, 2023
7e44e5b
atari
albertbou92 Sep 8, 2023
4641561
cartpole
albertbou92 Sep 8, 2023
8ad0fc1
fix
albertbou92 Sep 8, 2023
f4359e1
Merge remote-tracking branch 'origin/main' into update_dqn_example
vmoens Sep 8, 2023
ece9522
hydra config cartpole
albertbou92 Sep 8, 2023
51b673a
fix
albertbou92 Sep 8, 2023
6de66ed
fixes
vmoens Sep 8, 2023
c673a31
Merge remote-tracking branch 'PyTorchRL/update_dqn_example' into upda…
vmoens Sep 8, 2023
d1dfc1b
Merge pull request #1 from vmoens/update_dqn_example
vmoens Sep 8, 2023
617f9a3
params
albertbou92 Sep 8, 2023
7b3e6a4
params
albertbou92 Sep 8, 2023
1b5d32d
params
albertbou92 Sep 8, 2023
7f57539
atari working script
albertbou92 Sep 8, 2023
8ea99b2
fix
albertbou92 Sep 8, 2023
d81504a
vecnorm
albertbou92 Sep 9, 2023
2a27d09
conf
albertbou92 Sep 10, 2023
43596ee
fix
albertbou92 Sep 12, 2023
b20f756
eval time
albertbou92 Sep 12, 2023
fa5a03a
original implementation
albertbou92 Sep 13, 2023
c7b9255
logging
albertbou92 Sep 13, 2023
3e3d7d0
format
albertbou92 Sep 13, 2023
fc63456
fix
albertbou92 Sep 15, 2023
0e0e401
Merge branch 'main' into update_dqn_example
albertbou92 Sep 18, 2023
ef6d5f6
cleaner scripts
albertbou92 Sep 18, 2023
441a67e
cleaner scripts
albertbou92 Sep 18, 2023
75389db
fix
albertbou92 Sep 18, 2023
d9d35d1
fix
albertbou92 Sep 18, 2023
fa95119
format
albertbou92 Sep 18, 2023
206ee3c
format
albertbou92 Sep 18, 2023
d552d98
Merge branch 'main' into update_dqn_example
albertbou92 Sep 25, 2023
0a7818f
env reader and merge main
albertbou92 Sep 25, 2023
108e8fa
fix
albertbou92 Sep 25, 2023
673e7dd
eol transform
albertbou92 Sep 25, 2023
5089ca8
fixes
albertbou92 Sep 25, 2023
f805db1
Merge branch 'main' into update_dqn_example
albertbou92 Oct 3, 2023
048a8df
introduce feedback
albertbou92 Oct 6, 2023
c981e22
script fixes
albertbou92 Oct 6, 2023
1d961bb
Merge branch 'main' into update_dqn_example
albertbou92 Oct 6, 2023
c099b31
script fixes
albertbou92 Oct 6, 2023
4cbffa9
script fixes
albertbou92 Oct 6, 2023
02d6c34
script fixes
albertbou92 Oct 6, 2023
a488404
script fixes
albertbou92 Oct 6, 2023
36e3076
script fixes
albertbou92 Oct 6, 2023
f7c9bc6
fixes
albertbou92 Oct 6, 2023
0fdd264
fixes
albertbou92 Oct 6, 2023
a6a8d6e
fixes
albertbou92 Oct 6, 2023
de19aed
fixes
albertbou92 Oct 6, 2023
68b7a19
fixes
albertbou92 Oct 6, 2023
af4c603
Merge branch 'main' into update_dqn_example
albertbou92 Oct 11, 2023
807d9c5
Merge branch 'main' into update_dqn_example
albertbou92 Nov 3, 2023
197b679
atari tqdm fix
albertbou92 Nov 6, 2023
d61b160
Merge remote-tracking branch 'origin/main' into update_dqn_example
vmoens Nov 7, 2023
d739a80
Merge branch 'update_dqn_example' of https://github.com/PyTorchRL/rl …
vmoens Nov 7, 2023
a249366
Merge branch 'main' into update_dqn_example
Nov 28, 2023
03a3a62
merge main
Nov 28, 2023
53714c0
merge main
Nov 28, 2023
6b60324
atari script
albertbou92 Nov 28, 2023
06cc740
fixes
vmoens Nov 30, 2023
75625bd
fix end-of-life
vmoens Nov 30, 2023
82e253e
device in conf
albertbou92 Dec 5, 2023
014cc62
skip initial collection frames
albertbou92 Dec 5, 2023
fb6f61a
move logging
albertbou92 Dec 5, 2023
2346870
move logging
albertbou92 Dec 5, 2023
03bc5b0
Merge remote-tracking branch 'origin/main' into update_dqn_example
vmoens Dec 6, 2023
97e7aa5
CI examples fix
albertbou92 Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
hydra config cartpole
  • Loading branch information
albertbou92 committed Sep 8, 2023
commit ece9522208dba32fad3d5f620e385aaa0923ef6b
32 changes: 32 additions & 0 deletions examples/dqn/config_cartpole.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Environment
env:
env_name: CartPole-v1

# collector
collector:
total_frames: 500_000
frames_per_batch: 10
eps_start: 1.0
eps_end: 0.05
annealing_frames: 250_000
init_random_frames: 10_000

# buffer
buffer:
buffer_size: 10_000
batch_size: 128

# logger
logger:
backend: csv
exp_name: DQN

# Optim
optim:
lr: 2.5e-4

# loss
loss:
gamma: 0.99
hard_update_freq: 1
num_updates: 1
70 changes: 31 additions & 39 deletions examples/dqn/dqn_carpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
DQN Benchmarks: CartPole-v1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks in the repo is intended for runtime benchmark.
I would call this evaluation. At the end of the day, I think that what we will want is the following directory tree:

torchrl/
    ├── benchmarks/
    │   ├── component_x/
    │   │   ├── benchmark_workflow_1/
    │   │   ├── benchmark_workflow_2/
    │   │   └── ...
    │   ├── algorithm_y/
    │   │   ├── benchmark_workflow_1/
    │   │   ├── benchmark_workflow_2/
    │   │   └── ...
    │   └── ...
    ├── evaluation/
    │   ├── component_x/
    │   │   ├── evaluation_metrics_1/
    │   │   ├── evaluation_metrics_2/
    │   │   └── ...
    │   ├── algorithm_y/
    │   │   ├── evaluation_metrics_1/
    │   │   ├── evaluation_metrics_2/
    │   │   └── ...
    │   └── ...
    └── ...

cc @matteobettini @BY571

"""

import hydra
import tqdm
import time
import torch.nn
Expand Down Expand Up @@ -64,81 +65,68 @@ def make_dqn_model(env_name):
return qvalue_module


if __name__ == "__main__":
@hydra.main(config_path=".", config_name="config_cartpole", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821

device = "cpu" if not torch.cuda.is_available() else "cuda"
env_name = "CartPole-v1"
total_frames = 500_000
record_interval = 500_000
frames_per_batch = 10
num_updates = 1
buffer_size = 10_000
init_random_frames = 10_000
annealing_frames = 250_000
gamma = 0.99
lr = 2.5e-4
batch_size = 128
hard_update_freq = 50
eps_end = 0.05
logger_backend = "csv"

seed = 42
torch.manual_seed(seed)

# Make the components
model = make_dqn_model(env_name)
model_explore = EGreedyWrapper(model, annealing_num_steps=annealing_frames, eps_end=eps_end).to(device)
model = make_dqn_model(cfg.env.env_name)
model_explore = EGreedyWrapper(
policy=model,
annealing_num_steps=cfg.collector.annealing_frames,
eps_init=cfg.collector.eps_start,
eps_end=cfg.collector.eps_end,
).to(device)

# Create the collector
collector_class = SyncDataCollector
collector = SyncDataCollector(
make_env(env_name, device),
make_env(cfg.env.env_name, device),
policy=model_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)
collector.set_seed(seed)

# Create the replay buffer
replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=3,
storage=LazyTensorStorage(
max_size=buffer_size,
max_size=cfg.buffer.buffer_size,
device=device,
),
batch_size=batch_size,
batch_size=cfg.buffer.batch_size,
)

# Create the loss module
loss_module = DQNLoss(
value_network=model,
gamma=gamma,
gamma=cfg.loss.gamma,
loss_function="l2",
delay_value=True,
)
loss_module.make_value_estimator(gamma=gamma)
target_net_updater = HardUpdate(loss_module, value_network_update_interval=hard_update_freq)
loss_module.make_value_estimator(gamma=cfg.loss.gamma)
target_net_updater = HardUpdate(loss_module, value_network_update_interval=cfg.loss.hard_update_freq)

# Create the optimizer
optimizer = torch.optim.Adam(loss_module.parameters(), lr=lr)
optimizer = torch.optim.Adam(loss_module.parameters(), lr=cfg.optim.lr)

# Create the logger
exp_name = generate_exp_name("DQN", f"CartPole_{env_name}")
logger = get_logger(logger_backend, logger_name="dqn", experiment_name=exp_name)
exp_name = generate_exp_name("DQN", f"CartPole_{cfg.env.env_name}")
logger = get_logger(cfg.logger.backend, logger_name="dqn", experiment_name=exp_name)

# Main loop
collected_frames = 0
start_time = time.time()
pbar = tqdm.tqdm(total=total_frames)
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

for i, data in enumerate(collector):

# Train loging
logger.log_scalar("q_values", (data["action_value"]*data["action"]).sum().item() / frames_per_batch, collected_frames)
logger.log_scalar("q_values", (data["action_value"]*data["action"]).sum().item() / cfg.collector.frames_per_batch, collected_frames)
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
Expand All @@ -153,10 +141,10 @@ def make_dqn_model(env_name):
model_explore.step(current_frames)

# optimization steps
if collected_frames >= init_random_frames:
q_losses = TensorDict({}, batch_size=[num_updates])
for j in range(num_updates):
sampled_tensordict = replay_buffer.sample(batch_size)
if collected_frames >= cfg.collector.init_random_frames:
q_losses = TensorDict({}, batch_size=[cfg.collector.num_updates])
for j in range(cfg.collector.num_updates):
sampled_tensordict = replay_buffer.sample(cfg.buffer.batch_size)
loss_td = loss_module(sampled_tensordict)
q_loss = loss_td["loss"]
optimizer.zero_grad()
Expand All @@ -177,3 +165,7 @@ def make_dqn_model(env_name):
end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved