-
Notifications
You must be signed in to change notification settings - Fork 326
/
iql_offline.py
172 lines (139 loc) · 4.97 KB
/
iql_offline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""IQL Example.
This is a self-contained example of an offline IQL training script.
The helper functions are coded in the utils.py associated with this script.
"""
from __future__ import annotations
import warnings
import hydra
import numpy as np
import torch
import tqdm
from tensordict.nn import CudaGraphModule
from torchrl._utils import timeit
from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
log_metrics,
make_environment,
make_iql_model,
make_iql_optimizer,
make_loss,
make_offline_replay_buffer,
)
torch.set_float32_matmul_precision("high")
@hydra.main(config_path="", config_name="offline_config")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.backend).set()
# Create logger
exp_name = generate_exp_name("IQL-offline", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="iql_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)
# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = cfg.optim.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)
# Creante env
train_env, eval_env = make_environment(
cfg,
cfg.logger.eval_envs,
logger=logger,
)
# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
# Create agent
model = make_iql_model(cfg, train_env, eval_env, device)
# Create loss
loss_module, target_net_updater = make_loss(cfg.loss, model, device=device)
# Create optimizer
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
cfg.optim, loss_module
)
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value)
def update(data):
optimizer.zero_grad(set_to_none=True)
# compute losses
loss_info = loss_module(data)
actor_loss = loss_info["loss_actor"]
value_loss = loss_info["loss_value"]
q_loss = loss_info["loss_qvalue"]
(actor_loss + value_loss + q_loss).backward()
optimizer.step()
# update qnet_target params
target_net_updater.step()
return loss_info.detach()
compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)
pbar = tqdm.tqdm(range(cfg.optim.gradient_steps))
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps
# Training loop
for i in pbar:
timeit.printevery(1000, cfg.optim.gradient_steps, erase=True)
# sample data
with timeit("sample"):
data = replay_buffer.sample()
data = data.to(device)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_info = update(data)
# evaluation
metrics_to_log = loss_info.to_dict()
if i % evaluation_interval == 0:
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
metrics_to_log["evaluation_reward"] = eval_reward
if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, i)
pbar.close()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
if __name__ == "__main__":
main()