diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 167a14e8796..0c9fdebac36 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -8,7 +8,6 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import Composite from torchrl.data.tensor_specs import CategoricalBox from torchrl.envs import ( CatFrames, @@ -94,12 +93,12 @@ def make_ppo_modules_pixels(proof_environment, device): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, CategoricalBox): - num_outputs = proof_environment.action_spec.space.n + if isinstance(proof_environment.single_action_spec.space, CategoricalBox): + num_outputs = proof_environment.single_action_spec.space.n distribution_class = OneHotCategorical distribution_kwargs = {} else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape + num_outputs = proof_environment.single_action_spec.shape distribution_class = TanhNormal distribution_kwargs = { "low": proof_environment.action_spec_unbatched.space.low.to(device), @@ -153,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment, device): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec.to(device)), + spec=proof_environment.single_full_action_spec.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 8606506da15..89118927427 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -9,7 +9,6 @@ import torch.optim from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule -from torchrl.data import Composite from torchrl.envs import ( ClipTransform, DoubleToFloat, @@ -55,7 +54,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): input_shape = proof_environment.observation_spec["observation"].shape # Define policy output distribution class - num_outputs = proof_environment.action_spec.shape[-1] + num_outputs = proof_environment.single_action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { "low": proof_environment.action_spec_unbatched.space.low.to(device), @@ -83,7 +82,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.action_spec.shape[-1], device=device + proof_environment.single_action_spec.shape[-1], device=device ), ) @@ -95,7 +94,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=Composite(action=proof_environment.action_spec.to(device)), + spec=proof_environment.single_full_action_spec.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index ed0ca5476c5..ffc3de9d941 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -298,7 +298,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"): def make_cql_modules_state(model_cfg, proof_environment): - action_spec = proof_environment.action_spec + action_spec = proof_environment.single_action_spec actor_net_kwargs = { "num_cells": model_cfg.hidden_sizes, diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 8093617ba9e..6ac058b9843 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -6,15 +6,18 @@ This is a self-contained example of an offline Decision Transformer training script. The helper functions are coded in the utils.py associated with this script. """ + from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule +from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.libs.gym import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -67,58 +70,77 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create policy model - actor = make_dt_model(cfg) - policy = actor.to(model_device) + actor = make_dt_model(cfg, device=model_device) # Create loss - loss_module = make_dt_loss(cfg.loss, actor) + loss_module = make_dt_loss(cfg.loss, actor, device=model_device) # Create optimizer transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module) # Create inference policy inference_policy = DecisionTransformerInferenceWrapper( - policy=policy, + policy=actor, inference_context=cfg.env.inference_context, - ).to(model_device) + device=model_device, + ) inference_policy.set_tensor_keys( observation="observation_cat", action="action_cat", return_to_go="return_to_go_cat", ) - pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) - pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps clip_grad = cfg.optim.clip_grad - eval_steps = cfg.logger.eval_steps - pretrain_log_interval = cfg.logger.pretrain_log_interval - reward_scaling = cfg.env.reward_scaling - torchrl_logger.info(" ***Pretraining*** ") - # Pretraining - start_time = time.time() - for i in range(pretrain_gradient_steps): - pbar.update(1) - - # Sample data - data = offline_buffer.sample() + def update(data: TensorDict) -> TensorDict: + transformer_optim.zero_grad(set_to_none=True) # Compute loss - loss_vals = loss_module(data.to(model_device)) + loss_vals = loss_module(data) transformer_loss = loss_vals["loss"] - transformer_optim.zero_grad() - torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad) transformer_loss.backward() + torch.nn.utils.clip_grad_norm_(actor.parameters(), clip_grad) transformer_optim.step() - scheduler.step() + return loss_vals + + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode, dynamic=True) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + eval_steps = cfg.logger.eval_steps + pretrain_log_interval = cfg.logger.pretrain_log_interval + reward_scaling = cfg.env.reward_scaling + torchrl_logger.info(" ***Pretraining*** ") + # Pretraining + pbar = tqdm.tqdm(range(pretrain_gradient_steps)) + for i in pbar: + # Sample data + with timeit("rb - sample"): + data = offline_buffer.sample().to(model_device) + with timeit("update"): + loss_vals = update(data) + scheduler.step() # Log metrics to_log = {"train/loss": loss_vals["loss"]} # Evaluation - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): if i % pretrain_log_interval == 0: eval_td = test_env.rollout( max_steps=eval_steps, @@ -129,13 +151,17 @@ def main(cfg: "DictConfig"): # noqa: F821 to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) + if i % 200 == 0: + to_log.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() + if logger is not None: log_metrics(logger, to_log, i) pbar.close() if not test_env.is_closed: test_env.close() - torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/sota-implementations/decision_transformer/dt_config.yaml b/sota-implementations/decision_transformer/dt_config.yaml index 4805785a62c..b0070fa4377 100644 --- a/sota-implementations/decision_transformer/dt_config.yaml +++ b/sota-implementations/decision_transformer/dt_config.yaml @@ -55,7 +55,12 @@ optim: # loss loss: loss_function: "l2" - + +compile: + compile: False + compile_mode: + cudagraphs: False + # transformer model transformer: n_embd: 128 diff --git a/sota-implementations/decision_transformer/odt_config.yaml b/sota-implementations/decision_transformer/odt_config.yaml index eec2b455fb3..5d82cd75bef 100644 --- a/sota-implementations/decision_transformer/odt_config.yaml +++ b/sota-implementations/decision_transformer/odt_config.yaml @@ -42,6 +42,7 @@ replay_buffer: # optimizer optim: + optimizer: lamb device: null lr: 1.0e-4 weight_decay: 5.0e-4 @@ -56,6 +57,11 @@ loss: alpha_init: 0.1 target_entropy: auto +compile: + compile: False + compile_mode: + cudagraphs: False + # transformer model transformer: n_embd: 512 diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 3577217f296..9f3ec5f8134 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -9,14 +9,15 @@ from __future__ import annotations import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule +from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.libs.gym import set_gym_backend - from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from torchrl.record import VideoRecorder @@ -65,8 +66,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create policy model - actor = make_odt_model(cfg) - policy = actor.to(model_device) + policy = make_odt_model(cfg, device=model_device) # Create loss loss_module = make_odt_loss(cfg.loss, policy) @@ -80,13 +80,46 @@ def main(cfg: "DictConfig"): # noqa: F821 inference_policy = DecisionTransformerInferenceWrapper( policy=policy, inference_context=cfg.env.inference_context, - ).to(model_device) + device=model_device, + ) inference_policy.set_tensor_keys( observation="observation_cat", action="action_cat", return_to_go="return_to_go_cat", ) + def update(data): + transformer_optim.zero_grad(set_to_none=True) + temperature_optim.zero_grad(set_to_none=True) + # Compute loss + loss_vals = loss_module(data.to(model_device)) + transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"] + temperature_loss = loss_vals["loss_alpha"] + + (temperature_loss + transformer_loss).backward() + torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad) + + transformer_optim.step() + temperature_optim.step() + + return loss_vals.detach() + + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + compile_mode = "default" + update = torch.compile(update, mode=compile_mode, dynamic=False) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + if cfg.optim.optimizer == "lamb": + raise ValueError( + "cudagraphs isn't compatible with the Lamb optimizer. Use optim.optimizer=Adam instead." + ) + update = CudaGraphModule(update, warmup=50) + pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps @@ -100,35 +133,29 @@ def main(cfg: "DictConfig"): # noqa: F821 start_time = time.time() for i in range(pretrain_gradient_steps): pbar.update(1) - # Sample data - data = offline_buffer.sample() - # Compute loss - loss_vals = loss_module(data.to(model_device)) - transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"] - temperature_loss = loss_vals["loss_alpha"] - - transformer_optim.zero_grad() - torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad) - transformer_loss.backward() - transformer_optim.step() + with timeit("sample"): + # Sample data + data = offline_buffer.sample() - temperature_optim.zero_grad() - temperature_loss.backward() - temperature_optim.step() + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss_vals = update(data.to(model_device)) scheduler.step() # Log metrics to_log = { - "train/loss_log_likelihood": loss_vals["loss_log_likelihood"].item(), - "train/loss_entropy": loss_vals["loss_entropy"].item(), - "train/loss_alpha": loss_vals["loss_alpha"].item(), - "train/alpha": loss_vals["alpha"].item(), - "train/entropy": loss_vals["entropy"].item(), + "train/loss_log_likelihood": loss_vals["loss_log_likelihood"], + "train/loss_entropy": loss_vals["loss_entropy"], + "train/loss_alpha": loss_vals["loss_alpha"], + "train/alpha": loss_vals["alpha"], + "train/entropy": loss_vals["entropy"], } # Evaluation - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): inference_policy.eval() if i % pretrain_log_interval == 0: eval_td = test_env.rollout( @@ -143,6 +170,11 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) + if i % 200 == 0: + to_log.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() + if logger is not None: log_metrics(logger, to_log, i) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 6bc1946b0a4..5f14734addd 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -4,6 +4,9 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import os +from pathlib import Path + import torch.nn import torch.optim @@ -156,6 +159,7 @@ def make_env(): obs_std, train, ) + env.start() return env @@ -262,6 +266,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): direct_download=True, prefetch=4, writer=RoundRobinWriter(), + root=Path(os.environ["HOME"]) / ".cache" / "torchrl" / "data" / "d4rl", ) # since we're not extending the data, adding keys can only be done via @@ -335,14 +340,14 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): # ----- -def make_odt_model(cfg): +def make_odt_model(cfg, device: torch.device | None = None) -> TensorDictModule: env_cfg = cfg.env proof_environment = make_transformed_env( make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1 ) - action_spec = proof_environment.action_spec - for key, value in proof_environment.observation_spec.items(): + action_spec = proof_environment.action_spec_unbatched + for key, value in proof_environment.observation_spec_unbatched.items(): if key == "observation": state_dim = value.shape[-1] in_keys = [ @@ -355,6 +360,7 @@ def make_odt_model(cfg): state_dim=state_dim, action_dim=action_spec.shape[-1], transformer_config=cfg.transformer, + device=device, ) actor_module = TensorDictModule( @@ -366,7 +372,13 @@ def make_odt_model(cfg): ], ) dist_class = TanhNormal - dist_kwargs = {"low": -1.0, "high": 1.0, "tanh_loc": False, "upscale": 5.0} + dist_kwargs = { + "low": -torch.ones((), device=device), + "high": torch.ones((), device=device), + "tanh_loc": False, + "upscale": torch.full((), 5, device=device), + # "safe_tanh": not cfg.compile.compile, + } actor = ProbabilisticActor( spec=action_spec, @@ -383,21 +395,18 @@ def make_odt_model(cfg): with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] - actor(td) + actor(td.to(device)) return actor -def make_dt_model(cfg): +def make_dt_model(cfg, device: torch.device | None = None): env_cfg = cfg.env proof_environment = make_transformed_env( make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1 ) action_spec = proof_environment.action_spec_unbatched - for key, value in proof_environment.observation_spec.items(): - if key == "observation": - state_dim = value.shape[-1] in_keys = [ "observation_cat", "action_cat", @@ -405,9 +414,10 @@ def make_dt_model(cfg): ] actor_net = DTActor( - state_dim=state_dim, + state_dim=proof_environment.observation_spec_unbatched["observation"].shape[-1], action_dim=action_spec.shape[-1], transformer_config=cfg.transformer, + device=device, ) actor_module = TensorDictModule( @@ -417,12 +427,13 @@ def make_dt_model(cfg): ) dist_class = TanhDelta dist_kwargs = { - "low": action_spec.space.low, - "high": action_spec.space.high, + "low": action_spec.space.low.to(device), + "high": action_spec.space.high.to(device), + "safe": not cfg.compile.compile, } actor = ProbabilisticActor( - spec=action_spec, + spec=action_spec.to(device), in_keys=["param"], out_keys=["action"], module=actor_module, @@ -434,9 +445,10 @@ def make_dt_model(cfg): # init the lazy layers with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = proof_environment.rollout(max_steps=100) + td = proof_environment.fake_tensordict() + td = td.expand((100, *td.shape)) td["action"] = td["next", "action"] - actor(td) + actor(td.to(device)) return actor @@ -456,29 +468,43 @@ def make_odt_loss(loss_cfg, actor_network): return loss -def make_dt_loss(loss_cfg, actor_network): +def make_dt_loss(loss_cfg, actor_network, device: torch.device | None = None): loss = DTLoss( actor_network, loss_function=loss_cfg.loss_function, + device=device, ) loss.set_keys(action_target="action_cat") return loss def make_odt_optimizer(optim_cfg, loss_module): - dt_optimizer = Lamb( - loss_module.actor_network_params.flatten_keys().values(), - lr=optim_cfg.lr, - weight_decay=optim_cfg.weight_decay, - eps=1.0e-8, - ) + if optim_cfg.optimizer == "lamb": + dt_optimizer = Lamb( + loss_module.actor_network_params.flatten_keys().values(), + lr=torch.as_tensor( + optim_cfg.lr, device=next(loss_module.parameters()).device + ), + weight_decay=optim_cfg.weight_decay, + eps=1.0e-8, + ) + elif optim_cfg.optimizer == "adam": + dt_optimizer = torch.optim.Adam( + loss_module.actor_network_params.flatten_keys().values(), + lr=torch.as_tensor( + optim_cfg.lr, device=next(loss_module.parameters()).device + ), + weight_decay=optim_cfg.weight_decay, + eps=1.0e-8, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR( dt_optimizer, lambda steps: min((steps + 1) / optim_cfg.warmup_steps, 1) ) log_temp_optimizer = torch.optim.Adam( [loss_module.log_alpha], - lr=1e-4, + lr=torch.as_tensor(1e-4, device=next(loss_module.parameters()).device), betas=[0.9, 0.999], ) @@ -488,7 +514,7 @@ def make_odt_optimizer(optim_cfg, loss_module): def make_dt_optimizer(optim_cfg, loss_module): dt_optimizer = torch.optim.Adam( loss_module.actor_network_params.flatten_keys().values(), - lr=optim_cfg.lr, + lr=torch.as_tensor(optim_cfg.lr), weight_decay=optim_cfg.weight_decay, eps=1.0e-8, ) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 9a99d86150e..9b57e0c3c75 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -475,12 +475,12 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): spec=Composite( **{ "loc": Unbounded( - proof_environment.action_spec.shape, - device=proof_environment.action_spec.device, + proof_environment.single_action_spec.shape, + device=proof_environment.single_action_spec.device, ), "scale": Unbounded( - proof_environment.action_spec.shape, - device=proof_environment.action_spec.device, + proof_environment.single_action_spec.shape, + device=proof_environment.single_action_spec.device, ), } ), @@ -491,7 +491,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): default_interaction_type=InteractionType.RANDOM, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=Composite(**{action_key: proof_environment.action_spec}), + spec=Composite(**{action_key: proof_environment.single_action_spec}), ), ) return actor_simulator @@ -532,10 +532,10 @@ def _dreamer_make_actor_real( spec=Composite( **{ "loc": Unbounded( - proof_environment.action_spec.shape, + proof_environment.single_action_spec.shape, ), "scale": Unbounded( - proof_environment.action_spec.shape, + proof_environment.single_action_spec.shape, ), } ), @@ -546,7 +546,7 @@ def _dreamer_make_actor_real( default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}), + spec=proof_environment.single_full_action_spec.to("cpu"), ), ), SafeModule( diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index 5669d93ce85..50dfe9e4c45 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -8,7 +8,6 @@ import torch.optim from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule -from torchrl.data import CompositeSpec from torchrl.envs import ( ClipTransform, DoubleToFloat, @@ -50,7 +49,7 @@ def make_ppo_models_state(proof_environment): input_shape = proof_environment.observation_spec["observation"].shape # Define policy output distribution class - num_outputs = proof_environment.action_spec.shape[-1] + num_outputs = proof_environment.single_action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { "low": proof_environment.action_spec_unbatched.space.low, @@ -76,7 +75,7 @@ def make_ppo_models_state(proof_environment): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.action_spec.shape[-1], scale_lb=1e-8 + proof_environment.single_action_spec.shape[-1], scale_lb=1e-8 ), ) @@ -88,7 +87,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=proof_environment.single_full_action_spec, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py index 248a98a389d..dca07a33570 100644 --- a/sota-implementations/impala/utils.py +++ b/sota-implementations/impala/utils.py @@ -7,7 +7,6 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import Composite from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -70,7 +69,7 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - num_outputs = proof_environment.action_spec.space.n + num_outputs = proof_environment.single_action_spec.space.n distribution_class = OneHotCategorical distribution_kwargs = {} @@ -118,7 +117,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec), + spec=proof_environment.single_full_action_spec, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index d7d9e1a2d2f..c2b01c67c4d 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -249,7 +249,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): def make_iql_modules_state(model_cfg, proof_environment): - action_spec = proof_environment.action_spec + action_spec = proof_environment.single_action_spec actor_net_kwargs = { "num_cells": model_cfg.hidden_sizes, diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 9be451331d8..885754f5ac1 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -7,7 +7,6 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import Composite from torchrl.data.tensor_specs import CategoricalBox from torchrl.envs import ( CatFrames, @@ -93,12 +92,12 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, CategoricalBox): - num_outputs = proof_environment.action_spec.space.n + if isinstance(proof_environment.single_action_spec.space, CategoricalBox): + num_outputs = proof_environment.single_action_spec.space.n distribution_class = OneHotCategorical distribution_kwargs = {} else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape + num_outputs = proof_environment.single_action_spec.shape distribution_class = TanhNormal distribution_kwargs = { "low": proof_environment.action_spec_unbatched.space.low, @@ -149,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec), + spec=proof_environment.single_full_action_spec, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index ebbc6f7916d..50dfe9e4c45 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -8,7 +8,6 @@ import torch.optim from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule -from torchrl.data import Composite from torchrl.envs import ( ClipTransform, DoubleToFloat, @@ -50,7 +49,7 @@ def make_ppo_models_state(proof_environment): input_shape = proof_environment.observation_spec["observation"].shape # Define policy output distribution class - num_outputs = proof_environment.action_spec.shape[-1] + num_outputs = proof_environment.single_action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { "low": proof_environment.action_spec_unbatched.space.low, @@ -76,7 +75,7 @@ def make_ppo_models_state(proof_environment): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.action_spec.shape[-1], scale_lb=1e-8 + proof_environment.single_action_spec.shape[-1], scale_lb=1e-8 ), ) @@ -88,7 +87,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=Composite(action=proof_environment.action_spec), + spec=proof_environment.single_full_action_spec, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index e34f1be8ff9..c44be57cca6 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -695,10 +695,15 @@ def __init__( ): minmax_msg = "high value has been found to be equal or less than low value" if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): - if not (high > low).all(): - raise ValueError(minmax_msg) + if is_dynamo_compiling(): + assert (high > low).all() + else: + if not (high > low).all(): + raise ValueError(minmax_msg) elif isinstance(high, Number) and isinstance(low, Number): - if high <= low: + if is_dynamo_compiling(): + assert high > low + elif high <= low: raise ValueError(minmax_msg) else: if not all(high > low): diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 8a20ad2eba8..cb35521f26c 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -7,6 +7,7 @@ import dataclasses import importlib +from contextlib import nullcontext from dataclasses import dataclass from typing import Any @@ -92,9 +93,6 @@ def __init__( config: dict | DTConfig = None, device: torch.device | None = None, ): - if device is not None: - with torch.device(device): - return self.__init__(state_dim, action_dim, config) if not _has_transformers: raise ImportError( @@ -117,28 +115,29 @@ def __init__( super(DecisionTransformer, self).__init__() - gpt_config = transformers.GPT2Config( - n_embd=config["n_embd"], - n_layer=config["n_layer"], - n_head=config["n_head"], - n_inner=config["n_inner"], - activation_function=config["activation"], - n_positions=config["n_positions"], - resid_pdrop=config["resid_pdrop"], - attn_pdrop=config["attn_pdrop"], - vocab_size=1, - ) - self.state_dim = state_dim - self.action_dim = action_dim - self.hidden_size = config["n_embd"] + with torch.device(device) if device is not None else nullcontext(): + gpt_config = transformers.GPT2Config( + n_embd=config["n_embd"], + n_layer=config["n_layer"], + n_head=config["n_head"], + n_inner=config["n_inner"], + activation_function=config["activation"], + n_positions=config["n_positions"], + resid_pdrop=config["resid_pdrop"], + attn_pdrop=config["attn_pdrop"], + vocab_size=1, + ) + self.state_dim = state_dim + self.action_dim = action_dim + self.hidden_size = config["n_embd"] - self.transformer = GPT2Model(config=gpt_config) + self.transformer = GPT2Model(config=gpt_config) - self.embed_return = torch.nn.Linear(1, self.hidden_size) - self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size) - self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size) + self.embed_return = torch.nn.Linear(1, self.hidden_size) + self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size) + self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size) - self.embed_ln = nn.LayerNorm(self.hidden_size) + self.embed_ln = nn.LayerNorm(self.hidden_size) def forward( self, @@ -162,13 +161,9 @@ def forward( # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) # which works nice in an autoregressive sense since states predict actions - stacked_inputs = ( - torch.stack( - (returns_embeddings, state_embeddings, action_embeddings), dim=-3 - ) - .permute(*range(len(batch_size)), -2, -3, -1) - .reshape(*batch_size, 3 * seq_length, self.hidden_size) - ) + stacked_inputs = torch.stack( + (returns_embeddings, state_embeddings, action_embeddings), dim=-2 + ).reshape(*batch_size, 3 * seq_length, self.hidden_size) stacked_inputs = self.embed_ln(stacked_inputs) # we feed in the input embeddings (not word indices as in NLP) to the model @@ -179,9 +174,7 @@ def forward( # reshape x so that the second dimension corresponds to the original # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t - x = x.reshape(*batch_size, seq_length, 3, self.hidden_size).permute( - *range(len(batch_size)), -2, -3, -1 - ) + x = x.reshape(*batch_size, seq_length, 3, self.hidden_size).transpose(-3, -2) if batch_size_orig is batch_size: return x[..., 1, :, :] # only state tokens - return x[..., 1, :, :].view(*batch_size_orig, *x.shape[-2:]) + return x[..., 1, :, :].reshape(*batch_size_orig, *x.shape[-2:]) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index cad4065f54a..9c25636091d 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1558,6 +1558,7 @@ def __init__( state_dim=state_dim, action_dim=action_dim, config=transformer_config, + device=device, ) self.action_layer_mean = nn.Linear( transformer_config["n_embd"], action_dim, device=device @@ -1656,6 +1657,7 @@ def __init__( state_dim=state_dim, action_dim=action_dim, config=transformer_config, + device=device, ) self.action_layer = nn.Linear( transformer_config["n_embd"], action_dim, device=device diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 1b1f0aa4e0b..013e28713bf 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -214,7 +214,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets - tensordict = tensordict.clone(False) + tensordict = tensordict.copy() target_actions = tensordict.get(self.tensor_keys.action_target) if target_actions.requires_grad: raise RuntimeError("target action cannot be part of a graph.")