-
Notifications
You must be signed in to change notification settings - Fork 326
/
dqn_cartpole.py
243 lines (211 loc) · 8.24 KB
/
dqn_cartpole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import warnings
import hydra
import torch.nn
import torch.optim
import tqdm
from tensordict.nn import CudaGraphModule, TensorDictSequential
from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.modules import EGreedyModule
from torchrl.objectives import DQNLoss, HardUpdate
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_cartpole import eval_model, make_dqn_model, make_env
torch.set_float32_matmul_precision("high")
@hydra.main(config_path="", config_name="config_cartpole", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
device = cfg.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)
# Make the components
model = make_dqn_model(cfg.env.env_name, device=device)
greedy_module = EGreedyModule(
annealing_num_steps=cfg.collector.annealing_frames,
eps_init=cfg.collector.eps_start,
eps_end=cfg.collector.eps_end,
spec=model.spec,
device=device,
)
model_explore = TensorDictSequential(
model,
greedy_module,
)
# Create the replay buffer
replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
storage=LazyTensorStorage(max_size=cfg.buffer.buffer_size, device=device),
batch_size=cfg.buffer.batch_size,
)
# Create the loss module
loss_module = DQNLoss(
value_network=model,
loss_function="l2",
delay_value=True,
)
loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device)
loss_module = loss_module.to(device)
target_net_updater = HardUpdate(
loss_module, value_network_update_interval=cfg.loss.hard_update_freq
)
# Create the optimizer
optimizer = torch.optim.Adam(loss_module.parameters(), lr=cfg.optim.lr)
# Create the logger
logger = None
if cfg.logger.backend:
exp_name = generate_exp_name("DQN", f"CartPole_{cfg.env.env_name}")
logger = get_logger(
cfg.logger.backend,
logger_name="dqn",
experiment_name=exp_name,
wandb_kwargs={
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)
# Create the test environment
test_env = make_env(cfg.env.env_name, "cpu", from_pixels=cfg.logger.video)
if cfg.logger.video:
test_env.insert_transform(
0,
VideoRecorder(
logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
),
)
def update(sampled_tensordict):
loss_td = loss_module(sampled_tensordict)
q_loss = loss_td["loss"]
optimizer.zero_grad()
q_loss.backward()
optimizer.step()
target_net_updater.step()
return q_loss.detach()
compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)
# Create the collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, "cpu"),
policy=model_explore,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device="cpu",
storing_device="cpu",
max_frames_per_traj=-1,
init_random_frames=cfg.collector.init_random_frames,
compile_policy={"mode": compile_mode, "fullgraph": True}
if compile_mode is not None
else False,
cudagraph_policy=cfg.compile.cudagraphs,
)
# Main loop
collected_frames = 0
num_updates = cfg.loss.num_updates
batch_size = cfg.buffer.batch_size
test_interval = cfg.logger.test_interval
num_test_episodes = cfg.logger.num_test_episodes
frames_per_batch = cfg.collector.frames_per_batch
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
init_random_frames = cfg.collector.init_random_frames
q_losses = torch.zeros(num_updates, device=device)
c_iter = iter(collector)
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
data = next(c_iter)
metrics_to_log = {}
pbar.update(data.numel())
data = data.reshape(-1)
current_frames = data.numel()
with timeit("rb - extend"):
replay_buffer.extend(data)
collected_frames += current_frames
greedy_module.step(current_frames)
# Get and log training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_reward_mean = episode_rewards.mean().item()
episode_length = data["next", "step_count"][data["next", "done"]]
episode_length_mean = episode_length.sum().item() / len(episode_length)
metrics_to_log.update(
{
"train/episode_reward": episode_reward_mean,
"train/episode_length": episode_length_mean,
}
)
if collected_frames < init_random_frames:
if collected_frames < init_random_frames:
if logger:
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, step=collected_frames)
continue
# optimization steps
for j in range(num_updates):
with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample(batch_size)
sampled_tensordict = sampled_tensordict.to(device)
with timeit("update"):
q_loss = update(sampled_tensordict)
q_losses[j].copy_(q_loss)
# Get and log q-values, loss, epsilon, sampling time and training time
metrics_to_log.update(
{
"train/q_values": (data["action_value"] * data["action"]).sum().item()
/ frames_per_batch,
"train/q_loss": q_losses.mean().item(),
"train/epsilon": greedy_module.eps,
}
)
# Get and log evaluation rewards and eval time
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
cur_test_frame = (i * frames_per_batch) // test_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
model.eval()
test_rewards = eval_model(model, test_env, num_test_episodes)
model.train()
metrics_to_log.update(
{
"eval/reward": test_rewards,
}
)
# Log all the information
if logger:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, step=collected_frames)
# update weights of the inference policy
collector.update_policy_weights_()
collector.shutdown()
if not test_env.is_closed:
test_env.close()
if __name__ == "__main__":
main()