From 187de7c8bcf836197b2c9eda8a7c90fea96ec231 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 17:20:41 -0800 Subject: [PATCH] [Feature] timeit.printevery ghstack-source-id: 19165bbfbea5cdc0a6b159493fb02571bab872f3 Pull Request resolved: https://github.com/pytorch/rl/pull/2653 --- sota-implementations/a2c/a2c_atari.py | 10 +++---- sota-implementations/a2c/a2c_mujoco.py | 10 +++---- sota-implementations/cql/cql_offline.py | 12 ++------ sota-implementations/cql/cql_online.py | 10 +++---- .../cql/discrete_cql_online.py | 11 +++---- sota-implementations/crossq/crossq.py | 10 +++---- sota-implementations/ddpg/ddpg.py | 9 +++--- .../decision_transformer/dt.py | 6 ++-- .../decision_transformer/online_dt.py | 9 ++---- .../discrete_sac/discrete_sac.py | 9 +++--- sota-implementations/dqn/dqn_atari.py | 9 +++--- sota-implementations/dqn/dqn_cartpole.py | 9 +++--- sota-implementations/dreamer/dreamer.py | 3 +- sota-implementations/gail/gail.py | 30 ++++++++++++------- sota-implementations/iql/discrete_iql.py | 6 ++-- sota-implementations/iql/iql_offline.py | 3 ++ sota-implementations/iql/iql_online.py | 5 +++- sota-implementations/ppo/ppo_atari.py | 4 ++- sota-implementations/ppo/ppo_mujoco.py | 4 ++- torchrl/__init__.py | 1 + torchrl/_utils.py | 21 ++++++++++++- 21 files changed, 104 insertions(+), 87 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index c7f70308fd4..47e43125ea4 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -182,7 +182,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): lr = cfg.optim.lr c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) + with timeit("collecting"): data = next(c_iter) @@ -261,10 +264,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): "test/reward": test_rewards.mean(), } ) - if i % 200 == 0: - log_info.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + log_info.update(timeit.todict(prefix="time")) if logger: for key, value in log_info.items(): diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index cf88e7db01a..07ad5197954 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -179,7 +179,10 @@ def update(batch): pbar = tqdm.tqdm(total=cfg.collector.total_frames) c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) + with timeit("collecting"): data = next(c_iter) @@ -257,10 +260,7 @@ def update(batch): ) actor.train() - if i % 200 == 0: - log_info.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + log_info.update(timeit.todict(prefix="time")) if logger: for key, value in log_info.items(): diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index e74997eb37f..c0030a1e9cc 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -11,7 +11,6 @@ """ from __future__ import annotations -import time import warnings import hydra @@ -21,7 +20,7 @@ import tqdm from tensordict.nn import CudaGraphModule -from torchrl._utils import logger as torchrl_logger, timeit +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger @@ -156,9 +155,9 @@ def update(data, policy_eval_start, iteration): eval_steps = cfg.logger.eval_steps # Training loop - start_time = time.time() policy_eval_start = torch.tensor(policy_eval_start, device=device) for i in range(gradient_steps): + timeit.printevery(1000, gradient_steps, erase=True) pbar.update(1) # sample data with timeit("sample"): @@ -192,15 +191,10 @@ def update(data, policy_eval_start, iteration): to_log["evaluation_reward"] = eval_reward with timeit("log"): - if i % 200 == 0: - to_log.update(timeit.todict(prefix="time")) + to_log.update(timeit.todict(prefix="time")) log_metrics(logger, to_log, i) - if i % 200 == 0: - timeit.print() - timeit.erase() pbar.close() - torchrl_logger.info(f"Training time: {time.time() - start_time}") if not eval_env.is_closed: eval_env.close() diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 61a19894ce0..b61556874c3 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -170,7 +170,9 @@ def update(sampled_tensordict): eval_rollout_steps = cfg.logger.eval_steps c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): tensordict = next(c_iter) pbar.update(tensordict.numel()) @@ -222,8 +224,7 @@ def update(sampled_tensordict): "loss_alpha_prime" ).mean() metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean() - if i % 10 == 0: - metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log.update(timeit.todict(prefix="time")) # Evaluation with timeit("eval"): @@ -245,9 +246,6 @@ def update(sampled_tensordict): metrics_to_log["eval/reward"] = eval_reward log_metrics(logger, metrics_to_log, collected_frames) - if i % 10 == 0: - timeit.print() - timeit.erase() collector.shutdown() if not eval_env.is_closed: diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index c5a06b4b156..e6a710f1f4b 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -151,7 +151,9 @@ def update(sampled_tensordict): frames_per_batch = cfg.collector.frames_per_batch c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for _ in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): torch.compiler.cudagraph_mark_step_begin() tensordict = next(c_iter) @@ -224,12 +226,7 @@ def update(sampled_tensordict): tds = torch.stack(tds, dim=0).mean() metrics_to_log["train/q_loss"] = tds["loss_qvalue"] metrics_to_log["train/cql_loss"] = tds["loss_cql"] - if i % 100 == 0: - metrics_to_log.update(timeit.todict(prefix="time")) - - if i % 100 == 0: - timeit.print() - timeit.erase() + metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index a0068b6662e..5f6d762d644 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -192,7 +192,9 @@ def update(sampled_tensordict: TensorDict, update_actor: bool): update_counter = 0 delayed_updates = cfg.optim.policy_update_delay c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for _ in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): torch.compiler.cudagraph_mark_step_begin() tensordict = next(c_iter) @@ -258,8 +260,7 @@ def update(sampled_tensordict: TensorDict, update_actor: bool): metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( episode_length ) - if i % 20 == 0: - metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log.update(timeit.todict(prefix="time")) if collected_frames >= init_random_frames: metrics_to_log["train/q_loss"] = tds["loss_qvalue"] metrics_to_log["train/actor_loss"] = tds["loss_actor"] @@ -267,9 +268,6 @@ def update(sampled_tensordict: TensorDict, update_actor: bool): if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - if i % 20 == 0: - timeit.print() - timeit.erase() collector.shutdown() if not eval_env.is_closed: diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index c3e3c9eb835..9d06dc2ff75 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -156,7 +156,9 @@ def update(sampled_tensordict): eval_rollout_steps = cfg.env.max_episode_steps c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for _ in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): tensordict = next(c_iter) # Update exploration policy @@ -226,10 +228,7 @@ def update(sampled_tensordict): eval_env.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - if i % 20 == 0: - metrics_to_log.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 6ac058b9843..57ba327b935 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -128,6 +128,7 @@ def update(data: TensorDict) -> TensorDict: # Pretraining pbar = tqdm.tqdm(range(pretrain_gradient_steps)) for i in pbar: + timeit.printevery(1000, pretrain_gradient_steps, erase=True) # Sample data with timeit("rb - sample"): data = offline_buffer.sample().to(model_device) @@ -151,10 +152,7 @@ def update(data: TensorDict) -> TensorDict: to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) - if i % 200 == 0: - to_log.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, to_log, i) diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 9f3ec5f8134..7c6c9968774 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -8,7 +8,6 @@ """ from __future__ import annotations -import time import warnings import hydra @@ -130,8 +129,8 @@ def update(data): torchrl_logger.info(" ***Pretraining*** ") # Pretraining - start_time = time.time() for i in range(pretrain_gradient_steps): + timeit.printevery(1000, pretrain_gradient_steps, erase=True) pbar.update(1) with timeit("sample"): # Sample data @@ -170,10 +169,7 @@ def update(data): eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) - if i % 200 == 0: - to_log.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, to_log, i) @@ -181,7 +177,6 @@ def update(data): pbar.close() if not test_env.is_closed: test_env.close() - torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index c88206e1330..a5dad120a60 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -155,7 +155,9 @@ def update(sampled_tensordict): frames_per_batch = cfg.collector.frames_per_batch c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): collected_data = next(c_iter) @@ -229,10 +231,7 @@ def update(sampled_tensordict): eval_env.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - if i % 50 == 0: - metrics_to_log.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index d43ac25c822..b4236c3e89f 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -173,7 +173,9 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=total_frames) c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): data = next(c_iter) log_info = {} @@ -241,10 +243,7 @@ def update(sampled_tensordict): ) model.train() - if i % 200 == 0: - timeit.print() - log_info.update(timeit.todict(prefix="time")) - timeit.erase() + log_info.update(timeit.todict(prefix="time")) # Log all the information if logger: diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 873cf278d4b..57236337ced 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -156,7 +156,9 @@ def update(sampled_tensordict): q_losses = torch.zeros(num_updates, device=device) c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): data = next(c_iter) @@ -226,10 +228,7 @@ def update(sampled_tensordict): } ) - if i % 200 == 0: - timeit.print() - log_info.update(timeit.todict(prefix="time")) - timeit.erase() + log_info.update(timeit.todict(prefix="time")) # Log all the information if logger: diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 0db55b3ee00..a197796e978 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -275,9 +275,8 @@ def compile_rssms(module): "t_sample": t_sample, "t_preproc": t_preproc, "t_collect": t_collect, - **timeit.todict(percall=False), + **timeit.todict(prefix="time"), } - timeit.erase() metrics_to_log.update(loss_metrics) if logger is not None: diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index a02845cfe4d..45d3acbb85f 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -22,7 +22,7 @@ from ppo_utils import eval_model, make_env, make_ppo_models from tensordict.nn import CudaGraphModule -from torchrl._utils import compile_with_warmup +from torchrl._utils import compile_with_warmup, timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -256,19 +256,28 @@ def update(data, expert_data, num_network_updates=num_network_updates): cfg_logger_test_interval = cfg.logger.test_interval cfg_logger_num_test_episodes = cfg.logger.num_test_episodes - for i, data in enumerate(collector): + total_iter = len(collector) + collector_iter = iter(collector) + for i in range(total_iter): + + timeit.printevery(1000, total_iter, erase=True) + + with timeit("collection"): + data = next(collector_iter) log_info = {} frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(data.numel()) - # Update discriminator - # Get expert data - expert_data = replay_buffer.sample() - expert_data = expert_data.to(device) + with timeit("rb - sample expert"): + # Get expert data + expert_data = replay_buffer.sample() + expert_data = expert_data.to(device) - metadata = update(data, expert_data) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + metadata = update(data, expert_data) d_loss = metadata["dloss"] alpha = metadata["alpha"] @@ -287,8 +296,6 @@ def update(data, expert_data, num_network_updates=num_network_updates): log_info.update( { - # "train/actor_loss": actor_loss.item(), - # "train/critic_loss": critic_loss.item(), "train/discriminator_loss": d_loss["loss"], "train/lr": alpha * cfg_optim_lr, "train/clip_epsilon": ( @@ -300,7 +307,9 @@ def update(data, expert_data, num_network_updates=num_network_updates): ) # evaluation - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( i * frames_in_batch ) // cfg_logger_test_interval: @@ -315,6 +324,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): ) actor.train() if logger is not None: + log_info.update(timeit.todict(prefix="time")) log_metrics(logger, log_info, i) pbar.close() diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index e51bd25a8a8..805cfc6e23d 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -159,7 +159,10 @@ def update(sampled_tensordict): eval_rollout_steps = cfg.collector.max_frames_per_traj collector_iter = iter(collector) - for _ in range(len(collector)): + total_iter = len(collector) + for _ in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) + with timeit("collection"): tensordict = next(collector_iter) current_frames = tensordict.numel() @@ -230,7 +233,6 @@ def update(sampled_tensordict): metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - timeit.erase() collector.shutdown() diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 1a270ee8ccc..00f4cb24f5a 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -133,6 +133,8 @@ def update(data): # Training loop for i in pbar: + timeit.printevery(1000, cfg.optim.gradient_steps, erase=True) + # sample data with timeit("sample"): data = replay_buffer.sample() @@ -155,6 +157,7 @@ def update(data): eval_reward = eval_td["next", "reward"].sum(1).mean().item() to_log["evaluation_reward"] = eval_reward if logger is not None: + to_log.update(timeit.todict(prefix="time")) log_metrics(logger, to_log, i) pbar.close() diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 4f6c765d1e8..28b35099286 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -156,7 +156,10 @@ def update(sampled_tensordict): eval_rollout_steps = cfg.collector.max_frames_per_traj collector_iter = iter(collector) pbar = tqdm.tqdm(range(collector.total_frames)) - for _ in range(len(collector)): + total_iter = len(collector) + for _ in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) + with timeit("collection"): tensordict = next(collector_iter) current_frames = tensordict.numel() diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index cc42ef38f9d..153a1ad9515 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -207,8 +207,10 @@ def update(batch, num_network_updates): losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches]) collector_iter = iter(collector) + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) - for i in range(len(collector)): with timeit("collecting"): data = next(collector_iter) diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index a0cf2726aca..f8568be56e6 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -200,8 +200,10 @@ def update(batch, num_network_updates): losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches]) collector_iter = iter(collector) + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) - for i in range(len(collector)): with timeit("collecting"): data = next(collector_iter) diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 7a41bf0ab8f..d4c75c85179 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -52,6 +52,7 @@ import torchrl.modules import torchrl.objectives import torchrl.trainers +from torchrl._utils import compile_with_warmup, timeit # Filter warnings in subprocesses: True by default given the multiple optional # deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`. diff --git a/torchrl/_utils.py b/torchrl/_utils.py index cc1621d8723..6a2f80aeffb 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -103,7 +103,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): val[2] = N @staticmethod - def print(prefix=None) -> str: # noqa: T202 + def print(prefix: str = None) -> str: # noqa: T202 """Prints the state of the timer. Returns: @@ -123,6 +123,25 @@ def print(prefix=None) -> str: # noqa: T202 logger.info(string[-1]) return "\n".join(string) + _printevery_count = 0 + + @classmethod + def printevery( + cls, + num_prints: int, + total_count: int, + *, + prefix: str = None, + erase: bool = False, + ) -> None: + """Prints the state of the timer at regular intervals.""" + interval = max(1, total_count // num_prints) + if cls._printevery_count % interval == 0: + cls.print(prefix=prefix) + if erase: + cls.erase() + cls._printevery_count += 1 + @classmethod def todict(cls, percall=True, prefix=None): def _make_key(key):