Skip to content

Commit

Permalink
[Feature] timeit.printevery
Browse files Browse the repository at this point in the history
ghstack-source-id: 19165bbfbea5cdc0a6b159493fb02571bab872f3
Pull Request resolved: #2653
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent f5a187d commit 187de7c
Show file tree
Hide file tree
Showing 21 changed files with 104 additions and 87 deletions.
10 changes: 5 additions & 5 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
lr = cfg.optim.lr

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)

with timeit("collecting"):
data = next(c_iter)

Expand Down Expand Up @@ -261,10 +264,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
"test/reward": test_rewards.mean(),
}
)
if i % 200 == 0:
log_info.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
Expand Down
10 changes: 5 additions & 5 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def update(batch):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)

with timeit("collecting"):
data = next(c_iter)

Expand Down Expand Up @@ -257,10 +260,7 @@ def update(batch):
)
actor.train()

if i % 200 == 0:
log_info.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
Expand Down
12 changes: 3 additions & 9 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
Expand All @@ -21,7 +20,7 @@
import tqdm
from tensordict.nn import CudaGraphModule

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl._utils import timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -156,9 +155,9 @@ def update(data, policy_eval_start, iteration):
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
policy_eval_start = torch.tensor(policy_eval_start, device=device)
for i in range(gradient_steps):
timeit.printevery(1000, gradient_steps, erase=True)
pbar.update(1)
# sample data
with timeit("sample"):
Expand Down Expand Up @@ -192,15 +191,10 @@ def update(data, policy_eval_start, iteration):
to_log["evaluation_reward"] = eval_reward

with timeit("log"):
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, to_log, i)
if i % 200 == 0:
timeit.print()
timeit.erase()

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
if not eval_env.is_closed:
eval_env.close()

Expand Down
10 changes: 4 additions & 6 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def update(sampled_tensordict):
eval_rollout_steps = cfg.logger.eval_steps

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
tensordict = next(c_iter)
pbar.update(tensordict.numel())
Expand Down Expand Up @@ -222,8 +224,7 @@ def update(sampled_tensordict):
"loss_alpha_prime"
).mean()
metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
if i % 10 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log.update(timeit.todict(prefix="time"))

# Evaluation
with timeit("eval"):
Expand All @@ -245,9 +246,6 @@ def update(sampled_tensordict):
metrics_to_log["eval/reward"] = eval_reward

log_metrics(logger, metrics_to_log, collected_frames)
if i % 10 == 0:
timeit.print()
timeit.erase()

collector.shutdown()
if not eval_env.is_closed:
Expand Down
11 changes: 4 additions & 7 deletions sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def update(sampled_tensordict):
frames_per_batch = cfg.collector.frames_per_batch

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for _ in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
torch.compiler.cudagraph_mark_step_begin()
tensordict = next(c_iter)
Expand Down Expand Up @@ -224,12 +226,7 @@ def update(sampled_tensordict):
tds = torch.stack(tds, dim=0).mean()
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/cql_loss"] = tds["loss_cql"]
if i % 100 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))

if i % 100 == 0:
timeit.print()
timeit.erase()
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
Expand Down
10 changes: 4 additions & 6 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
update_counter = 0
delayed_updates = cfg.optim.policy_update_delay
c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for _ in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
torch.compiler.cudagraph_mark_step_begin()
tensordict = next(c_iter)
Expand Down Expand Up @@ -258,18 +260,14 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log.update(timeit.todict(prefix="time"))
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/actor_loss"] = tds["loss_actor"]
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
if i % 20 == 0:
timeit.print()
timeit.erase()

collector.shutdown()
if not eval_env.is_closed:
Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def update(sampled_tensordict):
eval_rollout_steps = cfg.env.max_episode_steps

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for _ in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
tensordict = next(c_iter)
# Update exploration policy
Expand Down Expand Up @@ -226,10 +228,7 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
Expand Down
6 changes: 2 additions & 4 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def update(data: TensorDict) -> TensorDict:
# Pretraining
pbar = tqdm.tqdm(range(pretrain_gradient_steps))
for i in pbar:
timeit.printevery(1000, pretrain_gradient_steps, erase=True)
# Sample data
with timeit("rb - sample"):
data = offline_buffer.sample().to(model_device)
Expand All @@ -151,10 +152,7 @@ def update(data: TensorDict) -> TensorDict:
to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)
Expand Down
9 changes: 2 additions & 7 deletions sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
Expand Down Expand Up @@ -130,8 +129,8 @@ def update(data):

torchrl_logger.info(" ***Pretraining*** ")
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
timeit.printevery(1000, pretrain_gradient_steps, erase=True)
pbar.update(1)
with timeit("sample"):
# Sample data
Expand Down Expand Up @@ -170,18 +169,14 @@ def update(data):
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)

if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)

pbar.close()
if not test_env.is_closed:
test_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def update(sampled_tensordict):
frames_per_batch = cfg.collector.frames_per_batch

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
collected_data = next(c_iter)

Expand Down Expand Up @@ -229,10 +231,7 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
if i % 50 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
metrics_to_log.update(timeit.todict(prefix="time"))
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)

Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=total_frames)

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
data = next(c_iter)
log_info = {}
Expand Down Expand Up @@ -241,10 +243,7 @@ def update(sampled_tensordict):
)
model.train()

if i % 200 == 0:
timeit.print()
log_info.update(timeit.todict(prefix="time"))
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

# Log all the information
if logger:
Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def update(sampled_tensordict):
q_losses = torch.zeros(num_updates, device=device)

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
data = next(c_iter)

Expand Down Expand Up @@ -226,10 +228,7 @@ def update(sampled_tensordict):
}
)

if i % 200 == 0:
timeit.print()
log_info.update(timeit.todict(prefix="time"))
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

# Log all the information
if logger:
Expand Down
3 changes: 1 addition & 2 deletions sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,8 @@ def compile_rssms(module):
"t_sample": t_sample,
"t_preproc": t_preproc,
"t_collect": t_collect,
**timeit.todict(percall=False),
**timeit.todict(prefix="time"),
}
timeit.erase()
metrics_to_log.update(loss_metrics)

if logger is not None:
Expand Down
Loading

1 comment on commit 187de7c

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 187de7c Previous: f5a187d Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 39.40646741684915 iter/sec (stddev: 0.14686385600224153) 247.01828794382322 iter/sec (stddev: 0.0006587454642556527) 6.27

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.