Skip to content

Commit

Permalink
[Feature] CROSSQ compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 98a2b30e8f6a1b0bc583a9f3c51adc2634eb8028
Pull Request resolved: #2554
  • Loading branch information
vmoens committed Dec 14, 2024
1 parent e2be42e commit 01a421e
Show file tree
Hide file tree
Showing 56 changed files with 299 additions and 102 deletions.
8 changes: 8 additions & 0 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings

import hydra
import torch

Expand Down Expand Up @@ -149,6 +153,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
adv_module = torch.compile(adv_module, mode=compile_mode)

if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
adv_module = CudaGraphModule(adv_module)

Expand Down
8 changes: 8 additions & 0 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings

import hydra
import torch

Expand Down Expand Up @@ -145,6 +149,10 @@ def update(batch):
adv_module = torch.compile(adv_module, mode=compile_mode)

if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=20)
adv_module = CudaGraphModule(adv_module, warmup=20)

Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import numpy as np
import torch.nn
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import numpy as np
import torch.nn
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/bandits/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import argparse

Expand Down
15 changes: 14 additions & 1 deletion sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
The helper functions are coded in the utils.py associated with this script.
"""
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np

import torch
import tqdm
from tensordict.nn import CudaGraphModule
Expand All @@ -32,6 +36,8 @@
make_offline_replay_buffer,
)

torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="offline_config", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -77,7 +83,9 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_env.start()

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
loss_module, target_net_updater = make_continuous_loss(
cfg.loss, model, device=device
)

# Create Optimizer
(
Expand Down Expand Up @@ -134,6 +142,10 @@ def update(data, policy_eval_start, iteration):
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
Expand All @@ -154,6 +166,7 @@ def update(data, policy_eval_start, iteration):

with timeit("update"):
# compute loss
torch.compiler.cudagraph_mark_step_begin()
i_device = torch.tensor(i, device=device)
loss, loss_vals = update(
data.to(device), policy_eval_start=policy_eval_start, iteration=i_device
Expand Down
14 changes: 13 additions & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
The helper functions are coded in the utils.py associated with this script.
"""
from __future__ import annotations

import warnings

import hydra
import numpy as np
import torch
Expand All @@ -34,6 +38,8 @@
make_replay_buffer,
)

torch.set_float32_matmul_precision("high")


@hydra.main(version_base="1.1", config_path="", config_name="online_config")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -103,7 +109,9 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
loss_module, target_net_updater = make_continuous_loss(
cfg.loss, model, device=device
)

# Create optimizer
(
Expand Down Expand Up @@ -140,6 +148,10 @@ def update(sampled_tensordict):
if compile_mode:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
Expand Down
13 changes: 12 additions & 1 deletion sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
The helper functions are coded in the utils.py associated with this script.
"""
from __future__ import annotations

import warnings

import hydra
import numpy as np

import torch
import torch.cuda
import tqdm
Expand All @@ -33,6 +37,8 @@
make_replay_buffer,
)

torch.set_float32_matmul_precision("high")


@hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -70,7 +76,7 @@ def main(cfg: "DictConfig"): # noqa: F821
model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device)

# Create loss
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device)

compile_mode = None
if cfg.compile.compile:
Expand Down Expand Up @@ -123,6 +129,10 @@ def update(sampled_tensordict):
if compile_mode:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
Expand Down Expand Up @@ -170,6 +180,7 @@ def update(sampled_tensordict):
sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_dict = update(sampled_tensordict)
tds.append(loss_dict)

Expand Down
14 changes: 8 additions & 6 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools

import torch.nn
Expand Down Expand Up @@ -221,8 +223,8 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
# distribution_kwargs=TensorDictParams(
# TensorDict(
# {
# "low": action_spec.space.low,
# "high": action_spec.space.high,
# "low": torch.as_tensor(action_spec.space.low, device=device),
# "high": torch.as_tensor(action_spec.space.high, device=device),
# "tanh_loc": NonTensorData(False),
# }
# ),
Expand Down Expand Up @@ -326,7 +328,7 @@ def make_cql_modules_state(model_cfg, proof_environment):
# ---------


def make_continuous_loss(loss_cfg, model):
def make_continuous_loss(loss_cfg, model, device: torch.device | None = None):
loss_module = CQLLoss(
model[0],
model[1],
Expand All @@ -339,19 +341,19 @@ def make_continuous_loss(loss_cfg, model):
with_lagrange=loss_cfg.with_lagrange,
lagrange_thresh=loss_cfg.lagrange_thresh,
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)

return loss_module, target_net_updater


def make_discrete_loss(loss_cfg, model):
def make_discrete_loss(loss_cfg, model, device: torch.device | None = None):
loss_module = DiscreteCQLLoss(
model,
loss_function=loss_cfg.loss_function,
delay_value=True,
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)

return loss_module, target_net_updater
Expand Down
9 changes: 7 additions & 2 deletions sota-implementations/crossq/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ collector:
init_random_frames: 25000
frames_per_batch: 1000
init_env_steps: 1000
device: cpu
device:
env_per_collector: 1
reset_at_each_iter: False

Expand Down Expand Up @@ -46,7 +46,12 @@ network:
actor_activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"
device:

compile:
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
Expand Down
Loading

0 comments on commit 01a421e

Please sign in to comment.