From 2cfc2abd6f98444982ff7413ec98dbb0079bdafb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 14:01:46 -0800 Subject: [PATCH] [Feature] IQL compatibility with compile ghstack-source-id: 77bca166701d28dd69ef3964f55ab4f3e4b17fed Pull Request resolved: https://github.com/pytorch/rl/pull/2649 --- sota-implementations/iql/discrete_iql.py | 174 +++++++++++-------- sota-implementations/iql/discrete_iql.yaml | 5 + sota-implementations/iql/iql_offline.py | 90 ++++++---- sota-implementations/iql/iql_online.py | 162 +++++++++-------- sota-implementations/iql/offline_config.yaml | 5 + sota-implementations/iql/online_config.yaml | 5 + sota-implementations/iql/utils.py | 83 +++++---- torchrl/data/utils.py | 2 +- torchrl/objectives/iql.py | 45 +++-- 9 files changed, 327 insertions(+), 244 deletions(-) diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 79cf2114d40..e51bd25a8a8 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -13,16 +13,20 @@ """ from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -37,6 +41,9 @@ ) +torch.set_float32_matmul_precision("high") + + @hydra.main(config_path="", config_name="discrete_iql") def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() @@ -87,16 +94,54 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create model model = make_discrete_iql_model(cfg, train_env, eval_env, device) + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0]) + collector = make_collector( + cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode + ) # Create loss - loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) + loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device) # Create optimizer optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( cfg.optim, loss_module ) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value) + del optimizer_actor, optimizer_critic, optimizer_value + + def update(sampled_tensordict): + optimizer.zero_grad(set_to_none=True) + # compute losses + actor_loss, _ = loss_module.actor_loss(sampled_tensordict) + value_loss, _ = loss_module.value_loss(sampled_tensordict) + q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict) + (actor_loss + value_loss + q_loss).backward() + optimizer.step() + + # update qnet_target params + target_net_updater.step() + metadata.update( + {"actor_loss": actor_loss, "value_loss": value_loss, "q_loss": q_loss} + ) + return TensorDict(metadata).detach() + + if cfg.compile.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) # Main loop collected_frames = 0 @@ -112,84 +157,53 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch eval_rollout_steps = cfg.collector.max_frames_per_traj - sampling_start = start_time = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start - pbar.update(tensordict.numel()) + + collector_iter = iter(collector) + for _ in range(len(collector)): + with timeit("collection"): + tensordict = next(collector_iter) + current_frames = tensordict.numel() + pbar.update(current_frames) + # update weights of the inference policy collector.update_policy_weights_() - tensordict = tensordict.reshape(-1) - current_frames = tensordict.numel() - # add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("buffer - extend"): + tensordict = tensordict.reshape(-1) + + # add to replay buffer + replay_buffer.extend(tensordict) collected_frames += current_frames # optimization steps - training_start = time.time() - if collected_frames >= init_random_frames: - for _ in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample().clone() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict - # compute losses - actor_loss, _ = loss_module.actor_loss(sampled_tensordict) - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - value_loss, _ = loss_module.value_loss(sampled_tensordict) - optimizer_value.zero_grad() - value_loss.backward() - optimizer_value.step() - - q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict) - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - - # update qnet_target params - target_net_updater.step() - - # update priority - if prb: - sampled_tensordict.set( - loss_module.tensor_keys.priority, - metadata.pop("td_error").detach().max(0).values, - ) - replay_buffer.update_priority(sampled_tensordict) - - training_time = time.time() - training_start + with timeit("training"): + if collected_frames >= init_random_frames: + for _ in range(num_updates): + # sample from replay buffer + with timeit("buffer - sample"): + sampled_tensordict = replay_buffer.sample().to(device) + + with timeit("training - update"): + torch.compiler.cudagraph_mark_step_begin() + metadata = update(sampled_tensordict) + # update priority + if prb: + sampled_tensordict.set( + loss_module.tensor_keys.priority, + metadata.pop("td_error").detach().max(0).values, + ) + replay_buffer.update_priority(sampled_tensordict) + episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] - # Logging metrics_to_log = {} - if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][ - tensordict["next", "done"] - ] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( - episode_length - ) - if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = q_loss.detach() - metrics_to_log["train/actor_loss"] = actor_loss.detach() - metrics_to_log["train/value_loss"] = value_loss.detach() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time - # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], @@ -197,18 +211,28 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + + # Logging + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = metadata["q_loss"] + metrics_to_log["train/actor_loss"] = metadata["actor_loss"] + metrics_to_log["train/value_loss"] = metadata["value_loss"] + metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() + timeit.erase() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/iql/discrete_iql.yaml b/sota-implementations/iql/discrete_iql.yaml index d28c02cf499..3f53ab9a68a 100644 --- a/sota-implementations/iql/discrete_iql.yaml +++ b/sota-implementations/iql/discrete_iql.yaml @@ -59,3 +59,8 @@ loss: # IQL specific hyperparameter temperature: 100 expectile: 0.8 + +compile: + compile: False + compile_mode: default + cudagraphs: False diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 09cf9954b86..1a270ee8ccc 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -11,16 +11,19 @@ """ from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -34,6 +37,9 @@ ) +torch.set_float32_matmul_precision("high") + + @hydra.main(config_path="", config_name="offline_config") def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() @@ -79,60 +85,69 @@ def main(cfg: "DictConfig"): # noqa: F821 model = make_iql_model(cfg, train_env, eval_env, device) # Create loss - loss_module, target_net_updater = make_loss(cfg.loss, model) + loss_module, target_net_updater = make_loss(cfg.loss, model, device=device) # Create optimizer optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( cfg.optim, loss_module ) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value) - pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) - - gradient_steps = cfg.optim.gradient_steps - evaluation_interval = cfg.logger.eval_iter - eval_steps = cfg.logger.eval_steps - - # Training loop - start_time = time.time() - for i in range(gradient_steps): - pbar.update(1) - # sample data - data = replay_buffer.sample() - - if data.device != device: - data = data.to(device, non_blocking=True) - + def update(data): + optimizer.zero_grad(set_to_none=True) # compute losses loss_info = loss_module(data) actor_loss = loss_info["loss_actor"] value_loss = loss_info["loss_value"] q_loss = loss_info["loss_qvalue"] - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - optimizer_value.zero_grad() - value_loss.backward() - optimizer_value.step() - - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() + (actor_loss + value_loss + q_loss).backward() + optimizer.step() # update qnet_target params target_net_updater.step() + return loss_info.detach() + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + + if cfg.compile.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + pbar = tqdm.tqdm(range(cfg.optim.gradient_steps)) + + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + + # Training loop + for i in pbar: + # sample data + with timeit("sample"): + data = replay_buffer.sample() + data = data.to(device) - # log metrics - to_log = { - "loss_actor": actor_loss.item(), - "loss_qvalue": q_loss.item(), - "loss_value": value_loss.item(), - } + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss_info = update(data) # evaluation + to_log = loss_info.to_dict() if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) @@ -147,7 +162,6 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_env.close() if not train_env.is_closed: train_env.close() - torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 8497d24f106..4f6c765d1e8 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -13,16 +13,19 @@ """ from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -37,6 +40,9 @@ ) +torch.set_float32_matmul_precision("high") + + @hydra.main(config_path="", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() @@ -87,20 +93,56 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create model model = make_iql_model(cfg, train_env, eval_env, device) + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0]) + collector = make_collector( + cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode + ) # Create loss - loss_module, target_net_updater = make_loss(cfg.loss, model) + loss_module, target_net_updater = make_loss(cfg.loss, model, device=device) # Create optimizer optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( cfg.optim, loss_module ) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value) + del optimizer_actor, optimizer_critic, optimizer_value + + def update(sampled_tensordict): + optimizer.zero_grad(set_to_none=True) + # compute losses + loss_info = loss_module(sampled_tensordict) + actor_loss = loss_info["loss_actor"] + value_loss = loss_info["loss_value"] + q_loss = loss_info["loss_qvalue"] + + (actor_loss + value_loss + q_loss).backward() + optimizer.step() + + # update qnet_target params + target_net_updater.step() + return loss_info.detach() + + if cfg.compile.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) # Main loop collected_frames = 0 - pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames num_updates = int( @@ -112,82 +154,46 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch eval_rollout_steps = cfg.collector.max_frames_per_traj - sampling_start = start_time = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start - pbar.update(tensordict.numel()) + collector_iter = iter(collector) + pbar = tqdm.tqdm(range(collector.total_frames)) + for _ in range(len(collector)): + with timeit("collection"): + tensordict = next(collector_iter) + current_frames = tensordict.numel() + pbar.update(current_frames) # update weights of the inference policy collector.update_policy_weights_() - tensordict = tensordict.view(-1) - current_frames = tensordict.numel() - # add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("rb - extend"): + # add to replay buffer + tensordict = tensordict.reshape(-1) + replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames # optimization steps - training_start = time.time() - if collected_frames >= init_random_frames: - for _ in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample().clone() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict - # compute losses - loss_info = loss_module(sampled_tensordict) - actor_loss = loss_info["loss_actor"] - value_loss = loss_info["loss_value"] - q_loss = loss_info["loss_qvalue"] - - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - optimizer_value.zero_grad() - value_loss.backward() - optimizer_value.step() - - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - - # update qnet_target params - target_net_updater.step() - - # update priority - if prb: - replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start + with timeit("training"): + if collected_frames >= init_random_frames: + for _ in range(num_updates): + with timeit("rb - sampling"): + # sample from replay buffer + sampled_tensordict = replay_buffer.sample().to(device) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss_info = update(sampled_tensordict) + # update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] # Logging metrics_to_log = {} - if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][ - tensordict["next", "done"] - ] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( - episode_length - ) - if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = q_loss.detach() - metrics_to_log["train/actor_loss"] = actor_loss.detach() - metrics_to_log["train/value_loss"] = value_loss.detach() - metrics_to_log["train/entropy"] = loss_info.get("entropy").detach() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time - # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("evaluating"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], @@ -195,25 +201,33 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = loss_info["loss_qvalue"] + metrics_to_log["train/actor_loss"] = loss_info["loss_actor"] + metrics_to_log["train/value_loss"] = loss_info["loss_value"] + metrics_to_log["train/entropy"] = loss_info.get("entropy") + metrics_to_log.update(timeit.todict(prefix="time")) + if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") - if __name__ == "__main__": main() diff --git a/sota-implementations/iql/offline_config.yaml b/sota-implementations/iql/offline_config.yaml index 5f34fa5651a..ff739387c9d 100644 --- a/sota-implementations/iql/offline_config.yaml +++ b/sota-implementations/iql/offline_config.yaml @@ -47,3 +47,8 @@ loss: # IQL specific hyperparameter temperature: 3.0 expectile: 0.7 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/iql/online_config.yaml b/sota-implementations/iql/online_config.yaml index 64ad7466192..070740a8707 100644 --- a/sota-implementations/iql/online_config.yaml +++ b/sota-implementations/iql/online_config.yaml @@ -61,3 +61,8 @@ loss: # IQL specific hyperparameter temperature: 3.0 expectile: 0.7 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 261cb912de0..519d4350536 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -10,6 +10,7 @@ import torch.optim from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor +from torch.distributions import Categorical from torchrl.collectors import SyncDataCollector from torchrl.data import ( @@ -36,7 +37,6 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( MLP, - OneHotCategorical, ProbabilisticActor, SafeModule, TanhNormal, @@ -44,7 +44,6 @@ ) from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate from torchrl.record import VideoRecorder - from torchrl.trainers.helpers.models import ACTIVATIONS @@ -58,7 +57,11 @@ def env_maker(cfg, device="cpu", from_pixels=False): if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False + cfg.env.name, + device=device, + from_pixels=from_pixels, + pixels_only=False, + categorical_action_encoding=True, ) elif lib == "dm_control": env = DMControlEnv( @@ -118,7 +121,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector(cfg, train_env, actor_model_explore, compile_mode): """Make collector.""" device = cfg.collector.device if device in ("", None): @@ -134,6 +137,8 @@ def make_collector(cfg, train_env, actor_model_explore): max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, device=device, + compile_policy={"mode": compile_mode} if compile_mode else False, + cudagraph_policy=cfg.compile.cudagraphs, ) collector.set_seed(cfg.env.seed) return collector @@ -179,7 +184,8 @@ def make_offline_replay_buffer(rb_cfg): dataset_id=rb_cfg.dataset, split_trajs=False, batch_size=rb_cfg.batch_size, - sampler=SamplerWithoutReplacement(drop_last=False), + # We use drop_last to avoid recompiles (and dynamic shapes) + sampler=SamplerWithoutReplacement(drop_last=True), prefetch=4, direct_download=True, ) @@ -219,8 +225,8 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "low": action_spec.space.low, - "high": action_spec.space.high, + "low": action_spec.space.low.to(device), + "high": action_spec.space.high.to(device), "tanh_loc": False, }, default_interaction_type=ExplorationType.RANDOM, @@ -244,12 +250,10 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): model = torch.nn.ModuleList([actor, qvalue, value_net]).to(device) # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = eval_env.reset() + td = eval_env.fake_tensordict() td = td.to(device) for net in model: net(td) - del td - eval_env.close() return model @@ -292,19 +296,16 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): """Make discrete IQL agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] + action_spec = train_env.action_spec_unbatched # Define Actor Network in_keys = ["observation"] - actor_net_kwargs = { - "num_cells": cfg.model.hidden_sizes, - "out_features": action_spec.shape[-1], - "activation_class": ACTIVATIONS[cfg.model.activation], - } - - actor_net = MLP(**actor_net_kwargs) + actor_net = MLP( + num_cells=cfg.model.hidden_sizes, + out_features=action_spec.space.n, + activation_class=ACTIVATIONS[cfg.model.activation], + device=device, + ) actor_module = SafeModule( module=actor_net, @@ -312,26 +313,23 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): out_keys=["logits"], ) actor = ProbabilisticActor( - spec=Composite(action=eval_env.action_spec), + spec=Composite(action=eval_env.action_spec_unbatched).to(device), module=actor_module, in_keys=["logits"], out_keys=["action"], - distribution_class=OneHotCategorical, + distribution_class=Categorical, distribution_kwargs={}, default_interaction_type=InteractionType.RANDOM, return_log_prob=False, ) # Define Critic Network - qvalue_net_kwargs = { - "num_cells": cfg.model.hidden_sizes, - "out_features": action_spec.shape[-1], - "activation_class": ACTIVATIONS[cfg.model.activation], - } qvalue_net = MLP( - **qvalue_net_kwargs, + num_cells=cfg.model.hidden_sizes, + out_features=action_spec.space.n, + activation_class=ACTIVATIONS[cfg.model.activation], + device=device, ) - qvalue = TensorDictModule( in_keys=["observation"], out_keys=["state_action_value"], @@ -339,27 +337,25 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): ) # Define Value Network - value_net_kwargs = { - "num_cells": cfg.model.hidden_sizes, - "out_features": 1, - "activation_class": ACTIVATIONS[cfg.model.activation], - } - value_net = MLP(**value_net_kwargs) + value_net = MLP( + num_cells=cfg.model.hidden_sizes, + out_features=1, + activation_class=ACTIVATIONS[cfg.model.activation], + device=device, + ) value_net = TensorDictModule( in_keys=["observation"], out_keys=["state_value"], module=value_net, ) - model = torch.nn.ModuleList([actor, qvalue, value_net]).to(device) + model = torch.nn.ModuleList([actor, qvalue, value_net]) # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = eval_env.reset() + td = eval_env.fake_tensordict() td = td.to(device) for net in model: net(td) - del td - eval_env.close() return model @@ -369,7 +365,7 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): # --------- -def make_loss(loss_cfg, model): +def make_loss(loss_cfg, model, device): loss_module = IQLLoss( model[0], model[1], @@ -378,13 +374,13 @@ def make_loss(loss_cfg, model): temperature=loss_cfg.temperature, expectile=loss_cfg.expectile, ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) return loss_module, target_net_updater -def make_discrete_loss(loss_cfg, model): +def make_discrete_loss(loss_cfg, model, device): loss_module = DiscreteIQLLoss( model[0], model[1], @@ -392,8 +388,9 @@ def make_discrete_loss(loss_cfg, model): loss_function=loss_cfg.loss_function, temperature=loss_cfg.temperature, expectile=loss_cfg.expectile, + action_space="categorical", ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device) target_net_updater = HardUpdate( loss_module, value_network_update_interval=loss_cfg.hard_update_interval ) diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index db2c8afca10..d43cbd7810d 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -307,7 +307,7 @@ def _process_action_space_spec(action_space, spec): return action_space, spec -def _find_action_space(action_space): +def _find_action_space(action_space) -> str: if isinstance(action_space, TensorSpec): if isinstance(action_space, Composite): if "action" in action_space.keys(): diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 71d1a22e17b..039d5fc1c34 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -785,15 +785,20 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.shape != state_action_value.shape: + if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)): # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) - chosen_state_action_value = torch.gather( - state_action_value, -1, index=action - ).squeeze(-1) - else: + chosen_state_action_value = torch.vmap( + lambda state_action_value, action: torch.gather( + state_action_value, -1, index=action + ).squeeze(-1), + (0, None), + )(state_action_value, action) + elif self.action_space == "one_hot": action = action.to(torch.float) chosen_state_action_value = (state_action_value * action).sum(-1) + else: + raise RuntimeError(f"Unknown action space {self.action_space}.") min_Q, _ = torch.min(chosen_state_action_value, dim=0) if log_prob.shape != min_Q.shape: raise RuntimeError( @@ -828,15 +833,22 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.shape != state_action_value.shape: + if action.ndim < ( + state_action_value.ndim - (td_q.ndim - tensordict.ndim) + ): # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) - chosen_state_action_value = torch.gather( - state_action_value, -1, index=action - ).squeeze(-1) - else: + chosen_state_action_value = torch.vmap( + lambda state_action_value, action: torch.gather( + state_action_value, -1, index=action + ).squeeze(-1), + (0, None), + )(state_action_value, action) + elif self.action_space == "one_hot": action = action.to(torch.float) chosen_state_action_value = (state_action_value * action).sum(-1) + else: + raise RuntimeError(f"Unknown action space {self.action_space}.") min_Q, _ = torch.min(chosen_state_action_value, dim=0) # state value td_copy = tensordict.select(*self.value_network.in_keys, strict=False) @@ -863,13 +875,20 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.shape != state_action_value.shape: + if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)): # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) - pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1) - else: + pred_val = torch.vmap( + lambda state_action_value, action: torch.gather( + state_action_value, -1, index=action + ).squeeze(-1), + (0, None), + )(state_action_value, action) + elif self.action_space == "one_hot": action = action.to(torch.float) pred_val = (state_action_value * action).sum(-1) + else: + raise RuntimeError(f"Unknown action space {self.action_space}.") td_error = (pred_val - target_value.expand_as(pred_val)).pow(2) loss_qval = distance_loss(