From 01a421e76a6f208c4eec6e407b2563c28d493540 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 13 Dec 2024 16:02:29 -0800 Subject: [PATCH] [Feature] CROSSQ compatibility with compile ghstack-source-id: 98a2b30e8f6a1b0bc583a9f3c51adc2634eb8028 Pull Request resolved: https://github.com/pytorch/rl/pull/2554 --- sota-implementations/a2c/a2c_atari.py | 8 + sota-implementations/a2c/a2c_mujoco.py | 8 + sota-implementations/a2c/utils_atari.py | 1 + sota-implementations/a2c/utils_mujoco.py | 1 + sota-implementations/bandits/dqn.py | 1 + sota-implementations/cql/cql_offline.py | 15 +- sota-implementations/cql/cql_online.py | 14 +- .../cql/discrete_cql_online.py | 13 +- sota-implementations/cql/utils.py | 14 +- sota-implementations/crossq/config.yaml | 9 +- sota-implementations/crossq/crossq.py | 207 +++++++++++------- sota-implementations/crossq/utils.py | 22 +- sota-implementations/ddpg/ddpg.py | 2 + sota-implementations/ddpg/utils.py | 2 + .../decision_transformer/dt.py | 2 + .../decision_transformer/lamb.py | 2 + .../decision_transformer/online_dt.py | 2 + .../decision_transformer/utils.py | 1 + .../discrete_sac/discrete_sac.py | 2 + sota-implementations/discrete_sac/utils.py | 2 + sota-implementations/dqn/dqn_atari.py | 2 + sota-implementations/dqn/dqn_cartpole.py | 2 + sota-implementations/dqn/utils_atari.py | 1 + sota-implementations/dqn/utils_cartpole.py | 1 + sota-implementations/dreamer/dreamer.py | 2 + sota-implementations/dreamer/dreamer_utils.py | 2 + sota-implementations/gail/gail.py | 2 + sota-implementations/gail/gail_utils.py | 1 + sota-implementations/gail/ppo_utils.py | 1 + .../impala/impala_multi_node_ray.py | 2 + .../impala/impala_multi_node_submitit.py | 2 + .../impala/impala_single_node.py | 2 + sota-implementations/impala/utils.py | 1 + sota-implementations/iql/discrete_iql.py | 2 + sota-implementations/iql/iql_offline.py | 2 + sota-implementations/iql/iql_online.py | 2 + sota-implementations/iql/utils.py | 2 + sota-implementations/multiagent/iql.py | 2 + .../multiagent/maddpg_iddpg.py | 2 + sota-implementations/multiagent/mappo_ippo.py | 2 + sota-implementations/multiagent/qmix_vdn.py | 2 + sota-implementations/multiagent/sac.py | 2 + .../multiagent/utils/logging.py | 2 + .../multiagent/utils/utils.py | 2 + sota-implementations/ppo/ppo_atari.py | 2 + sota-implementations/ppo/ppo_mujoco.py | 2 + sota-implementations/ppo/utils_atari.py | 1 + sota-implementations/ppo/utils_mujoco.py | 1 + sota-implementations/redq/redq.py | 1 + sota-implementations/sac/sac.py | 2 + sota-implementations/sac/utils.py | 2 + sota-implementations/td3/td3.py | 2 + sota-implementations/td3/utils.py | 2 + sota-implementations/td3_bc/td3_bc.py | 2 + sota-implementations/td3_bc/utils.py | 2 + torchrl/objectives/value/advantages.py | 11 +- 56 files changed, 299 insertions(+), 102 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index f6401b9946c..c7f70308fd4 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -2,6 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import warnings + import hydra import torch @@ -149,6 +153,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): adv_module = torch.compile(adv_module, mode=compile_mode) if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) adv_module = CudaGraphModule(adv_module) diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index b75a5224bc5..cf88e7db01a 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -2,6 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import warnings + import hydra import torch @@ -145,6 +149,10 @@ def update(batch): adv_module = torch.compile(adv_module, mode=compile_mode) if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=20) adv_module = CudaGraphModule(adv_module, warmup=20) diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index a0cea48b510..167a14e8796 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import numpy as np import torch.nn diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 645bc806265..8606506da15 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import numpy as np import torch.nn diff --git a/sota-implementations/bandits/dqn.py b/sota-implementations/bandits/dqn.py index 55ba34f5010..37cde0e2c62 100644 --- a/sota-implementations/bandits/dqn.py +++ b/sota-implementations/bandits/dqn.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import argparse diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 36a9b2478d5..e74997eb37f 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -9,10 +9,14 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time +import warnings import hydra import numpy as np + import torch import tqdm from tensordict.nn import CudaGraphModule @@ -32,6 +36,8 @@ make_offline_replay_buffer, ) +torch.set_float32_matmul_precision("high") + @hydra.main(config_path="", config_name="offline_config", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 @@ -77,7 +83,9 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_env.start() # Create loss - loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) + loss_module, target_net_updater = make_continuous_loss( + cfg.loss, model, device=device + ) # Create Optimizer ( @@ -134,6 +142,10 @@ def update(data, policy_eval_start, iteration): compile_mode = "reduce-overhead" update = torch.compile(update, mode=compile_mode) if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) update = CudaGraphModule(update, warmup=50) pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) @@ -154,6 +166,7 @@ def update(data, policy_eval_start, iteration): with timeit("update"): # compute loss + torch.compiler.cudagraph_mark_step_begin() i_device = torch.tensor(i, device=device) loss, loss_vals = update( data.to(device), policy_eval_start=policy_eval_start, iteration=i_device diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index b45340b60b2..f9a0a89776f 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -11,6 +11,10 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + +import warnings + import hydra import numpy as np import torch @@ -34,6 +38,8 @@ make_replay_buffer, ) +torch.set_float32_matmul_precision("high") + @hydra.main(version_base="1.1", config_path="", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 @@ -103,7 +109,9 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create loss - loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) + loss_module, target_net_updater = make_continuous_loss( + cfg.loss, model, device=device + ) # Create optimizer ( @@ -140,6 +148,10 @@ def update(sampled_tensordict): if compile_mode: update = torch.compile(update, mode=compile_mode) if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) update = CudaGraphModule(update, warmup=50) # Main loop diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index bde67d14e78..c5a06b4b156 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -10,9 +10,13 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + +import warnings import hydra import numpy as np + import torch import torch.cuda import tqdm @@ -33,6 +37,8 @@ make_replay_buffer, ) +torch.set_float32_matmul_precision("high") + @hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config") def main(cfg: "DictConfig"): # noqa: F821 @@ -70,7 +76,7 @@ def main(cfg: "DictConfig"): # noqa: F821 model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device) # Create loss - loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) + loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device) compile_mode = None if cfg.compile.compile: @@ -123,6 +129,10 @@ def update(sampled_tensordict): if compile_mode: update = torch.compile(update, mode=compile_mode) if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) update = CudaGraphModule(update, warmup=50) # Main loop @@ -170,6 +180,7 @@ def update(sampled_tensordict): sampled_tensordict = replay_buffer.sample() sampled_tensordict = sampled_tensordict.to(device) with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() loss_dict = update(sampled_tensordict) tds.append(loss_dict) diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 2dc280b03eb..ed0ca5476c5 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import torch.nn @@ -221,8 +223,8 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"): # distribution_kwargs=TensorDictParams( # TensorDict( # { - # "low": action_spec.space.low, - # "high": action_spec.space.high, + # "low": torch.as_tensor(action_spec.space.low, device=device), + # "high": torch.as_tensor(action_spec.space.high, device=device), # "tanh_loc": NonTensorData(False), # } # ), @@ -326,7 +328,7 @@ def make_cql_modules_state(model_cfg, proof_environment): # --------- -def make_continuous_loss(loss_cfg, model): +def make_continuous_loss(loss_cfg, model, device: torch.device | None = None): loss_module = CQLLoss( model[0], model[1], @@ -339,19 +341,19 @@ def make_continuous_loss(loss_cfg, model): with_lagrange=loss_cfg.with_lagrange, lagrange_thresh=loss_cfg.lagrange_thresh, ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) return loss_module, target_net_updater -def make_discrete_loss(loss_cfg, model): +def make_discrete_loss(loss_cfg, model, device: torch.device | None = None): loss_module = DiscreteCQLLoss( model, loss_function=loss_cfg.loss_function, delay_value=True, ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) return loss_module, target_net_updater diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml index 1dcbd3db92d..bd6276a6dcf 100644 --- a/sota-implementations/crossq/config.yaml +++ b/sota-implementations/crossq/config.yaml @@ -12,7 +12,7 @@ collector: init_random_frames: 25000 frames_per_batch: 1000 init_env_steps: 1000 - device: cpu + device: env_per_collector: 1 reset_at_each_iter: False @@ -46,7 +46,12 @@ network: actor_activation: relu default_policy_scale: 1.0 scale_lb: 0.1 - device: "cuda:0" + device: + +compile: + compile: False + compile_mode: + cudagraphs: False # logging logger: diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index b07ae880046..a0068b6662e 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -10,16 +10,23 @@ The helper functions are coded in the utils.py associated with this script. """ -import time +from __future__ import annotations + +import warnings import hydra import numpy as np + import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -32,6 +39,8 @@ make_replay_buffer, ) +torch.set_float32_matmul_precision("high") + @hydra.main(version_base="1.1", config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 @@ -69,10 +78,27 @@ def main(cfg: "DictConfig"): # noqa: F821 model, exploration_policy = make_crossQ_agent(cfg, train_env, device) # Create CrossQ loss - loss_module = make_loss_module(cfg, model) + loss_module = make_loss_module(cfg, model, device=device) + + compile_mode = None + if cfg.compile.compile: + if cfg.compile.compile_mode not in (None, ""): + compile_mode = cfg.compile.compile_mode + elif cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" # Create off-policy collector - collector = make_collector(cfg, train_env, exploration_policy.eval(), device=device) + collector = make_collector( + cfg, + train_env, + exploration_policy.eval(), + device=device, + compile=cfg.compile.compile, + compile_mode=compile_mode, + cudagraph=cfg.compile.cudagraphs, + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -89,9 +115,66 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_critic, optimizer_alpha, ) = make_crossQ_optimizer(cfg, loss_module) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha) + del optimizer_actor, optimizer_critic, optimizer_alpha + + def update_qloss(sampled_tensordict): + optimizer.zero_grad(set_to_none=True) + td_loss = {} + q_loss, value_meta = loss_module.qvalue_loss(sampled_tensordict) + sampled_tensordict.set(loss_module.tensor_keys.priority, value_meta["td_error"]) + q_loss = q_loss.mean() + + # Update critic + q_loss.backward() + optimizer.step() + td_loss["loss_qvalue"] = q_loss + td_loss["loss_actor"] = float("nan") + td_loss["loss_alpha"] = float("nan") + return TensorDict(td_loss, device=device).detach() + + def update_all(sampled_tensordict: TensorDict): + optimizer.zero_grad(set_to_none=True) + + td_loss = {} + q_loss, value_meta = loss_module.qvalue_loss(sampled_tensordict) + sampled_tensordict.set(loss_module.tensor_keys.priority, value_meta["td_error"]) + q_loss = q_loss.mean() + + actor_loss, metadata_actor = loss_module.actor_loss(sampled_tensordict) + actor_loss = actor_loss.mean() + alpha_loss = loss_module.alpha_loss( + log_prob=metadata_actor["log_prob"].detach() + ).mean() + + # Updates + (q_loss + actor_loss + actor_loss).backward() + optimizer.step() + + # Update critic + td_loss["loss_qvalue"] = q_loss + td_loss["loss_actor"] = actor_loss + td_loss["loss_alpha"] = alpha_loss + + return TensorDict(td_loss, device=device).detach() + + if compile_mode: + update_all = torch.compile(update_all, mode=compile_mode) + update_qloss = torch.compile(update_qloss, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update_all = CudaGraphModule(update_all, warmup=50) + update_qloss = CudaGraphModule(update_qloss, warmup=50) + + def update(sampled_tensordict: TensorDict, update_actor: bool): + if update_actor: + return update_all(sampled_tensordict) + return update_qloss(sampled_tensordict) # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -106,79 +189,45 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch = cfg.collector.frames_per_batch eval_rollout_steps = cfg.env.max_episode_steps - sampling_start = time.time() update_counter = 0 delayed_updates = cfg.optim.policy_update_delay - for _, tensordict in enumerate(collector): - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + tensordict = next(c_iter) # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) - - tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() - # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + pbar.update(current_frames) + tensordict = tensordict.reshape(-1) + + with timeit("rb - extend"): + # Add to replay buffer + replay_buffer.extend(tensordict) collected_frames += current_frames # Optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - ( - actor_losses, - alpha_losses, - q_losses, - ) = ([], [], []) + tds = [] for _ in range(num_updates): - # Update actor every delayed_updates update_counter += 1 update_actor = update_counter % delayed_updates == 0 # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to(device) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Compute loss - q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict) - q_loss = q_loss.mean() - # Update critic - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - q_losses.append(q_loss.detach().item()) - - if update_actor: - actor_loss, metadata_actor = loss_module.actor_loss( - sampled_tensordict - ) - actor_loss = actor_loss.mean() - alpha_loss = loss_module.alpha_loss( - log_prob=metadata_actor["log_prob"] - ).mean() - - # Update actor - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - # Update alpha - optimizer_alpha.zero_grad() - alpha_loss.backward() - optimizer_alpha.step() - - actor_losses.append(actor_loss.detach().item()) - alpha_losses.append(alpha_loss.detach().item()) - + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample().to(device) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + td_loss = update(sampled_tensordict, update_actor=update_actor) + tds.append(td_loss.clone()) # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start + tds = TensorDict.stack(tds).nanmean() episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -186,47 +235,47 @@ def main(cfg: "DictConfig"): # noqa: F821 ) episode_rewards = tensordict["next", "episode_reward"][episode_end] - # Logging metrics_to_log = {} - if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][episode_end] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( - episode_length - ) - if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses).item() - metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item() - metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], auto_cast_to_device=True, break_when_any_done=True, ) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + + # Logging + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if i % 20 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = tds["loss_qvalue"] + metrics_to_log["train/actor_loss"] = tds["loss_actor"] + metrics_to_log["train/alpha_loss"] = tds["loss_alpha"] + if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() + if i % 20 == 0: + timeit.print() + timeit.erase() collector.shutdown() if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index 483bf257c63..b124a619ea0 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch from tensordict.nn import InteractionType, TensorDictModule @@ -90,7 +91,15 @@ def make_environment(cfg): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore, device): +def make_collector( + cfg, + train_env, + actor_model_explore, + device, + compile=False, + compile_mode=None, + cudagraph=False, +): """Make collector.""" collector = SyncDataCollector( train_env, @@ -99,6 +108,8 @@ def make_collector(cfg, train_env, actor_model_explore, device): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, device=device, + compile_policy={"mode": compile_mode} if compile else False, + cudagraph_policy=cudagraph, ) collector.set_seed(cfg.env.seed) return collector @@ -164,9 +175,10 @@ def make_crossQ_agent(cfg, train_env, device): dist_class = TanhNormal dist_kwargs = { - "low": action_spec.space.low, - "high": action_spec.space.high, + "low": torch.as_tensor(action_spec.space.low, device=device), + "high": torch.as_tensor(action_spec.space.high, device=device), "tanh_loc": False, + "safe_tanh": not cfg.compile.compile, } actor_extractor = NormalParamExtractor( @@ -236,7 +248,7 @@ def make_crossQ_agent(cfg, train_env, device): # --------- -def make_loss_module(cfg, model): +def make_loss_module(cfg, model, device: torch.device | None = None): """Make loss module and target network updater.""" # Create CrossQ loss loss_module = CrossQLoss( @@ -246,7 +258,7 @@ def make_loss_module(cfg, model): loss_function=cfg.optim.loss_function, alpha_init=cfg.optim.alpha_init, ) - loss_module.make_value_estimator(gamma=cfg.optim.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma, device=device) return loss_module diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index cebc3685625..01198808fec 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -10,6 +10,8 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/ddpg/utils.py b/sota-implementations/ddpg/utils.py index 9495fd038f2..e9495aa2b93 100644 --- a/sota-implementations/ddpg/utils.py +++ b/sota-implementations/ddpg/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import torch diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index b892462339c..8093617ba9e 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -6,6 +6,8 @@ This is a self-contained example of an offline Decision Transformer training script. The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/decision_transformer/lamb.py b/sota-implementations/decision_transformer/lamb.py index 69468d1ad86..5118f8a2721 100644 --- a/sota-implementations/decision_transformer/lamb.py +++ b/sota-implementations/decision_transformer/lamb.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # Lamb optimizer directly copied from https://github.com/facebookresearch/online-dt +from __future__ import annotations + import math import torch diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 184c850b626..3577217f296 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -6,6 +6,8 @@ This is a self-contained example of an Online Decision Transformer training script. The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 7f905c72366..6bc1946b0a4 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch.nn diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index a9a08827f5d..cb39d3ad06e 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -10,6 +10,8 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/discrete_sac/utils.py b/sota-implementations/discrete_sac/utils.py index 8051f07fe95..bd4e13cc13e 100644 --- a/sota-implementations/discrete_sac/utils.py +++ b/sota-implementations/discrete_sac/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import tempfile from contextlib import nullcontext diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 5d0162080e2..4f37502ab76 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -7,6 +7,8 @@ DQN: Reproducing experimental results from Mnih et al. 2015 for the Deep Q-Learning Algorithm on Atari Environments. """ +from __future__ import annotations + import tempfile import time diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 8149c700958..b97d8c904fd 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index 6f39e824c60..1e5440a54b6 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch.nn import torch.optim diff --git a/sota-implementations/dqn/utils_cartpole.py b/sota-implementations/dqn/utils_cartpole.py index c7f7491ad15..d378f1ec76b 100644 --- a/sota-implementations/dqn/utils_cartpole.py +++ b/sota-implementations/dqn/utils_cartpole.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch.nn import torch.optim diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 1b9823c1dd1..0db55b3ee00 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import contextlib import time diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 41ea170ac76..9a99d86150e 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import tempfile from contextlib import nullcontext diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index a3c64693fb3..b4856fa7d0d 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -9,6 +9,8 @@ The helper functions for gail are coded in the gail_utils.py and helper functions for ppo in ppo_utils. """ +from __future__ import annotations + import hydra import numpy as np import torch diff --git a/sota-implementations/gail/gail_utils.py b/sota-implementations/gail/gail_utils.py index 067e9c8c927..ce09292cc47 100644 --- a/sota-implementations/gail/gail_utils.py +++ b/sota-implementations/gail/gail_utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch.nn as nn import torch.optim diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index 63310113e98..5669d93ce85 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch.nn import torch.optim diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index 0dc033d6dd1..ba40de1acde 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -7,6 +7,8 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ +from __future__ import annotations + import hydra from torchrl._utils import logger as torchrl_logger diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index 33df035c20e..5f77008a12b 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -7,6 +7,8 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ +from __future__ import annotations + import hydra from torchrl._utils import logger as torchrl_logger diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index cc37df6c783..130d0d30dd7 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -7,6 +7,8 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ +from __future__ import annotations + import hydra from torchrl._utils import logger as torchrl_logger diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py index 30293940377..248a98a389d 100644 --- a/sota-implementations/impala/utils.py +++ b/sota-implementations/impala/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch.nn import torch.optim diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index ae1894379fd..79cf2114d40 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -11,6 +11,8 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 53581782d20..09cf9954b86 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -9,6 +9,8 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 3cdff06ffa2..8497d24f106 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -11,6 +11,8 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index ff84d0d8138..d7d9e1a2d2f 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import torch.nn diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index 66cc3b6659e..2692c1c24b5 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index 1485e3e8c0b..f04ccb19071 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index 06cc2cd1fce..924ea12272a 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index 1bcc2dbd10e..a832a29e6dd 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index 694083e5b0f..31106bdd2a0 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/multiagent/utils/logging.py b/sota-implementations/multiagent/utils/logging.py index cb6df4de7ea..e19ae8d78f7 100644 --- a/sota-implementations/multiagent/utils/logging.py +++ b/sota-implementations/multiagent/utils/logging.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import os import numpy as np diff --git a/sota-implementations/multiagent/utils/utils.py b/sota-implementations/multiagent/utils/utils.py index d21bafdf691..e2513f30aa7 100644 --- a/sota-implementations/multiagent/utils/utils.py +++ b/sota-implementations/multiagent/utils/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from tensordict import unravel_key from torchrl.envs import Transform diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 30a19a64d6e..7878a0286e3 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -7,6 +7,8 @@ This script reproduces the Proximal Policy Optimization (PPO) Algorithm results from Schulman et al. 2017 for the Atari Environments. """ +from __future__ import annotations + import hydra from torchrl._utils import logger as torchrl_logger from torchrl.record import VideoRecorder diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index b98285f0726..c1d6fe52585 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -7,6 +7,8 @@ This script reproduces the Proximal Policy Optimization (PPO) Algorithm results from Schulman et al. 2017 for the on MuJoCo Environments. """ +from __future__ import annotations + import hydra from torchrl._utils import logger as torchrl_logger from torchrl.record import VideoRecorder diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index debc8f9e211..9be451331d8 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch.nn import torch.optim diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index 6c7a1b80fd7..ebbc6f7916d 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch.nn import torch.optim diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index 0732bf5f3b4..3dec888145c 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import uuid from datetime import datetime diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index a99094cf715..ee3e7d08df0 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -10,6 +10,8 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/sac/utils.py b/sota-implementations/sac/utils.py index d1dbb2db791..9760793c9cd 100644 --- a/sota-implementations/sac/utils.py +++ b/sota-implementations/sac/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import torch diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 01a59686ac9..70333f56cd9 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -10,6 +10,8 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index 665c2e0c674..a9bc8140291 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import tempfile from contextlib import nullcontext diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index 930ff509488..75be949df90 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -9,6 +9,8 @@ The helper functions are coded in the utils.py associated with this script. """ +from __future__ import annotations + import time import hydra diff --git a/sota-implementations/td3_bc/utils.py b/sota-implementations/td3_bc/utils.py index 582afaaac04..d0c3161861d 100644 --- a/sota-implementations/td3_bc/utils.py +++ b/sota-implementations/td3_bc/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import torch diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index bbd6a23bfdd..3b08780e24c 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -905,7 +905,6 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - if self.gamma.device != device: self.gamma = self.gamma.to(device) gamma = self.gamma @@ -1372,13 +1371,12 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - if self.gamma.device != device: self.gamma = self.gamma.to(device) + gamma = self.gamma if self.lmbda.device != device: self.lmbda = self.lmbda.to(device) - gamma, lmbda = self.gamma, self.lmbda - + lmbda = self.lmbda steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1459,13 +1457,12 @@ def value_estimate( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - if self.gamma.device != device: self.gamma = self.gamma.to(device) + gamma = self.gamma if self.lmbda.device != device: self.lmbda = self.lmbda.to(device) - gamma, lmbda = self.gamma, self.lmbda - + lmbda = self.lmbda steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward)