diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 73155d9fa1a..36a9b2478d5 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -15,8 +15,11 @@ import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -69,6 +72,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create agent model = make_cql_model(cfg, train_env, eval_env, device) del train_env + if hasattr(eval_env, "start"): + # To set the number of threads to the definitive value + eval_env.start() # Create loss loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) @@ -81,81 +87,104 @@ def main(cfg: "DictConfig"): # noqa: F821 alpha_prime_optim, ) = make_continuous_cql_optimizer(cfg, loss_module) - pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) + # Group optimizers + optimizer = group_optimizers( + policy_optim, critic_optim, alpha_optim, alpha_prime_optim + ) - gradient_steps = cfg.optim.gradient_steps - policy_eval_start = cfg.optim.policy_eval_start - evaluation_interval = cfg.logger.eval_iter - eval_steps = cfg.logger.eval_steps - - # Training loop - start_time = time.time() - for i in range(gradient_steps): - pbar.update(1) - # sample data - data = replay_buffer.sample() - # compute loss - loss_vals = loss_module(data.clone().to(device)) + def update(data, policy_eval_start, iteration): + loss_vals = loss_module(data.to(device)) # official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks - if i >= policy_eval_start: - actor_loss = loss_vals["loss_actor"] - else: - actor_loss = loss_vals["loss_actor_bc"] + actor_loss = torch.where( + iteration >= policy_eval_start, + loss_vals["loss_actor"], + loss_vals["loss_actor_bc"], + ) q_loss = loss_vals["loss_qvalue"] cql_loss = loss_vals["loss_cql"] q_loss = q_loss + cql_loss + loss_vals["q_loss"] = q_loss # update model alpha_loss = loss_vals["loss_alpha"] alpha_prime_loss = loss_vals["loss_alpha_prime"] + if alpha_prime_loss is None: + alpha_prime_loss = 0 - alpha_optim.zero_grad() - alpha_loss.backward() - alpha_optim.step() + loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss - policy_optim.zero_grad() - actor_loss.backward() - policy_optim.step() + loss.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) - if alpha_prime_optim is not None: - alpha_prime_optim.zero_grad() - alpha_prime_loss.backward(retain_graph=True) - alpha_prime_optim.step() + # update qnet_target params + target_net_updater.step() - critic_optim.zero_grad() - # TODO: we have the option to compute losses independently retain is not needed? - q_loss.backward(retain_graph=False) - critic_optim.step() + return loss.detach(), loss_vals.detach() - loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss + compile_mode = None + if cfg.compile.compile: + if cfg.compile.compile_mode not in (None, ""): + compile_mode = cfg.compile.compile_mode + elif cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + update = CudaGraphModule(update, warmup=50) + + pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) + + gradient_steps = cfg.optim.gradient_steps + policy_eval_start = cfg.optim.policy_eval_start + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + + # Training loop + start_time = time.time() + policy_eval_start = torch.tensor(policy_eval_start, device=device) + for i in range(gradient_steps): + pbar.update(1) + # sample data + with timeit("sample"): + data = replay_buffer.sample() + + with timeit("update"): + # compute loss + i_device = torch.tensor(i, device=device) + loss, loss_vals = update( + data.to(device), policy_eval_start=policy_eval_start, iteration=i_device + ) # log metrics to_log = { - "loss": loss.item(), - "loss_actor_bc": loss_vals["loss_actor_bc"].item(), - "loss_actor": loss_vals["loss_actor"].item(), - "loss_qvalue": q_loss.item(), - "loss_cql": cql_loss.item(), - "loss_alpha": alpha_loss.item(), - "loss_alpha_prime": alpha_prime_loss.item(), + "loss": loss.cpu(), + **loss_vals.cpu(), } - # update qnet_target params - target_net_updater.step() - # evaluation - if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_td = eval_env.rollout( - max_steps=eval_steps, policy=model[0], auto_cast_to_device=True - ) - eval_env.apply(dump_video) - eval_reward = eval_td["next", "reward"].sum(1).mean().item() - to_log["evaluation_reward"] = eval_reward - - log_metrics(logger, to_log, i) + with timeit("log/eval"): + if i % evaluation_interval == 0: + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): + eval_td = eval_env.rollout( + max_steps=eval_steps, policy=model[0], auto_cast_to_device=True + ) + eval_env.apply(dump_video) + eval_reward = eval_td["next", "reward"].sum(1).mean().item() + to_log["evaluation_reward"] = eval_reward + + with timeit("log"): + if i % 200 == 0: + to_log.update(timeit.todict(prefix="time")) + log_metrics(logger, to_log, i) + if i % 200 == 0: + timeit.print() + timeit.erase() pbar.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 215514d5bc7..b45340b60b2 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -11,15 +11,16 @@ The helper functions are coded in the utils.py associated with this script. """ -import time - import hydra import numpy as np import torch import tqdm from tensordict import TensorDict -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -82,8 +83,24 @@ def main(cfg: "DictConfig"): # noqa: F821 # create agent model = make_cql_model(cfg, train_env, eval_env, device) + compile_mode = None + if cfg.compile.compile: + if cfg.compile.compile_mode not in (None, ""): + compile_mode = cfg.compile.compile_mode + elif cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0]) + collector = make_collector( + cfg, + train_env, + actor_model_explore=model[0], + compile=cfg.compile.compile, + compile_mode=compile_mode, + cudagraph=cfg.compile.cudagraphs, + ) # Create loss loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) @@ -95,8 +112,37 @@ def main(cfg: "DictConfig"): # noqa: F821 alpha_optim, alpha_prime_optim, ) = make_continuous_cql_optimizer(cfg, loss_module) + optimizer = group_optimizers( + policy_optim, critic_optim, alpha_optim, alpha_prime_optim + ) + + def update(sampled_tensordict): + + loss_td = loss_module(sampled_tensordict) + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_qvalue"] + cql_loss = loss_td["loss_cql"] + q_loss = q_loss + cql_loss + alpha_loss = loss_td["loss_alpha"] + alpha_prime_loss = loss_td["loss_alpha_prime"] + + total_loss = alpha_loss + actor_loss + alpha_prime_loss + q_loss + total_loss.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + # update qnet_target params + target_net_updater.step() + + return loss_td.detach() + + if compile_mode: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + update = CudaGraphModule(update, warmup=50) + # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -111,69 +157,39 @@ def main(cfg: "DictConfig"): # noqa: F821 evaluation_interval = cfg.logger.log_interval eval_rollout_steps = cfg.logger.eval_steps - sampling_start = time.time() - for i, tensordict in enumerate(collector): - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + tensordict = next(c_iter) pbar.update(tensordict.numel()) # update weights of the inference policy collector.update_policy_weights_() - tensordict = tensordict.view(-1) - current_frames = tensordict.numel() - # add to replay buffer - replay_buffer.extend(tensordict.cpu()) - collected_frames += current_frames + with timeit("rb - extend"): + tensordict = tensordict.view(-1) + current_frames = tensordict.numel() + # add to replay buffer + replay_buffer.extend(tensordict) + collected_frames += current_frames - # optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - log_loss_td = TensorDict(batch_size=[num_updates]) + log_loss_td = TensorDict(batch_size=[num_updates], device=device) for j in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() - - loss_td = loss_module(sampled_tensordict) - - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] - cql_loss = loss_td["loss_cql"] - q_loss = q_loss + cql_loss - alpha_loss = loss_td["loss_alpha"] - alpha_prime_loss = loss_td["loss_alpha_prime"] - - alpha_optim.zero_grad() - alpha_loss.backward() - alpha_optim.step() - - policy_optim.zero_grad() - actor_loss.backward() - policy_optim.step() - - if alpha_prime_optim is not None: - alpha_prime_optim.zero_grad() - alpha_prime_loss.backward(retain_graph=True) - alpha_prime_optim.step() - - critic_optim.zero_grad() - q_loss.backward(retain_graph=False) - critic_optim.step() - + pbar.set_description(f"optim iter {j}") + with timeit("rb - sample"): + # sample from replay buffer + sampled_tensordict = replay_buffer.sample().to(device) + + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss_td = update(sampled_tensordict) log_loss_td[j] = loss_td.detach() - - # update qnet_target params - target_net_updater.step() - # update priority if prb: - replay_buffer.update_priority(sampled_tensordict) + with timeit("rb - update priority"): + replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] @@ -195,36 +211,32 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_alpha_prime" ).mean() metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + if i % 10 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) # Evaluation - - prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval - cur_test_frame = (i * frames_per_batch) // evaluation_interval - final = current_frames >= collector.total_frames - if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() - eval_rollout = eval_env.rollout( - eval_rollout_steps, - model[0], - auto_cast_to_device=True, - break_when_any_done=True, - ) - eval_time = time.time() - eval_start - eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - eval_env.apply(dump_video) - metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + with timeit("eval"): + prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval + cur_test_frame = (i * frames_per_batch) // evaluation_interval + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model[0], + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + eval_env.apply(dump_video) + metrics_to_log["eval/reward"] = eval_reward log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() - - collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") + if i % 10 == 0: + timeit.print() + timeit.erase() collector.shutdown() if not eval_env.is_closed: diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index 644b8ec624e..6db31a9aa81 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -10,7 +10,7 @@ env: # Collector collector: frames_per_batch: 200 - total_frames: 20000 + total_frames: 1_000_000 multi_step: 0 init_random_frames: 1000 env_per_collector: 1 @@ -57,3 +57,8 @@ loss: loss_function: l2 gamma: 0.99 tau: 0.005 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index d0d6693eb97..bde67d14e78 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -10,14 +10,15 @@ The helper functions are coded in the utils.py associated with this script. """ -import time import hydra import numpy as np import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -71,8 +72,24 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create loss loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) + compile_mode = None + if cfg.compile.compile: + if cfg.compile.compile_mode not in (None, ""): + compile_mode = cfg.compile.compile_mode + elif cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create off-policy collector - collector = make_collector(cfg, train_env, explore_policy) + collector = make_collector( + cfg, + train_env, + explore_policy, + compile=cfg.compile.compile, + compile_mode=compile_mode, + cudagraph=cfg.compile.cudagraphs, + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -86,6 +103,28 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizers optimizer = make_discrete_cql_optimizer(cfg, loss_module) + def update(sampled_tensordict): + # Compute loss + optimizer.zero_grad(set_to_none=True) + loss_dict = loss_module(sampled_tensordict) + + q_loss = loss_dict["loss_qvalue"] + cql_loss = loss_dict["loss_cql"] + loss = q_loss + cql_loss + + # Update model + loss.backward() + optimizer.step() + + # Update target params + target_net_updater.step() + return loss_dict.detach() + + if compile_mode: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + update = CudaGraphModule(update, warmup=50) + # Main loop collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -101,9 +140,11 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch - start_time = sampling_start = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + tensordict = next(c_iter) # Update exploration policy explore_policy[1].step(tensordict.numel()) @@ -111,53 +152,31 @@ def main(cfg: "DictConfig"): # noqa: F821 # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) + current_frames = tensordict.numel() + pbar.update(current_frames) tensordict = tensordict.reshape(-1) - current_frames = tensordict.numel() - # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("rb - extend"): + # Add to replay buffer + replay_buffer.extend(tensordict) collected_frames += current_frames # Optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - ( - q_losses, - cql_losses, - ) = ([], []) + tds = [] for _ in range(num_updates): - # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Compute loss - loss_dict = loss_module(sampled_tensordict) - - q_loss = loss_dict["loss_qvalue"] - cql_loss = loss_dict["loss_cql"] - loss = q_loss + cql_loss + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample() + sampled_tensordict = sampled_tensordict.to(device) + with timeit("update"): + loss_dict = update(sampled_tensordict) + tds.append(loss_dict) - # Update model - optimizer.zero_grad() - loss.backward() - optimizer.step() - q_losses.append(q_loss.item()) - cql_losses.append(cql_loss.item()) - - # Update target params - target_net_updater.step() # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -165,8 +184,23 @@ def main(cfg: "DictConfig"): # noqa: F821 ) episode_rewards = tensordict["next", "episode_reward"][episode_end] - # Logging metrics_to_log = {} + # Evaluation + with timeit("eval"): + if collected_frames % eval_iter < frames_per_batch: + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model, + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + + # Logging if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][episode_end] metrics_to_log["train/reward"] = episode_rewards.mean().item() @@ -176,33 +210,20 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/epsilon"] = explore_policy[1].eps if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses) - metrics_to_log["train/cql_loss"] = np.mean(cql_losses) - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + tds = torch.stack(tds, dim=0).mean() + metrics_to_log["train/q_loss"] = tds["loss_qvalue"] + metrics_to_log["train/cql_loss"] = tds["loss_cql"] + if i % 100 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) + + if i % 100 == 0: + timeit.print() + timeit.erase() - # Evaluation - if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() - eval_rollout = eval_env.rollout( - eval_rollout_steps, - model, - auto_cast_to_device=True, - break_when_any_done=True, - ) - eval_time = time.time() - eval_start - eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/cql/offline_config.yaml b/sota-implementations/cql/offline_config.yaml index bf213d4e3c5..a14604251c0 100644 --- a/sota-implementations/cql/offline_config.yaml +++ b/sota-implementations/cql/offline_config.yaml @@ -54,3 +54,8 @@ loss: num_random: 10 with_lagrange: True lagrange_thresh: 5.0 # tau + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml index 00db1d6bb62..5a8be9616a0 100644 --- a/sota-implementations/cql/online_config.yaml +++ b/sota-implementations/cql/online_config.yaml @@ -11,7 +11,7 @@ env: # Collector collector: frames_per_batch: 1000 - total_frames: 20000 + total_frames: 1_000_000 multi_step: 0 init_random_frames: 5_000 env_per_collector: 1 @@ -66,3 +66,8 @@ loss: num_random: 10 with_lagrange: True lagrange_thresh: 10.0 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 51134b6828d..2dc280b03eb 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -113,7 +113,14 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector( + cfg, + train_env, + actor_model_explore, + compile=False, + compile_mode=None, + cudagraph=False, +): """Make collector.""" collector = SyncDataCollector( train_env, @@ -123,6 +130,8 @@ def make_collector(cfg, train_env, actor_model_explore): max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, device=cfg.collector.device, + compile_policy={"mode": compile_mode} if compile else False, + cudagraph_policy=cudagraph, ) collector.set_seed(cfg.env.seed) return collector @@ -207,11 +216,21 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"): in_keys=["loc", "scale"], spec=action_spec, distribution_class=TanhNormal, + # Wrapping the kwargs in a TensorDictParams such that these items are + # send to device when necessary - not compatible with compile yet + # distribution_kwargs=TensorDictParams( + # TensorDict( + # { + # "low": action_spec.space.low, + # "high": action_spec.space.high, + # "tanh_loc": NonTensorData(False), + # } + # ), + # no_convert=True, + # ), distribution_kwargs={ - "low": action_spec.space.low[len(train_env.batch_size) :], - "high": action_spec.space.high[ - len(train_env.batch_size) : - ], # remove batch-size + "low": action_spec.space.low.to(device), + "high": action_spec.space.high.to(device), "tanh_loc": False, }, default_interaction_type=ExplorationType.RANDOM, diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index d54671f569b..c2627770de9 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -47,10 +47,15 @@ def _updater_check_forward_prehook(module, *args, **kwargs): def _forward_wrapper(func): @functools.wraps(func) def new_forward(self, *args, **kwargs): - with set_exploration_type(self.deterministic_sampling_mode), set_recurrent_mode( - True - ): + em = set_exploration_type(self.deterministic_sampling_mode) + em.__enter__() + rm = set_recurrent_mode(True) + rm.__enter__() + try: return func(self, *args, **kwargs) + finally: + em.__exit__(None, None, None) + rm.__exit__(None, None, None) return new_forward diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 191096e7492..375e3834dfc 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -610,16 +610,13 @@ def filter_and_repeat(name, x): tensordict = data.named_apply( filter_and_repeat, batch_size=batch_size, filter_empty=True ) - with torch.no_grad(): - with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module( - self.actor_network - ): - dist = self.actor_network.get_dist(tensordict) - action = dist.rsample() - tensordict.set(self.tensor_keys.action, action) - sample_log_prob = dist.log_prob(action) - # tensordict.del_("loc") - # tensordict.del_("scale") + with set_exploration_type(ExplorationType.RANDOM), actor_params.data.to_module( + self.actor_network + ): + dist = self.actor_network.get_dist(tensordict) + action = dist.rsample() + tensordict.set(self.tensor_keys.action, action) + sample_log_prob = dist.log_prob(action) return ( tensordict.select( @@ -631,59 +628,59 @@ def filter_and_repeat(name, x): def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): tensordict = tensordict.clone(False) # get actions and log-probs - with torch.no_grad(): - with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module( - self.actor_network + # TODO: wait for compile to handle this properly + actor_data = actor_params.data.to_module(self.actor_network) + with set_exploration_type(ExplorationType.RANDOM): + next_tensordict = tensordict.get("next").clone(False) + next_dist = self.actor_network.get_dist(next_tensordict) + next_action = next_dist.rsample() + next_tensordict.set(self.tensor_keys.action, next_action) + next_sample_log_prob = next_dist.log_prob(next_action) + actor_data.to_module(self.actor_network, return_swap=False) + + # get q-values + if not self.max_q_backup: + next_tensordict_expand = self._vmap_qvalue_networkN0( + next_tensordict, qval_params.data + ) + next_state_value = next_tensordict_expand.get( + self.tensor_keys.state_action_value + ).min(0)[0] + if ( + next_state_value.shape[-len(next_sample_log_prob.shape) :] + != next_sample_log_prob.shape ): - next_tensordict = tensordict.get("next").clone(False) - next_dist = self.actor_network.get_dist(next_tensordict) - next_action = next_dist.rsample() - next_tensordict.set(self.tensor_keys.action, next_action) - next_sample_log_prob = next_dist.log_prob(next_action) - - # get q-values - if not self.max_q_backup: - next_tensordict_expand = self._vmap_qvalue_networkN0( - next_tensordict, qval_params - ) - next_state_value = next_tensordict_expand.get( - self.tensor_keys.state_action_value - ).min(0)[0] - if ( - next_state_value.shape[-len(next_sample_log_prob.shape) :] - != next_sample_log_prob.shape - ): - next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) - if not self.deterministic_backup: - next_state_value = next_state_value - _alpha * next_sample_log_prob - - if self.max_q_backup: - next_tensordict, _ = self._get_policy_actions( - tensordict.get("next").copy(), - actor_params, - num_actions=self.num_random, - ) - next_tensordict_expand = self._vmap_qvalue_networkN0( - next_tensordict, qval_params - ) + next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) + if not self.deterministic_backup: + next_state_value = next_state_value - _alpha * next_sample_log_prob + + if self.max_q_backup: + next_tensordict, _ = self._get_policy_actions( + tensordict.get("next").copy(), + actor_params, + num_actions=self.num_random, + ) + next_tensordict_expand = self._vmap_qvalue_networkN0( + next_tensordict, qval_params.data + ) - state_action_value = next_tensordict_expand.get( - self.tensor_keys.state_action_value + state_action_value = next_tensordict_expand.get( + self.tensor_keys.state_action_value + ) + # take max over actions + state_action_value = state_action_value.reshape( + torch.Size( + [self.num_qvalue_nets, *tensordict.shape, self.num_random, -1] ) - # take max over actions - state_action_value = state_action_value.reshape( - torch.Size( - [self.num_qvalue_nets, *tensordict.shape, self.num_random, -1] - ) - ).max(-2)[0] - # take min over qvalue nets - next_state_value = state_action_value.min(0)[0] + ).max(-2)[0] + # take min over qvalue nets + next_state_value = state_action_value.min(0)[0] - tensordict.set( - ("next", self.value_estimator.tensor_keys.value), next_state_value - ) - target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) - return target_value + tensordict.set( + ("next", self.value_estimator.tensor_keys.value), next_state_value + ) + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + return target_value def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. @@ -897,8 +894,7 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor: def _alpha(self): if self.min_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) - with torch.no_grad(): - alpha = self.log_alpha.exp() + alpha = self.log_alpha.data.exp() return alpha @@ -1188,14 +1184,12 @@ def value_loss( pred_val_index = (pred_val * action).sum(-1) # calculate target value - with torch.no_grad(): - target_value = self.value_estimator.value_estimate( - td_copy, params=self._cached_detached_target_value_params - ).squeeze(-1) - - with torch.no_grad(): - td_error = (pred_val_index - target_value).pow(2) - td_error = td_error.unsqueeze(-1) + target_value = self.value_estimator.value_estimate( + td_copy, params=self._cached_detached_target_value_params + ).squeeze(-1) + + td_error = (pred_val_index - target_value).pow(2) + td_error = td_error.unsqueeze(-1) tensordict.set( self.tensor_keys.priority,