From 1ce25f19ab1349bb794a7ba77867bfec9cb55fd9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 18 Dec 2024 10:20:36 +0000 Subject: [PATCH] [Feature] Log pbar rate in SOTA implementations ghstack-source-id: 283cc1bb4ad2d60281296d2cfb78ec41c77f4129 Pull Request resolved: https://github.com/pytorch/rl/pull/2662 --- sota-implementations/a2c/a2c_atari.py | 15 +++---- sota-implementations/a2c/a2c_mujoco.py | 18 ++++----- sota-implementations/cql/cql_offline.py | 9 +++-- sota-implementations/cql/cql_online.py | 3 +- .../cql/discrete_cql_online.py | 5 ++- sota-implementations/cql/utils.py | 2 +- sota-implementations/crossq/crossq.py | 3 +- sota-implementations/ddpg/ddpg.py | 3 +- .../decision_transformer/dt.py | 13 +++--- .../decision_transformer/online_dt.py | 10 ++--- .../decision_transformer/utils.py | 4 +- .../discrete_sac/discrete_sac.py | 3 +- sota-implementations/dqn/config_atari.yaml | 6 +-- sota-implementations/dqn/config_cartpole.yaml | 4 +- sota-implementations/dqn/dqn_atari.py | 40 ++++++++++--------- sota-implementations/dqn/dqn_cartpole.py | 24 +++++------ sota-implementations/gail/gail.py | 13 +++--- .../impala/impala_multi_node_ray.py | 14 +++---- .../impala/impala_multi_node_submitit.py | 14 +++---- .../impala/impala_single_node.py | 14 +++---- sota-implementations/iql/discrete_iql.py | 3 +- sota-implementations/iql/iql_offline.py | 9 +++-- sota-implementations/iql/iql_online.py | 3 +- .../multiagent/utils/logging.py | 18 ++++----- sota-implementations/ppo/ppo_atari.py | 27 +++++++------ sota-implementations/ppo/ppo_mujoco.py | 27 +++++++------ sota-implementations/sac/sac.py | 1 + sota-implementations/td3/td3.py | 1 + sota-implementations/td3_bc/td3_bc.py | 13 +++--- 29 files changed, 168 insertions(+), 151 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 47e43125ea4..3279d6e0a2b 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -189,7 +189,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): with timeit("collecting"): data = next(c_iter) - log_info = {} + metrics_to_log = {} frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) @@ -198,7 +198,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "terminated"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -242,8 +242,8 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): losses = torch.stack(losses).float().mean() for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": lr * alpha, } @@ -259,15 +259,16 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): test_rewards = eval_model( actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes ) - log_info.update( + metrics_to_log.update( { "test/reward": test_rewards.mean(), } ) - log_info.update(timeit.todict(prefix="time")) if logger: - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.shutdown() diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 07ad5197954..41e05dc1326 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -186,7 +186,7 @@ def update(batch): with timeit("collecting"): data = next(c_iter) - log_info = {} + metrics_to_log = {} frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(data.numel()) @@ -195,7 +195,7 @@ def update(batch): episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -236,8 +236,8 @@ def update(batch): # Get training losses losses = torch.stack(losses).float().mean() for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": alpha * cfg.optim.lr, } @@ -253,21 +253,19 @@ def update(batch): test_rewards = eval_model( actor, test_env, num_episodes=cfg.logger.num_test_episodes ) - log_info.update( + metrics_to_log.update( { "test/reward": test_rewards.mean(), } ) actor.train() - log_info.update(timeit.todict(prefix="time")) - if logger: - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) - torch.compiler.cudagraph_mark_step_begin() - collector.shutdown() if not test_env.is_closed: test_env.close() diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index c0030a1e9cc..2e1a20ad7a2 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -172,7 +172,7 @@ def update(data, policy_eval_start, iteration): ) # log metrics - to_log = { + metrics_to_log = { "loss": loss.cpu(), **loss_vals.cpu(), } @@ -188,11 +188,12 @@ def update(data, policy_eval_start, iteration): ) eval_env.apply(dump_video) eval_reward = eval_td["next", "reward"].sum(1).mean().item() - to_log["evaluation_reward"] = eval_reward + metrics_to_log["evaluation_reward"] = eval_reward with timeit("log"): - to_log.update(timeit.todict(prefix="time")) - log_metrics(logger, to_log, i) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() if not eval_env.is_closed: diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 03bdf6a493f..e992bdb5939 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -220,7 +220,6 @@ def update(sampled_tensordict): "loss_alpha_prime" ).mean() metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean() - metrics_to_log.update(timeit.todict(prefix="time")) # Evaluation with timeit("eval"): @@ -241,6 +240,8 @@ def update(sampled_tensordict): eval_env.apply(dump_video) metrics_to_log["eval/reward"] = eval_reward + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index 35238c5c6ab..d45ce3745fe 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -179,7 +179,7 @@ def update(sampled_tensordict): sampled_tensordict = sampled_tensordict.to(device) with timeit("update"): torch.compiler.cudagraph_mark_step_begin() - loss_dict = update(sampled_tensordict) + loss_dict = update(sampled_tensordict).clone() tds.append(loss_dict) # Update priority @@ -222,9 +222,10 @@ def update(sampled_tensordict): tds = torch.stack(tds, dim=0).mean() metrics_to_log["train/q_loss"] = tds["loss_qvalue"] metrics_to_log["train/cql_loss"] = tds["loss_cql"] - metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 306f1cdb7f1..8bbc70a32c3 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -185,7 +185,7 @@ def make_offline_replay_buffer(rb_cfg): dataset_id=rb_cfg.dataset, split_trajs=False, batch_size=rb_cfg.batch_size, - sampler=SamplerWithoutReplacement(drop_last=False), + sampler=SamplerWithoutReplacement(drop_last=True), prefetch=4, direct_download=True, ) diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index 07de3e26175..d84613e6876 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -256,13 +256,14 @@ def update(sampled_tensordict: TensorDict, update_actor: bool): metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( episode_length ) - metrics_to_log.update(timeit.todict(prefix="time")) if collected_frames >= init_random_frames: metrics_to_log["train/q_loss"] = tds["loss_qvalue"] metrics_to_log["train/actor_loss"] = tds["loss_actor"] metrics_to_log["train/alpha_loss"] = tds["loss_alpha"] if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index 6e2a749c3f1..bcb7ee6ef54 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -224,9 +224,10 @@ def update(sampled_tensordict): eval_env.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 57ba327b935..9e8446ed82f 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -76,7 +76,9 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_module = make_dt_loss(cfg.loss, actor, device=model_device) # Create optimizer - transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module) + transformer_optim, scheduler = make_dt_optimizer( + cfg.optim, loss_module, model_device + ) # Create inference policy inference_policy = DecisionTransformerInferenceWrapper( @@ -136,7 +138,7 @@ def update(data: TensorDict) -> TensorDict: loss_vals = update(data) scheduler.step() # Log metrics - to_log = {"train/loss": loss_vals["loss"]} + metrics_to_log = {"train/loss": loss_vals["loss"]} # Evaluation with set_exploration_type( @@ -149,13 +151,14 @@ def update(data: TensorDict) -> TensorDict: auto_cast_to_device=True, ) test_env.apply(dump_video) - to_log["eval/reward"] = ( + metrics_to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) - to_log.update(timeit.todict(prefix="time")) if logger is not None: - log_metrics(logger, to_log, i) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() if not test_env.is_closed: diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 7c6c9968774..1404cb7ebc0 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -143,7 +143,7 @@ def update(data): scheduler.step() # Log metrics - to_log = { + metrics_to_log = { "train/loss_log_likelihood": loss_vals["loss_log_likelihood"], "train/loss_entropy": loss_vals["loss_entropy"], "train/loss_alpha": loss_vals["loss_alpha"], @@ -165,14 +165,14 @@ def update(data): ) test_env.apply(dump_video) inference_policy.train() - to_log["eval/reward"] = ( + metrics_to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) - to_log.update(timeit.todict(prefix="time")) - if logger is not None: - log_metrics(logger, to_log, i) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() if not test_env.is_closed: diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 5f14734addd..d4a67e7d3a9 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -511,10 +511,10 @@ def make_odt_optimizer(optim_cfg, loss_module): return dt_optimizer, log_temp_optimizer, scheduler -def make_dt_optimizer(optim_cfg, loss_module): +def make_dt_optimizer(optim_cfg, loss_module, device): dt_optimizer = torch.optim.Adam( loss_module.actor_network_params.flatten_keys().values(), - lr=torch.as_tensor(optim_cfg.lr), + lr=torch.tensor(optim_cfg.lr, device=device), weight_decay=optim_cfg.weight_decay, eps=1.0e-8, ) diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index b7910c4e578..9ff50902887 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -227,8 +227,9 @@ def update(sampled_tensordict): eval_env.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/dqn/config_atari.yaml b/sota-implementations/dqn/config_atari.yaml index 021e7fd6132..85d513fbb2c 100644 --- a/sota-implementations/dqn/config_atari.yaml +++ b/sota-implementations/dqn/config_atari.yaml @@ -7,7 +7,7 @@ env: # collector collector: total_frames: 40_000_100 - frames_per_batch: 16 + frames_per_batch: 1600 eps_start: 1.0 eps_end: 0.01 annealing_frames: 4_000_000 @@ -38,9 +38,9 @@ optim: loss: gamma: 0.99 hard_update_freq: 10_000 - num_updates: 1 + num_updates: 100 compile: compile: False - compile_mode: + compile_mode: default cudagraphs: False diff --git a/sota-implementations/dqn/config_cartpole.yaml b/sota-implementations/dqn/config_cartpole.yaml index 58be7fb3bb5..199533ba9be 100644 --- a/sota-implementations/dqn/config_cartpole.yaml +++ b/sota-implementations/dqn/config_cartpole.yaml @@ -7,7 +7,7 @@ env: # collector collector: total_frames: 500_100 - frames_per_batch: 10 + frames_per_batch: 1000 eps_start: 1.0 eps_end: 0.05 annealing_frames: 250_000 @@ -37,7 +37,7 @@ optim: loss: gamma: 0.99 hard_update_freq: 50 - num_updates: 1 + num_updates: 100 compile: compile: False diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index b4236c3e89f..786e5d2ebb0 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -9,7 +9,7 @@ """ from __future__ import annotations -import tempfile +import functools import warnings import hydra @@ -64,20 +64,25 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create the replay buffer - if cfg.buffer.scratch_dir is None: - tempdir = tempfile.TemporaryDirectory() - scratch_dir = tempdir.name + if cfg.buffer.scratch_dir in ("", None): + storage_cls = LazyMemmapStorage else: - scratch_dir = cfg.buffer.scratch_dir + storage_cls = functools.partial( + LazyMemmapStorage, scratch_dir=cfg.buffer.scratch_dir + ) + + def transform(td): + return td.to(device) + replay_buffer = TensorDictReplayBuffer( pin_memory=False, - prefetch=3, - storage=LazyMemmapStorage( + storage=storage_cls( max_size=cfg.buffer.buffer_size, - scratch_dir=scratch_dir, ), batch_size=cfg.buffer.batch_size, ) + if transform is not None: + replay_buffer.append_transform(transform) # Create the loss module loss_module = DQNLoss( @@ -86,7 +91,7 @@ def main(cfg: "DictConfig"): # noqa: F821 delay_value=True, ) loss_module.set_keys(done="end-of-life", terminated="end-of-life") - loss_module.make_value_estimator(gamma=cfg.loss.gamma) + loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device) target_net_updater = HardUpdate( loss_module, value_network_update_interval=cfg.loss.hard_update_freq ) @@ -178,7 +183,7 @@ def update(sampled_tensordict): timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): data = next(c_iter) - log_info = {} + metrics_to_log = {} pbar.update(data.numel()) data = data.reshape(-1) current_frames = data.numel() * frame_skip @@ -193,7 +198,7 @@ def update(sampled_tensordict): episode_reward_mean = episode_rewards.mean().item() episode_length = data["next", "step_count"][data["next", "done"]] episode_length_mean = episode_length.sum().item() / len(episode_length) - log_info.update( + metrics_to_log.update( { "train/episode_reward": episode_reward_mean, "train/episode_length": episode_length_mean, @@ -202,7 +207,7 @@ def update(sampled_tensordict): if collected_frames < init_random_frames: if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, step=collected_frames) continue @@ -210,13 +215,12 @@ def update(sampled_tensordict): for j in range(num_updates): with timeit("rb - sample"): sampled_tensordict = replay_buffer.sample() - sampled_tensordict = sampled_tensordict.to(device) with timeit("update"): q_loss = update(sampled_tensordict) q_losses[j].copy_(q_loss) # Get and log q-values, loss, epsilon, sampling time and training time - log_info.update( + metrics_to_log.update( { "train/q_values": data["chosen_action_value"].sum() / frames_per_batch, "train/q_loss": q_losses.mean(), @@ -236,18 +240,18 @@ def update(sampled_tensordict): test_rewards = eval_model( model, test_env, num_episodes=num_test_episodes ) - log_info.update( + metrics_to_log.update( { "eval/reward": test_rewards, } ) model.train() - log_info.update(timeit.todict(prefix="time")) - # Log all the information if logger: - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, step=collected_frames) # update weights of the inference policy diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 57236337ced..4fde452fba9 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -55,11 +55,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create the replay buffer replay_buffer = TensorDictReplayBuffer( pin_memory=False, - prefetch=10, - storage=LazyTensorStorage( - max_size=cfg.buffer.buffer_size, - device="cpu", - ), + storage=LazyTensorStorage(max_size=cfg.buffer.buffer_size, device=device), batch_size=cfg.buffer.batch_size, ) @@ -69,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_function="l2", delay_value=True, ) - loss_module.make_value_estimator(gamma=cfg.loss.gamma) + loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device) loss_module = loss_module.to(device) target_net_updater = HardUpdate( loss_module, value_network_update_interval=cfg.loss.hard_update_freq @@ -162,7 +158,7 @@ def update(sampled_tensordict): with timeit("collecting"): data = next(c_iter) - log_info = {} + metrics_to_log = {} pbar.update(data.numel()) data = data.reshape(-1) current_frames = data.numel() @@ -178,7 +174,7 @@ def update(sampled_tensordict): episode_reward_mean = episode_rewards.mean().item() episode_length = data["next", "step_count"][data["next", "done"]] episode_length_mean = episode_length.sum().item() / len(episode_length) - log_info.update( + metrics_to_log.update( { "train/episode_reward": episode_reward_mean, "train/episode_length": episode_length_mean, @@ -188,7 +184,7 @@ def update(sampled_tensordict): if collected_frames < init_random_frames: if collected_frames < init_random_frames: if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, step=collected_frames) continue @@ -202,7 +198,7 @@ def update(sampled_tensordict): q_losses[j].copy_(q_loss) # Get and log q-values, loss, epsilon, sampling time and training time - log_info.update( + metrics_to_log.update( { "train/q_values": (data["action_value"] * data["action"]).sum().item() / frames_per_batch, @@ -222,17 +218,17 @@ def update(sampled_tensordict): model.eval() test_rewards = eval_model(model, test_env, num_test_episodes) model.train() - log_info.update( + metrics_to_log.update( { "eval/reward": test_rewards, } ) - log_info.update(timeit.todict(prefix="time")) - # Log all the information if logger: - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, step=collected_frames) # update weights of the inference policy diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 45d3acbb85f..bdb8843aaf6 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -265,7 +265,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): with timeit("collection"): data = next(collector_iter) - log_info = {} + metrics_to_log = {} frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(data.numel()) @@ -286,7 +286,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -294,7 +294,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): } ) - log_info.update( + metrics_to_log.update( { "train/discriminator_loss": d_loss["loss"], "train/lr": alpha * cfg_optim_lr, @@ -317,15 +317,16 @@ def update(data, expert_data, num_network_updates=num_network_updates): test_rewards = eval_model( actor, test_env, num_episodes=cfg_logger_num_test_episodes ) - log_info.update( + metrics_to_log.update( { "eval/reward": test_rewards.mean(), } ) actor.train() if logger is not None: - log_info.update(timeit.todict(prefix="time")) - log_metrics(logger, log_info, i) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index b2b724f6a6d..dcf908c2cd2 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -165,7 +165,7 @@ def main(cfg: "DictConfig"): # noqa: F821 start_time = sampling_start = time.time() for i, data in enumerate(collector): - log_info = {} + metrics_to_log = {} sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip @@ -175,7 +175,7 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "terminated"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if len(accumulator) < batch_size: accumulator.append(data) if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) continue @@ -243,8 +243,8 @@ def main(cfg: "DictConfig"): # noqa: F821 training_time = time.time() - training_start losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": alpha * lr, "train/sampling_time": sampling_time, @@ -263,7 +263,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, test_env, num_episodes=num_test_episodes ) eval_time = time.time() - eval_start - log_info.update( + metrics_to_log.update( { "eval/reward": test_reward, "eval/time": eval_time, @@ -272,7 +272,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor.train() if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index 07d38604391..4d90e9053bd 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -157,7 +157,7 @@ def main(cfg: "DictConfig"): # noqa: F821 start_time = sampling_start = time.time() for i, data in enumerate(collector): - log_info = {} + metrics_to_log = {} sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip @@ -167,7 +167,7 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -178,7 +178,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if len(accumulator) < batch_size: accumulator.append(data) if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) continue @@ -235,8 +235,8 @@ def main(cfg: "DictConfig"): # noqa: F821 training_time = time.time() - training_start losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": alpha * lr, "train/sampling_time": sampling_time, @@ -255,7 +255,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, test_env, num_episodes=num_test_episodes ) eval_time = time.time() - eval_start - log_info.update( + metrics_to_log.update( { "eval/reward": test_reward, "eval/time": eval_time, @@ -264,7 +264,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor.train() if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index cd11ae467c3..cda63ac0919 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -134,7 +134,7 @@ def main(cfg: "DictConfig"): # noqa: F821 start_time = sampling_start = time.time() for i, data in enumerate(collector): - log_info = {} + metrics_to_log = {} sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip @@ -144,7 +144,7 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "terminated"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -155,7 +155,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if len(accumulator) < batch_size: accumulator.append(data) if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) continue @@ -212,8 +212,8 @@ def main(cfg: "DictConfig"): # noqa: F821 training_time = time.time() - training_start losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": alpha * lr, "train/sampling_time": sampling_time, @@ -232,7 +232,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, test_env, num_episodes=num_test_episodes ) eval_time = time.time() - eval_start - log_info.update( + metrics_to_log.update( { "eval/reward": test_reward, "eval/time": eval_time, @@ -241,7 +241,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor.train() if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index e56661acf0c..aa4cea04024 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -226,8 +226,9 @@ def update(sampled_tensordict): metrics_to_log["train/q_loss"] = metadata["q_loss"] metrics_to_log["train/actor_loss"] = metadata["actor_loss"] metrics_to_log["train/value_loss"] = metadata["value_loss"] - metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 00f4cb24f5a..eaf791438cc 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -145,7 +145,7 @@ def update(data): loss_info = update(data) # evaluation - to_log = loss_info.to_dict() + metrics_to_log = loss_info.to_dict() if i % evaluation_interval == 0: with set_exploration_type( ExplorationType.DETERMINISTIC @@ -155,10 +155,11 @@ def update(data): ) eval_env.apply(dump_video) eval_reward = eval_td["next", "reward"].sum(1).mean().item() - to_log["evaluation_reward"] = eval_reward + metrics_to_log["evaluation_reward"] = eval_reward if logger is not None: - to_log.update(timeit.todict(prefix="time")) - log_metrics(logger, to_log, i) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() if not eval_env.is_closed: diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 7ec2a30dfd9..5b90f00c467 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -215,9 +215,10 @@ def update(sampled_tensordict): metrics_to_log["train/actor_loss"] = loss_info["loss_actor"] metrics_to_log["train/value_loss"] = loss_info["loss_value"] metrics_to_log["train/entropy"] = loss_info.get("entropy") - metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/multiagent/utils/logging.py b/sota-implementations/multiagent/utils/logging.py index e19ae8d78f7..40c9b70d578 100644 --- a/sota-implementations/multiagent/utils/logging.py +++ b/sota-implementations/multiagent/utils/logging.py @@ -56,13 +56,13 @@ def log_training( .unsqueeze(-1), ) - to_log = { + metrics_to_log = { f"train/learner/{key}": value.mean().item() for key, value in training_td.items() } if "info" in sampling_td.get("agents").keys(): - to_log.update( + metrics_to_log.update( { f"train/info/{key}": value.mean().item() for key, value in sampling_td.get(("agents", "info")).items() @@ -76,7 +76,7 @@ def log_training( episode_reward = sampling_td.get(("next", "agents", "episode_reward")).mean(-2)[ done ] - to_log.update( + metrics_to_log.update( { "train/reward/reward_min": reward.min().item(), "train/reward/reward_mean": reward.mean().item(), @@ -94,12 +94,12 @@ def log_training( } ) if isinstance(logger, WandbLogger): - logger.experiment.log(to_log, commit=False) + logger.experiment.log(metrics_to_log, commit=False) else: - for key, value in to_log.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key.replace("/", "_"), value, step=step) - return to_log + return metrics_to_log def log_evaluation( @@ -121,7 +121,7 @@ def log_evaluation( rollouts[k] = r[: done_index + 1] rewards = [td.get(("next", "agents", "reward")).sum(0).mean() for td in rollouts] - to_log = { + metrics_to_log = { "eval/episode_reward_min": min(rewards), "eval/episode_reward_max": max(rewards), "eval/episode_reward_mean": sum(rewards) / len(rollouts), @@ -138,7 +138,7 @@ def log_evaluation( if isinstance(logger, WandbLogger): import wandb - logger.experiment.log(to_log, commit=False) + logger.experiment.log(metrics_to_log, commit=False) logger.experiment.log( { "eval/video": wandb.Video(vid, fps=1 / env_test.world.dt, format="mp4"), @@ -146,6 +146,6 @@ def log_evaluation( commit=False, ) else: - for key, value in to_log.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key.replace("/", "_"), value, step=step) logger.log_video("eval_video", vid, step=step) diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 153a1ad9515..8ecb675535b 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -67,12 +67,11 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, "cpu"), + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, device=device, - storing_device=device, max_frames_per_traj=-1, compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False, cudagraph_policy=cfg.compile.cudagraphs, @@ -214,7 +213,7 @@ def update(batch, num_network_updates): with timeit("collecting"): data = next(collector_iter) - log_info = {} + metrics_to_log = {} frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(frames_in_batch) @@ -223,7 +222,7 @@ def update(batch, num_network_updates): episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "terminated"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -246,10 +245,11 @@ def update(batch, num_network_updates): data_buffer.extend(data_reshape) for k, batch in enumerate(data_buffer): - torch.compiler.cudagraph_mark_step_begin() - loss, num_network_updates = update( - batch, num_network_updates=num_network_updates - ) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss, num_network_updates = update( + batch, num_network_updates=num_network_updates + ) loss = loss.clone() num_network_updates = num_network_updates.clone() losses[j, k] = loss.select( @@ -259,8 +259,8 @@ def update(batch, num_network_updates): # Get training losses and times losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses_mean.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": loss["alpha"] * cfg_optim_lr, "train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon, @@ -278,15 +278,16 @@ def update(batch, num_network_updates): test_rewards = eval_model( actor, test_env, num_episodes=cfg_logger_num_test_episodes ) - log_info.update( + metrics_to_log.update( { "eval/reward": test_rewards.mean(), } ) actor.train() if logger: - log_info.update(timeit.todict(prefix="time")) - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index f8568be56e6..27ae7e57848 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -72,7 +72,6 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, device=device, - storing_device=device, max_frames_per_traj=-1, compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False, cudagraph_policy=cfg.compile.cudagraphs, @@ -207,7 +206,7 @@ def update(batch, num_network_updates): with timeit("collecting"): data = next(collector_iter) - log_info = {} + metrics_to_log = {} frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(frames_in_batch) @@ -216,7 +215,7 @@ def update(batch, num_network_updates): episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -240,11 +239,12 @@ def update(batch, num_network_updates): data_buffer.extend(data_reshape) for k, batch in enumerate(data_buffer): - torch.compiler.cudagraph_mark_step_begin() - loss, num_network_updates = update( - batch, num_network_updates=num_network_updates - ) - loss = loss.clone() + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss, num_network_updates = update( + batch, num_network_updates=num_network_updates + ) + loss = loss.clone() num_network_updates = num_network_updates.clone() losses[j, k] = loss.select( "loss_critic", "loss_entropy", "loss_objective" @@ -253,8 +253,8 @@ def update(batch, num_network_updates): # Get training losses and times losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses_mean.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": loss["alpha"] * cfg_optim_lr, "train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon @@ -274,7 +274,7 @@ def update(batch, num_network_updates): test_rewards = eval_model( actor, test_env, num_episodes=cfg_logger_num_test_episodes ) - log_info.update( + metrics_to_log.update( { "eval/reward": test_rewards.mean(), } @@ -282,8 +282,9 @@ def update(batch, num_network_updates): actor.train() if logger: - log_info.update(timeit.todict(prefix="time")) - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index a1ec631fe39..e159824f9cd 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -230,6 +230,7 @@ def update(sampled_tensordict): metrics_to_log["eval/reward"] = eval_reward if logger is not None: metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index bcbe6b879da..3a741735a1c 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -247,6 +247,7 @@ def update(sampled_tensordict, update_actor, prb=prb): metrics_to_log["eval/reward"] = eval_reward if logger is not None: metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index 35563777962..ac65f2875cf 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -151,11 +151,11 @@ def update(sampled_tensordict, update_actor): torch.compiler.cudagraph_mark_step_begin() metadata = update(sampled_tensordict, update_actor).clone() - to_log = {} + metrics_to_log = {} if update_actor: - to_log.update(metadata.to_dict()) + metrics_to_log.update(metadata.to_dict()) else: - to_log.update(metadata.exclude("actor_loss").to_dict()) + metrics_to_log.update(metadata.exclude("actor_loss").to_dict()) # evaluation if update_counter % evaluation_interval == 0: @@ -167,10 +167,11 @@ def update(sampled_tensordict, update_actor): ) eval_env.apply(dump_video) eval_reward = eval_td["next", "reward"].sum(1).mean().item() - to_log["evaluation_reward"] = eval_reward + metrics_to_log["evaluation_reward"] = eval_reward if logger is not None: - to_log.update(timeit.todict(prefix="time")) - log_metrics(logger, to_log, update_counter) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, update_counter) if not eval_env.is_closed: eval_env.close()