From bfadce90bf069fddd934ffe78e8f9876409b9bbe Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 23 Apr 2024 15:18:23 +0200 Subject: [PATCH] [BugFix,Refactor] Dreamer refactor (#1918) Co-authored-by: Vincent Moens Co-authored-by: Vincent Moens --- .../linux_examples/scripts/run_test.sh | 48 +- docs/source/reference/envs.rst | 1 + sota-implementations/dreamer/config.yaml | 105 +- sota-implementations/dreamer/dreamer.py | 551 ++++----- sota-implementations/dreamer/dreamer_utils.py | 1079 +++++++++++------ torchrl/_utils.py | 8 +- torchrl/data/replay_buffers/storages.py | 43 +- torchrl/envs/__init__.py | 2 +- torchrl/envs/common.py | 1 + torchrl/envs/model_based/__init__.py | 1 + torchrl/envs/model_based/common.py | 17 +- torchrl/envs/model_based/dreamer.py | 45 +- torchrl/envs/transforms/transforms.py | 60 +- torchrl/envs/utils.py | 45 +- torchrl/modules/__init__.py | 1 + torchrl/modules/models/__init__.py | 9 +- torchrl/modules/models/model_based.py | 95 +- torchrl/objectives/dreamer.py | 39 +- torchrl/record/loggers/csv.py | 4 + 19 files changed, 1252 insertions(+), 902 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 1d11d481e3c..4587be88ddc 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -167,19 +167,17 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di # logger.record_video=True \ # logger.record_frames=4 \ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ - total_frames=200 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=200 \ - num_workers=4 \ - env_per_collector=2 \ - collector_device=cuda:0 \ - model_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 \ - rssm_hidden_dim=17 + collector.total_frames=200 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=200 \ + env.n_parallel_envs=4 \ + optimization.optim_steps_per_batch=1 \ + logger.video=True \ + logger.backend=csv \ + replay_buffer.buffer_size=120 \ + replay_buffer.batch_size=24 \ + replay_buffer.batch_length=12 \ + networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -223,19 +221,17 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq # With single envs python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ - total_frames=200 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=200 \ - num_workers=2 \ - env_per_collector=1 \ - collector_device=cuda:0 \ - model_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 \ - rssm_hidden_dim=17 + collector.total_frames=200 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=200 \ + env.n_parallel_envs=1 \ + optimization.optim_steps_per_batch=1 \ + logger.backend=csv \ + logger.video=True \ + replay_buffer.buffer_size=120 \ + replay_buffer.batch_size=24 \ + replay_buffer.batch_length=12 \ + networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 519865d1d00..fe72ea89a56 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -799,6 +799,7 @@ Domain-specific ModelBasedEnvBase model_based.dreamer.DreamerEnv + model_based.dreamer.DreamerDecoder Libraries diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index e81d74e08fa..ab101e8486a 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -1,39 +1,66 @@ -env_name: cheetah -env_task: run -env_library: dm_control -catframes: 1 -async_collection: True -record_video: 0 -frame_skip: 2 -batch_size: 50 -batch_length: 50 -total_frames: 5000000 -world_model_lr: 6e-4 -actor_value_lr: 8e-5 -from_pixels: True -# we want 50 frames / traj in the replay buffer. Given the frame_skip=2 this makes each traj 100 steps long -env_per_collector: 8 -num_workers: 8 -collector_device: cuda:1 -model_device: cuda:0 -frames_per_batch: 800 -optim_steps_per_batch: 80 -record_interval: 30 -max_frames_per_traj: 1000 -record_frames: 1000 -batch_transform: 1 -state_dim: 30 -rssm_hidden_dim: 200 -grad_clip: 100 -grayscale: False -image_size : 64 -buffer_size: 20000 -init_env_steps: 1000 -init_random_frames: 5000 -logger: csv -offline_logging: False -project_name: torchrl_example_dreamer -normalize_rewards_online: True -normalize_rewards_online_scale: 5.0 -normalize_rewards_online_decay: 0.99999 -reward_scaling: 1.0 +env: + name: cheetah + task: run + seed: 0 + backend: dm_control + frame_skip: 2 + from_pixels: True + grayscale: False + image_size : 64 + horizon: 500 + n_parallel_envs: 8 + device: + _target_: dreamer_utils._default_device + device: null + +collector: + total_frames: 5_000_000 + init_random_frames: 3000 + frames_per_batch: 1000 + device: + _target_: dreamer_utils._default_device + device: null + +optimization: + train_every: 1000 + grad_clip: 100 + + world_model_lr: 6e-4 + actor_lr: 8e-5 + value_lr: 8e-5 + kl_scale: 1.0 + free_nats: 3.0 + optim_steps_per_batch: 80 + gamma: 0.99 + lmbda: 0.95 + imagination_horizon: 15 + compile: False + compile_backend: inductor + use_autocast: True + +networks: + exploration_noise: 0.3 + device: + _target_: dreamer_utils._default_device + device: null + state_dim: 30 + rssm_hidden_dim: 200 + hidden_dim: 400 + activation: "elu" + + +replay_buffer: + batch_size: 2500 + buffer_size: 1000000 + batch_length: 50 + scratch_dir: null + +logger: + backend: wandb + project: dreamer-v1 + exp_name: ${env.name}-${env.task}-${env.seed} + mode: online + # eval interval, in collection counts + eval_iter: 10 + eval_rollout_steps: 500 + video: False diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index a1d8c8aec4e..e7b346b2b22 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -1,267 +1,194 @@ -import dataclasses -from pathlib import Path +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import contextlib +import time import hydra import torch import torch.cuda import tqdm from dreamer_utils import ( - call_record, - EnvConfig, - grad_norm, - make_recorder_env, - parallel_env_constructor, - transformed_env_constructor, + dump_video, + log_metrics, + make_collector, + make_dreamer, + make_environments, + make_replay_buffer, ) -from hydra.core.config_store import ConfigStore +from hydra.utils import instantiate -# float16 -from torch.cuda.amp import autocast, GradScaler +# mixed precision training +from torch.cuda.amp import GradScaler from torch.nn.utils import clip_grad_norm_ -from torchrl._utils import logger as torchrl_logger +from torchrl._utils import logger as torchrl_logger, timeit +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import RSSMRollout -from torchrl.envs import EnvBase -from torchrl.modules.tensordict_module.exploration import ( - AdditiveGaussianWrapper, - OrnsteinUhlenbeckProcessWrapper, -) from torchrl.objectives.dreamer import ( DreamerActorLoss, DreamerModelLoss, DreamerValueLoss, ) from torchrl.record.loggers import generate_exp_name, get_logger -from torchrl.trainers.helpers.collectors import ( - make_collector_offpolicy, - OffPolicyCollectorConfig, -) -from torchrl.trainers.helpers.envs import ( - correct_for_frame_skip, - initialize_observation_norm_transforms, - retrieve_observation_norms_state_dict, -) -from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.models import DreamerConfig, make_dreamer -from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig -from torchrl.trainers.helpers.trainers import TrainerConfig -from torchrl.trainers.trainers import Recorder, RewardNormalizer - -config_fields = [ - (config_field.name, config_field.type, config_field) - for config_cls in ( - OffPolicyCollectorConfig, - EnvConfig, - LoggerConfig, - ReplayArgsConfig, - DreamerConfig, - TrainerConfig, - ) - for config_field in dataclasses.fields(config_cls) -] -Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) -cs = ConfigStore.instance() -cs.store(name="config", node=Config) - - -def retrieve_stats_from_state_dict(obs_norm_state_dict): - return { - "loc": obs_norm_state_dict["loc"], - "scale": obs_norm_state_dict["scale"], - } @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 + # cfg = correct_for_frame_skip(cfg) + + device = torch.device(instantiate(cfg.networks.device)) + + # Create logger + exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="dreamer_logging", + experiment_name=exp_name, + wandb_kwargs={"mode": cfg.logger.mode}, # "config": cfg}, + ) - cfg = correct_for_frame_skip(cfg) - - if not isinstance(cfg.reward_scaling, float): - cfg.reward_scaling = 1.0 - - if torch.cuda.is_available() and cfg.model_device == "": - device = torch.device("cuda:0") - elif cfg.model_device: - device = torch.device(cfg.model_device) - else: - device = torch.device("cpu") - torchrl_logger.info(f"Using device {device}") - - exp_name = generate_exp_name("Dreamer", cfg.exp_name) - logger = get_logger( - logger_type=cfg.logger, - logger_name="dreamer", - experiment_name=exp_name, - wandb_kwargs={ - "project": cfg.project_name, - "group": f"Dreamer_{cfg.env_name}", - "offline": cfg.offline_logging, - }, - ) - video_tag = f"Dreamer_{cfg.env_name}_policy_test" if cfg.record_video else "" - - key, init_env_steps, stats = None, None, None - if not cfg.vecnorm and cfg.norm_stats: - if not hasattr(cfg, "init_env_steps"): - raise AttributeError("init_env_steps missing from arguments.") - key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") - init_env_steps = cfg.init_env_steps - stats = {"loc": None, "scale": None} - elif cfg.from_pixels: - stats = {"loc": 0.5, "scale": 0.5} - proof_env = transformed_env_constructor( - cfg=cfg, use_env_creator=False, stats=stats - )() - initialize_observation_norm_transforms( - proof_environment=proof_env, num_iter=init_env_steps, key=key + train_env, test_env = make_environments( + cfg=cfg, + parallel_envs=cfg.env.n_parallel_envs, + logger=logger, ) - _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] - proof_env.close() - # Create the different components of dreamer - world_model, model_based_env, actor_model, value_model, policy = make_dreamer( - obs_norm_state_dict=obs_norm_state_dict, + # Make dreamer components + action_key = "action" + value_key = "state_value" + ( + world_model, + model_based_env, + model_based_env_eval, + actor_model, + value_model, + policy, + ) = make_dreamer( cfg=cfg, device=device, - use_decoder_in_env=True, - action_key="action", - value_key="state_value", - proof_environment=transformed_env_constructor( - cfg, stats={"loc": 0.0, "scale": 1.0} - )(), + action_key=action_key, + value_key=value_key, + use_decoder_in_env=cfg.logger.video, + logger=logger, ) - - # reward normalization - if cfg.normalize_rewards_online: - # if used the running statistics of the rewards are computed and the - # rewards used for training will be normalized based on these. - reward_normalizer = RewardNormalizer( - scale=cfg.normalize_rewards_online_scale, - decay=cfg.normalize_rewards_online_decay, - ) - else: - reward_normalizer = None - # Losses world_model_loss = DreamerModelLoss(world_model) + # Adapt loss keys to gym backend + if cfg.env.backend == "gym": + world_model_loss.set_keys(pixels="observation", reco_pixels="reco_observation") + actor_loss = DreamerActorLoss( actor_model, value_model, model_based_env, - imagination_horizon=cfg.imagination_horizon, + imagination_horizon=cfg.optimization.imagination_horizon, + discount_loss=True, ) - value_loss = DreamerValueLoss(value_model) - - # Exploration noise to be added to the actions - if cfg.exploration == "additive_gaussian": - exploration_policy = AdditiveGaussianWrapper( - policy, - sigma_init=0.3, - sigma_end=0.3, - ).to(device) - elif cfg.exploration == "ou_exploration": - exploration_policy = OrnsteinUhlenbeckProcessWrapper( - policy, - annealing_num_steps=cfg.total_frames, - ).to(device) - elif cfg.exploration == "": - exploration_policy = policy.to(device) - - action_dim_gsde, state_dim_gsde = None, None - create_env_fn = parallel_env_constructor( - cfg=cfg, - obs_norm_state_dict=obs_norm_state_dict, - action_dim_gsde=action_dim_gsde, - state_dim_gsde=state_dim_gsde, - ) - if isinstance(create_env_fn, EnvBase): - create_env_fn.rollout(2) - else: - create_env_fn().rollout(2) - - # Create the replay buffer - collector = make_collector_offpolicy( - make_env=create_env_fn, - actor_model_explore=exploration_policy, - cfg=cfg, + actor_loss.make_value_estimator( + gamma=cfg.optimization.gamma, lmbda=cfg.optimization.lmbda ) - torchrl_logger.info(f"collector: {collector}") - - replay_buffer = make_replay_buffer("cpu", cfg) - - record = Recorder( - record_frames=cfg.record_frames, - frame_skip=cfg.frame_skip, - policy_exploration=policy, - environment=make_recorder_env( - cfg=cfg, - video_tag=video_tag, - obs_norm_state_dict=obs_norm_state_dict, - logger=logger, - create_env_fn=create_env_fn, - ), - record_interval=cfg.record_interval, - log_keys=cfg.recorder_log_keys, + value_loss = DreamerValueLoss( + value_model, discount_loss=True, gamma=cfg.optimization.gamma + ) + + # Make collector + collector = make_collector(cfg, train_env, policy) + + # Make replay buffer + batch_size = cfg.replay_buffer.batch_size + batch_length = cfg.replay_buffer.batch_length + buffer_size = cfg.replay_buffer.buffer_size + scratch_dir = cfg.replay_buffer.scratch_dir + replay_buffer = make_replay_buffer( + batch_size=batch_size, + batch_seq_len=batch_length, + buffer_size=buffer_size, + buffer_scratch_dir=scratch_dir, + device=device, + pixel_obs=cfg.env.from_pixels, + grayscale=cfg.env.grayscale, + image_size=cfg.env.image_size, + use_autocast=cfg.optimization.use_autocast, ) - final_seed = collector.set_seed(cfg.seed) - torchrl_logger.info(f"init seed: {cfg.seed}, final seed: {final_seed}") # Training loop collected_frames = 0 - pbar = tqdm.tqdm(total=cfg.total_frames) - path = Path("./log") - path.mkdir(exist_ok=True) + pbar = tqdm.tqdm(total=cfg.collector.total_frames) - # optimizers - world_model_opt = torch.optim.Adam(world_model.parameters(), lr=cfg.world_model_lr) - actor_opt = torch.optim.Adam(actor_model.parameters(), lr=cfg.actor_value_lr) - value_opt = torch.optim.Adam(value_model.parameters(), lr=cfg.actor_value_lr) + # Make optimizer + world_model_opt = torch.optim.Adam( + world_model.parameters(), lr=cfg.optimization.world_model_lr + ) + actor_opt = torch.optim.Adam(actor_model.parameters(), lr=cfg.optimization.actor_lr) + value_opt = torch.optim.Adam(value_model.parameters(), lr=cfg.optimization.value_lr) + + # Grad scaler for mixed precision training https://pytorch.org/docs/stable/amp.html + use_autocast = cfg.optimization.use_autocast + if use_autocast: + scaler1 = GradScaler() + scaler2 = GradScaler() + scaler3 = GradScaler() + + init_random_frames = cfg.collector.init_random_frames + optim_steps_per_batch = cfg.optimization.optim_steps_per_batch + grad_clip = cfg.optimization.grad_clip + eval_iter = cfg.logger.eval_iter + eval_rollout_steps = cfg.logger.eval_rollout_steps + + if cfg.optimization.compile: + torch._dynamo.config.capture_scalar_outputs = True + + torchrl_logger.info("Compiling") + backend = cfg.optimization.compile_backend + + def compile_rssms(module): + if isinstance(module, RSSMRollout) and not getattr( + module, "_compiled", False + ): + module._compiled = True + module.rssm_prior.module = torch.compile( + module.rssm_prior.module, backend=backend + ) + module.rssm_posterior.module = torch.compile( + module.rssm_posterior.module, backend=backend + ) - scaler1 = GradScaler() - scaler2 = GradScaler() - scaler3 = GradScaler() + world_model_loss.apply(compile_rssms) + t_collect_init = time.time() for i, tensordict in enumerate(collector): - cmpt = 0 - if reward_normalizer is not None: - reward_normalizer.update_reward_stats(tensordict) + t_collect = time.time() - t_collect_init + + t_preproc_init = time.time() pbar.update(tensordict.numel()) current_frames = tensordict.numel() collected_frames += current_frames - # Compared to the original paper, the replay buffer is not temporally - # sampled. We fill it with trajectories of length batch_length. - # To be closer to the paper, we would need to fill it with trajectories - # of length 1000 and then sample subsequences of length batch_length. - - tensordict = tensordict.reshape(-1, cfg.batch_length) + ep_reward = tensordict.get("episode_reward")[..., -1, 0] replay_buffer.extend(tensordict.cpu()) - logger.log_scalar( - "r_training", - tensordict["next", "reward"].mean().detach().item(), - step=collected_frames, - ) - - if (i % cfg.record_interval) == 0: - do_log = True - else: - do_log = False + t_preproc = time.time() - t_preproc_init - if collected_frames >= cfg.init_random_frames: - if i % cfg.record_interval == 0: - logger.log_scalar("cmpt", cmpt) - for j in range(cfg.optim_steps_per_batch): - cmpt += 1 + if collected_frames >= init_random_frames: + t_loss_actor = 0.0 + t_loss_critic = 0.0 + t_loss_model = 0.0 + for _ in range(optim_steps_per_batch): # sample from replay buffer - sampled_tensordict = replay_buffer.sample(cfg.batch_size).to( - device, non_blocking=True - ) - if reward_normalizer is not None: - sampled_tensordict = reward_normalizer.normalize_reward( - sampled_tensordict - ) + t_sample_init = time.time() + sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length) + t_sample = time.time() - t_sample_init + + t_loss_model_init = time.time() # update world model - with autocast(dtype=torch.float16): + with torch.autocast( + device_type=device.type, + dtype=torch.bfloat16, + ) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict ) @@ -270,113 +197,125 @@ def main(cfg: "DictConfig"): # noqa: F821 + model_loss_td["loss_model_reco"] + model_loss_td["loss_model_reward"] ) - # If we are logging videos, we keep some frames. - if ( - cfg.record_video - and (record._count + 1) % cfg.record_interval == 0 - ): - sampled_tensordict_save = ( - sampled_tensordict.select( - "next" "state", - "belief", - )[:4] - .detach() - .to_tensordict() - ) - else: - sampled_tensordict_save = None + world_model_opt.zero_grad() + if use_autocast: scaler1.scale(loss_world_model).backward() scaler1.unscale_(world_model_opt) - clip_grad_norm_(world_model.parameters(), cfg.grad_clip) + else: + loss_world_model.backward() + world_model_grad = clip_grad_norm_(world_model.parameters(), grad_clip) + if use_autocast: scaler1.step(world_model_opt) - if j == cfg.optim_steps_per_batch - 1 and do_log: - logger.log_scalar( - "loss_world_model", - loss_world_model.detach().item(), - step=collected_frames, - ) - logger.log_scalar( - "grad_world_model", - grad_norm(world_model_opt), - step=collected_frames, - ) - logger.log_scalar( - "loss_model_kl", - model_loss_td["loss_model_kl"].detach().item(), - step=collected_frames, - ) - logger.log_scalar( - "loss_model_reco", - model_loss_td["loss_model_reco"].detach().item(), - step=collected_frames, - ) - logger.log_scalar( - "loss_model_reward", - model_loss_td["loss_model_reward"].detach().item(), - step=collected_frames, - ) - world_model_opt.zero_grad() scaler1.update() + else: + world_model_opt.step() + t_loss_model += time.time() - t_loss_model_init # update actor network - with autocast(dtype=torch.float16): + t_loss_actor_init = time.time() + with torch.autocast( + device_type=device.type, dtype=torch.bfloat16 + ) if use_autocast else contextlib.nullcontext(): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) - scaler2.scale(actor_loss_td["loss_actor"]).backward() - scaler2.unscale_(actor_opt) - clip_grad_norm_(actor_model.parameters(), cfg.grad_clip) - scaler2.step(actor_opt) - if j == cfg.optim_steps_per_batch - 1 and do_log: - logger.log_scalar( - "loss_actor", - actor_loss_td["loss_actor"].detach().item(), - step=collected_frames, - ) - logger.log_scalar( - "grad_actor", - grad_norm(actor_opt), - step=collected_frames, - ) + actor_opt.zero_grad() - scaler2.update() + if use_autocast: + scaler2.scale(actor_loss_td["loss_actor"]).backward() + scaler2.unscale_(actor_opt) + else: + actor_loss_td["loss_actor"].backward() + actor_model_grad = clip_grad_norm_(actor_model.parameters(), grad_clip) + if use_autocast: + scaler2.step(actor_opt) + scaler2.update() + else: + actor_opt.step() + t_loss_actor += time.time() - t_loss_actor_init # update value network - with autocast(dtype=torch.float16): + t_loss_critic_init = time.time() + with torch.autocast( + device_type=device.type, dtype=torch.bfloat16 + ) if use_autocast else contextlib.nullcontext(): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) - scaler3.scale(value_loss_td["loss_value"]).backward() - scaler3.unscale_(value_opt) - clip_grad_norm_(value_model.parameters(), cfg.grad_clip) - scaler3.step(value_opt) - if j == cfg.optim_steps_per_batch - 1 and do_log: - logger.log_scalar( - "loss_value", - value_loss_td["loss_value"].detach().item(), - step=collected_frames, - ) - logger.log_scalar( - "grad_value", - grad_norm(value_opt), - step=collected_frames, - ) + value_opt.zero_grad() - scaler3.update() - if j == cfg.optim_steps_per_batch - 1: - do_log = False - - stats = retrieve_stats_from_state_dict(obs_norm_state_dict) - call_record( - logger, - record, - collected_frames, - sampled_tensordict_save, - stats, - model_based_env, - actor_model, - cfg, - ) - if cfg.exploration != "": - exploration_policy.step(current_frames) + if use_autocast: + scaler3.scale(value_loss_td["loss_value"]).backward() + scaler3.unscale_(value_opt) + else: + value_loss_td["loss_value"].backward() + critic_model_grad = clip_grad_norm_(value_model.parameters(), grad_clip) + if use_autocast: + scaler3.step(value_opt) + scaler3.update() + else: + value_opt.step() + t_loss_critic += time.time() - t_loss_critic_init + + metrics_to_log = {"reward": ep_reward.mean().item()} + if collected_frames >= init_random_frames: + loss_metrics = { + "loss_model_kl": model_loss_td["loss_model_kl"].item(), + "loss_model_reco": model_loss_td["loss_model_reco"].item(), + "loss_model_reward": model_loss_td["loss_model_reward"].item(), + "loss_actor": actor_loss_td["loss_actor"].item(), + "loss_value": value_loss_td["loss_value"].item(), + "world_model_grad": world_model_grad, + "actor_model_grad": actor_model_grad, + "critic_model_grad": critic_model_grad, + "t_loss_actor": t_loss_actor, + "t_loss_critic": t_loss_critic, + "t_loss_model": t_loss_model, + "t_sample": t_sample, + "t_preproc": t_preproc, + "t_collect": t_collect, + **timeit.todict(percall=False), + } + timeit.erase() + metrics_to_log.update(loss_metrics) + + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + + policy.step(current_frames) collector.update_policy_weights_() + # Evaluation + if (i % eval_iter) == 0: + # Real env + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_rollout = test_env.rollout( + eval_rollout_steps, + policy, + auto_cast_to_device=True, + break_when_any_done=True, + ) + test_env.apply(dump_video) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + eval_metrics = {"eval/reward": eval_reward} + if logger is not None: + log_metrics(logger, eval_metrics, collected_frames) + # Simulated env + if model_based_env_eval is not None: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_rollout = model_based_env_eval.rollout( + eval_rollout_steps, + policy, + auto_cast_to_device=True, + break_when_any_done=True, + auto_reset=False, + tensordict=eval_rollout[..., 0] + .exclude("next", "action") + .to(device), + ) + model_based_env_eval.apply(dump_video) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + eval_metrics = {"eval/simulated_reward": eval_reward} + if logger is not None: + log_metrics(logger, eval_metrics, collected_frames) + + t_collect_init = time.time() if __name__ == "__main__": diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 51593a33caa..ff14871b011 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -2,435 +2,720 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass, field as dataclass_field -from typing import Any, Callable, Optional, Sequence, Union - -from torchrl.data import UnboundedContinuousTensorSpec -from torchrl.envs import ParallelEnv -from torchrl.envs.common import EnvBase -from torchrl.envs.env_creator import env_creator, EnvCreator -from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import ( - CatFrames, - CenterCrop, +import functools +import tempfile +from contextlib import nullcontext + +import torch + +import torch.nn as nn +from hydra.utils import instantiate +from tensordict import NestedKey +from tensordict.nn import ( + InteractionType, + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, + TensorDictModule, + TensorDictSequential, +) +from torchrl.collectors import SyncDataCollector + +from torchrl.data import ( + CompositeSpec, + LazyMemmapStorage, + SliceSampler, + TensorDictReplayBuffer, + UnboundedContinuousTensorSpec, +) + +from torchrl.envs import ( + Compose, + DeviceCastTransform, + DMControlEnv, DoubleToFloat, + DreamerDecoder, + DreamerEnv, + EnvCreator, + ExcludeTransform, + # ExcludeTransform, + FrameSkipTransform, GrayScale, - NoopResetEnv, - ObservationNorm, + GymEnv, + ParallelEnv, + RenameTransform, Resize, - RewardScaling, + RewardSum, + set_gym_backend, + StepCounter, + TensorDictPrimer, ToTensorImage, TransformedEnv, ) -from torchrl.envs.transforms.transforms import FlattenObservation, TensorDictPrimer -from torchrl.record.loggers import Logger -from torchrl.record.recorder import VideoRecorder - -__all__ = [ - "transformed_env_constructor", - "parallel_env_constructor", -] - -LIBS = { - "gym": GymEnv, - "dm_control": DMControlEnv, -} -import torch -from torch.cuda.amp import autocast +from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type +from torchrl.modules import ( + AdditiveGaussianWrapper, + DreamerActor, + IndependentNormal, + MLP, + ObsDecoder, + ObsEncoder, + RSSMPosterior, + RSSMPrior, + RSSMRollout, + SafeModule, + SafeProbabilisticModule, + SafeProbabilisticTensorDictSequential, + SafeSequential, + TanhNormal, + WorldModelWrapper, +) +from torchrl.record import VideoRecorder + + +def _make_env(cfg, device, from_pixels=False): + lib = cfg.env.backend + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + env = GymEnv( + cfg.env.name, + device=device, + from_pixels=cfg.env.from_pixels or from_pixels, + pixels_only=cfg.env.from_pixels, + ) + elif lib == "dm_control": + env = DMControlEnv( + cfg.env.name, + cfg.env.task, + from_pixels=cfg.env.from_pixels or from_pixels, + pixels_only=cfg.env.from_pixels, + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") + default_dict = { + "state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim,)), + "belief": UnboundedContinuousTensorSpec(shape=(cfg.networks.rssm_hidden_dim,)), + } + env = env.append_transform( + TensorDictPrimer(random=False, default_value=0, **default_dict) + ) + assert env is not None + return env -def make_env_transforms( - env, - cfg, - video_tag, - logger, - env_name, - stats, - norm_obs_only, - env_library, - action_dim_gsde, - state_dim_gsde, - batch_dims=0, - obs_norm_state_dict=None, -): - env = TransformedEnv(env) - - from_pixels = cfg.from_pixels - vecnorm = cfg.vecnorm - norm_rewards = vecnorm and cfg.norm_rewards - reward_scaling = cfg.reward_scaling - reward_loc = cfg.reward_loc - - if len(video_tag): - center_crop = cfg.center_crop - if center_crop: - center_crop = center_crop[0] +def transform_env(cfg, env): + if not isinstance(env, TransformedEnv): + env = TransformedEnv(env) + if cfg.env.from_pixels: + # transforms pixel from 0-255 to 0-1 (uint8 to float32) env.append_transform( - VideoRecorder( - logger=logger, - tag=f"{video_tag}_{env_name}_video", - center_crop=center_crop, - ), + RenameTransform(in_keys=["pixels"], out_keys=["pixels_int"]) ) + env.append_transform( + ToTensorImage(from_int=True, in_keys=["pixels_int"], out_keys=["pixels"]) + ) + if cfg.env.grayscale: + env.append_transform(GrayScale()) - if cfg.noops: - env.append_transform(NoopResetEnv(cfg.noops)) + image_size = cfg.env.image_size + env.append_transform(Resize(image_size, image_size)) - if from_pixels: - if not cfg.catframes: - raise RuntimeError( - "this env builder currently only accepts positive catframes values" - "when pixels are being used." - ) - env.append_transform(ToTensorImage()) - if cfg.center_crop: - env.append_transform(CenterCrop(*cfg.center_crop)) - env.append_transform(Resize(cfg.image_size, cfg.image_size)) - if cfg.grayscale: - env.append_transform(GrayScale()) - env.append_transform(FlattenObservation(0, -3, allow_positive_dim=True)) - env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"], dim=-3)) - if stats is None and obs_norm_state_dict is None: - obs_stats = { - "loc": torch.zeros(()), - "scale": torch.ones(()), - } - elif stats is None and obs_norm_state_dict is not None: - obs_stats = obs_norm_state_dict - else: - obs_stats = stats - obs_stats["standard_normal"] = True - obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"]) - env.append_transform(obs_norm) - if norm_rewards: - reward_scaling = 1.0 - reward_loc = 0.0 - if norm_obs_only: - reward_scaling = 1.0 - reward_loc = 0.0 - if reward_scaling is not None: - env.append_transform(RewardScaling(reward_loc, reward_scaling)) - - double_to_float_list = [] - float_to_double_list = [] - if env_library is DMControlEnv: - double_to_float_list += [ - "reward", - ] - float_to_double_list += ["action"] # DMControl requires double-precision env.append_transform(DoubleToFloat()) + env.append_transform(RewardSum()) + env.append_transform(FrameSkipTransform(cfg.env.frame_skip)) + env.append_transform(StepCounter(cfg.env.horizon)) - default_dict = { - "state": UnboundedContinuousTensorSpec(shape=(*env.batch_size, cfg.state_dim)), - "belief": UnboundedContinuousTensorSpec( - shape=(*env.batch_size, cfg.rssm_hidden_dim) - ), - } - env.append_transform( - TensorDictPrimer(random=False, default_value=0, **default_dict) - ) return env -def transformed_env_constructor( - cfg: "DictConfig", # noqa: F821 - video_tag: str = "", - logger: Optional[Logger] = None, - stats: Optional[dict] = None, - norm_obs_only: bool = False, - use_env_creator: bool = False, - custom_env_maker: Optional[Callable] = None, - custom_env: Optional[EnvBase] = None, - return_transformed_envs: bool = True, - action_dim_gsde: Optional[int] = None, - state_dim_gsde: Optional[int] = None, - batch_dims: Optional[int] = 0, - obs_norm_state_dict: Optional[dict] = None, - ignore_device: bool = False, -) -> Union[Callable, EnvCreator]: - """ - Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. - - Args: - cfg (DictConfig): a DictConfig containing the arguments of the script. - video_tag (str, optional): video tag to be passed to the Logger object - logger (Logger, optional): logger associated with the script - stats (dict, optional): a dictionary containing the `loc` and `scale` for the `ObservationNorm` transform - norm_obs_only (bool, optional): If `True` and `VecNorm` is used, the reward won't be normalized online. - Default is `False`. - use_env_creator (bool, optional): wheter the `EnvCreator` class should be used. By using `EnvCreator`, - one can make sure that running statistics will be put in shared memory and accessible for all workers - when using a `VecNorm` transform. Default is `True`. - custom_env_maker (callable, optional): if your env maker is not part - of torchrl env wrappers, a custom callable - can be passed instead. In this case it will override the - constructor retrieved from `args`. - custom_env (EnvBase, optional): if an existing environment needs to be - transformed_in, it can be passed directly to this helper. `custom_env_maker` - and `custom_env` are exclusive features. - return_transformed_envs (bool, optional): if True, a transformed_in environment - is returned. - action_dim_gsde (int, Optional): if gSDE is used, this can present the action dim to initialize the noise. - Make sure this is indicated in environment executed in parallel. - state_dim_gsde: if gSDE is used, this can present the state dim to initialize the noise. - Make sure this is indicated in environment executed in parallel. - batch_dims (int, optional): number of dimensions of a batch of data. If a single env is - used, it should be 0 (default). If multiple envs are being transformed in parallel, - it should be set to 1 (or the number of dims of the batch). - obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded - into the environment - ignore_device (bool, optional): if True, the device is ignored. - """ - - def make_transformed_env(**kwargs) -> TransformedEnv: - env_name = cfg.env_name - env_task = cfg.env_task - env_library = LIBS[cfg.env_library] - frame_skip = cfg.frame_skip - from_pixels = cfg.from_pixels - - if custom_env is None and custom_env_maker is None: - if not ignore_device: - if isinstance(cfg.collector_device, str): - device = cfg.collector_device - elif isinstance(cfg.collector_device, Sequence): - device = cfg.collector_device[0] - else: - raise ValueError( - "collector_device must be either a string or a sequence of strings" - ) - else: - device = None - env_kwargs = { - "env_name": env_name, - "device": device, - "frame_skip": frame_skip, - "from_pixels": from_pixels or len(video_tag), - "pixels_only": from_pixels, - } - if env_name == "quadruped": - # hard code camera_id for quadruped - camera_id = "x" - env_kwargs["camera_id"] = camera_id - if env_library is DMControlEnv: - env_kwargs.update({"task_name": env_task}) - env_kwargs.update(kwargs) - env = env_library(**env_kwargs) - elif custom_env is None and custom_env_maker is not None: - env = custom_env_maker(**kwargs) - elif custom_env_maker is None and custom_env is not None: - env = custom_env - else: - raise RuntimeError("cannot provive both custom_env and custom_env_maker") - - if not return_transformed_envs: - return env - - return make_env_transforms( - env, - cfg, - video_tag, - logger, - env_name, - stats, - norm_obs_only, - env_library, - action_dim_gsde, - state_dim_gsde, - batch_dims=batch_dims, - obs_norm_state_dict=obs_norm_state_dict, - ) - - if use_env_creator: - return env_creator(make_transformed_env) - return make_transformed_env - - -def parallel_env_constructor( - cfg: "DictConfig", **kwargs # noqa: F821 -) -> Union[ParallelEnv, EnvCreator]: - """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. - - Args: - cfg (DictConfig): config containing user-defined arguments - kwargs: keyword arguments for the `transformed_env_constructor` method. - """ - batch_transform = cfg.batch_transform - kwargs.update({"cfg": cfg, "use_env_creator": True}) - if cfg.env_per_collector == 1: - make_transformed_env = transformed_env_constructor(**kwargs) - return make_transformed_env - make_transformed_env = transformed_env_constructor( - return_transformed_envs=not batch_transform, ignore_device=True, **kwargs +def make_environments(cfg, parallel_envs=1, logger=None): + """Make environments for training and evaluation.""" + func = functools.partial(_make_env, cfg=cfg, device=cfg.env.device) + train_env = ParallelEnv( + parallel_envs, + EnvCreator(func), + serial_for_single=True, + ) + train_env = transform_env(cfg, train_env) + train_env.set_seed(cfg.env.seed) + func = functools.partial( + _make_env, cfg=cfg, device=cfg.env.device, from_pixels=cfg.logger.video ) - parallel_env = ParallelEnv( - num_workers=cfg.env_per_collector, - create_env_fn=make_transformed_env, - create_env_kwargs=None, - pin_memory=cfg.pin_memory, - device=cfg.collector_device, + eval_env = ParallelEnv( + 1, + EnvCreator(func), serial_for_single=True, ) - if batch_transform: - kwargs.update( - { - "cfg": cfg, - "use_env_creator": False, - "custom_env": parallel_env, - "batch_dims": 1, - } + eval_env = transform_env(cfg, eval_env) + eval_env.set_seed(cfg.env.seed + 1) + if cfg.logger.video: + eval_env.insert_transform(0, VideoRecorder(logger, tag="eval/video")) + check_env_specs(train_env) + check_env_specs(eval_env) + return train_env, eval_env + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + +def make_dreamer( + cfg, + device, + action_key: str = "action", + value_key: str = "state_value", + use_decoder_in_env: bool = False, + compile: bool = True, + logger=None, +): + test_env = _make_env(cfg, device="cpu") + test_env = transform_env(cfg, test_env) + # Make encoder and decoder + if cfg.env.from_pixels: + encoder = ObsEncoder() + decoder = ObsDecoder() + observation_in_key = "pixels" + observation_out_key = "reco_pixels" + else: + encoder = MLP( + out_features=1024, + depth=2, + num_cells=cfg.networks.hidden_dim, + activation_class=get_activation(cfg.networks.activation), + ) + decoder = MLP( + out_features=test_env.observation_spec["observation"].shape[-1], + depth=2, + num_cells=cfg.networks.hidden_dim, + activation_class=get_activation(cfg.networks.activation), + ) + observation_in_key = "observation" + observation_out_key = "reco_observation" + + # Make RSSM + rssm_prior = RSSMPrior( + hidden_dim=cfg.networks.rssm_hidden_dim, + rnn_hidden_dim=cfg.networks.rssm_hidden_dim, + state_dim=cfg.networks.state_dim, + action_spec=test_env.action_spec, + ) + rssm_posterior = RSSMPosterior( + hidden_dim=cfg.networks.rssm_hidden_dim, state_dim=cfg.networks.state_dim + ) + # Make reward module + reward_module = MLP( + out_features=1, + depth=2, + num_cells=cfg.networks.hidden_dim, + activation_class=get_activation(cfg.networks.activation), + ) + + # Make combined world model + world_model = _dreamer_make_world_model( + encoder, + decoder, + rssm_prior, + rssm_posterior, + reward_module, + observation_in_key=observation_in_key, + observation_out_key=observation_out_key, + ) + world_model.to(device) + + # Initialize world model + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + tensordict = ( + test_env.rollout(5, auto_cast_to_device=True) + .unsqueeze(-1) + .to(world_model.device) ) - env = transformed_env_constructor(**kwargs)() - return env - return parallel_env + tensordict = tensordict.to_tensordict() + world_model(tensordict) + + # Create model-based environment + model_based_env = _dreamer_make_mbenv( + reward_module=reward_module, + rssm_prior=rssm_prior, + decoder=decoder, + observation_out_key=observation_out_key, + test_env=test_env, + use_decoder_in_env=use_decoder_in_env, + state_dim=cfg.networks.state_dim, + rssm_hidden_dim=cfg.networks.rssm_hidden_dim, + ) + + # def detach_state_and_belief(data): + # data.set("state", data.get("state").detach()) + # data.set("belief", data.get("belief").detach()) + # return data + # + # model_based_env = model_based_env.append_transform(detach_state_and_belief) + check_env_specs(model_based_env) + + # Make actor + actor_simulator, actor_realworld = _dreamer_make_actors( + encoder=encoder, + observation_in_key=observation_in_key, + rssm_prior=rssm_prior, + rssm_posterior=rssm_posterior, + mlp_num_units=cfg.networks.hidden_dim, + activation=get_activation(cfg.networks.activation), + action_key=action_key, + test_env=test_env, + ) + # Exploration noise to be added to the actor_realworld + actor_realworld = AdditiveGaussianWrapper( + actor_realworld, + sigma_init=1.0, + sigma_end=1.0, + annealing_num_steps=1, + mean=0.0, + std=cfg.networks.exploration_noise, + ) + # Make Critic + value_model = _dreamer_make_value_model( + hidden_dim=cfg.networks.hidden_dim, + activation=cfg.networks.activation, + value_key=value_key, + ) + + actor_simulator.to(device) + value_model.to(device) + actor_realworld.to(device) + model_based_env.to(device) + + # Initialize model-based environment, actor and critic + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + tensordict = ( + model_based_env.fake_tensordict().unsqueeze(-1).to(value_model.device) + ) + tensordict = tensordict + tensordict = actor_simulator(tensordict) + value_model(tensordict) + + if cfg.logger.video: + model_based_env_eval = model_based_env.append_transform(DreamerDecoder()) + + def float_to_int(data): + reco_pixels_float = data.get("reco_pixels") + reco_pixels = (reco_pixels_float * 255).floor() + # assert (reco_pixels < 256).all() and (reco_pixels > 0).all(), (reco_pixels.min(), reco_pixels.max()) + reco_pixels = reco_pixels.to(torch.uint8) + data.set("reco_pixels_float", reco_pixels_float) + return data.set("reco_pixels", reco_pixels) + + model_based_env_eval.append_transform(float_to_int) + model_based_env_eval.append_transform( + VideoRecorder( + logger=logger, tag="eval/simulated_rendering", in_keys=["reco_pixels"] + ) + ) -def recover_pixels(pixels, stats): + else: + model_based_env_eval = None return ( - (255 * (pixels * stats["scale"] + stats["loc"])) - .clamp(min=0, max=255) - .to(torch.uint8) + world_model, + model_based_env, + model_based_env_eval, + actor_simulator, + value_model, + actor_realworld, ) -@torch.inference_mode() -def call_record( - logger, - record, - collected_frames, - sampled_tensordict, - stats, - model_based_env, - actor_model, - cfg, +def make_collector(cfg, train_env, actor_model_explore): + """Make collector.""" + collector = SyncDataCollector( + train_env, + actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + policy_device=instantiate(cfg.collector.device), + env_device=train_env.device, + storing_device="cpu", + ) + collector.set_seed(cfg.env.seed) + + return collector + + +def make_replay_buffer( + *, + batch_size, + batch_seq_len, + buffer_size=1000000, + buffer_scratch_dir=None, + device=None, + prefetch=3, + pixel_obs=True, + grayscale=True, + image_size, + use_autocast, ): - td_record = record(None) - if td_record is not None and logger is not None: - for key, value in td_record.items(): - if key in ["r_evaluation", "total_r_evaluation"]: - logger.log_scalar( - key, - value.detach().item(), - step=collected_frames, - ) - # Compute observation reco - if cfg.record_video and record._count % cfg.record_interval == 0: - world_model_td = sampled_tensordict - - true_pixels = recover_pixels(world_model_td[("next", "pixels")], stats) - - reco_pixels = recover_pixels(world_model_td["next", "reco_pixels"], stats) - with autocast(dtype=torch.float16): - world_model_td = world_model_td.select("state", "belief", "reward") - world_model_td = model_based_env.rollout( - max_steps=true_pixels.shape[1], - policy=actor_model, - auto_reset=False, - tensordict=world_model_td[:, 0], + with ( + tempfile.TemporaryDirectory() + if buffer_scratch_dir is None + else nullcontext(buffer_scratch_dir) + ) as scratch_dir: + transforms = Compose() + if pixel_obs: + + def check_no_pixels(data): + assert "pixels" not in data.keys() + return data + + transforms = Compose( + ExcludeTransform("pixels", ("next", "pixels"), inverse=True), + check_no_pixels, # will be called only during forward + ToTensorImage( + in_keys=["pixels_int", ("next", "pixels_int")], + out_keys=["pixels", ("next", "pixels")], + ), ) - imagine_pxls = recover_pixels( - model_based_env.decode_obs(world_model_td)["next", "reco_pixels"], - stats, + if grayscale: + transforms.append(GrayScale(in_keys=["pixels", ("next", "pixels")])) + transforms.append( + Resize(image_size, image_size, in_keys=["pixels", ("next", "pixels")]) + ) + transforms.append(DeviceCastTransform(device=device)) + + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device="cpu", + ndim=2, + ), + sampler=SliceSampler( + slice_len=batch_seq_len, + strict_length=False, + traj_key=("collector", "traj_ids"), + cache_values=True, + compile=True, + ), + transform=transforms, + batch_size=batch_size, ) + return replay_buffer - stacked_pixels = torch.cat([true_pixels, reco_pixels, imagine_pxls], dim=-1) - if logger is not None: - logger.log_video( - "pixels_rec_and_imag", - stacked_pixels.detach().cpu(), - ) + +def _dreamer_make_value_model( + hidden_dim: int = 400, activation: str = "elu", value_key: str = "state_value" +): + value_model = MLP( + out_features=1, + depth=3, + num_cells=hidden_dim, + activation_class=get_activation(activation), + ) + value_model = ProbabilisticTensorDictSequential( + TensorDictModule( + value_model, + in_keys=["state", "belief"], + out_keys=["loc"], + ), + ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=[value_key], + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": 1}, + ), + ) + + return value_model + + +def _dreamer_make_actors( + encoder, + observation_in_key, + rssm_prior, + rssm_posterior, + mlp_num_units, + activation, + action_key, + test_env, +): + actor_module = DreamerActor( + out_features=test_env.action_spec.shape[-1], + depth=3, + num_cells=mlp_num_units, + activation_class=activation, + ) + actor_simulator = _dreamer_make_actor_sim(action_key, test_env, actor_module) + actor_realworld = _dreamer_make_actor_real( + encoder, + observation_in_key, + rssm_prior, + rssm_posterior, + actor_module, + action_key, + test_env, + ) + return actor_simulator, actor_realworld + + +def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): + actor_simulator = SafeProbabilisticTensorDictSequential( + SafeModule( + actor_module, + in_keys=["state", "belief"], + out_keys=["loc", "scale"], + spec=CompositeSpec( + **{ + "loc": UnboundedContinuousTensorSpec( + proof_environment.action_spec.shape, + device=proof_environment.action_spec.device, + ), + "scale": UnboundedContinuousTensorSpec( + proof_environment.action_spec.shape, + device=proof_environment.action_spec.device, + ), + } + ), + ), + SafeProbabilisticModule( + in_keys=["loc", "scale"], + out_keys=[action_key], + default_interaction_type=InteractionType.RANDOM, + distribution_class=TanhNormal, + distribution_kwargs={"tanh_loc": True}, + spec=CompositeSpec(**{action_key: proof_environment.action_spec}), + ), + ) + return actor_simulator -def grad_norm(optimizer: torch.optim.Optimizer): - sum_of_sq = 0.0 - for pg in optimizer.param_groups: - for p in pg["params"]: - sum_of_sq += p.grad.pow(2).sum() - return sum_of_sq.sqrt().detach().item() - - -def make_recorder_env(cfg, video_tag, obs_norm_state_dict, logger, create_env_fn): - recorder = transformed_env_constructor( - cfg, - video_tag=video_tag, - norm_obs_only=True, - obs_norm_state_dict=obs_norm_state_dict, - logger=logger, - use_env_creator=False, - )() - - # remove video recorder from recorder to have matching state_dict keys - if cfg.record_video: - recorder_rm = TransformedEnv(recorder.base_env) - for transform in recorder.transform: - if not isinstance(transform, VideoRecorder): - recorder_rm.append_transform(transform.clone()) +def _dreamer_make_actor_real( + encoder, + observation_in_key, + rssm_prior, + rssm_posterior, + actor_module, + action_key, + proof_environment, +): + # actor for real world: interacts with states ~ posterior + # Out actor differs from the original paper where first they compute prior and posterior and then act on it + # but we found that this approach worked better. + actor_realworld = SafeSequential( + SafeModule( + encoder, + in_keys=[observation_in_key], + out_keys=["encoded_latents"], + ), + SafeModule( + rssm_posterior, + in_keys=["belief", "encoded_latents"], + out_keys=[ + "_", + "_", + "state", + ], + ), + SafeProbabilisticTensorDictSequential( + SafeModule( + actor_module, + in_keys=["state", "belief"], + out_keys=["loc", "scale"], + spec=CompositeSpec( + **{ + "loc": UnboundedContinuousTensorSpec( + proof_environment.action_spec.shape, + ), + "scale": UnboundedContinuousTensorSpec( + proof_environment.action_spec.shape, + ), + } + ), + ), + SafeProbabilisticModule( + in_keys=["loc", "scale"], + out_keys=[action_key], + default_interaction_type=InteractionType.MODE, + distribution_class=TanhNormal, + distribution_kwargs={"tanh_loc": True}, + spec=CompositeSpec( + **{action_key: proof_environment.action_spec.to("cpu")} + ), + ), + ), + SafeModule( + rssm_prior, + in_keys=["state", "belief", action_key], + out_keys=[ + "_", + "_", + "_", # we don't need the prior state + ("next", "belief"), + ], + ), + ) + return actor_realworld + + +def _dreamer_make_mbenv( + reward_module, + rssm_prior, + test_env, + decoder, + observation_out_key: str = "reco_pixels", + use_decoder_in_env: bool = False, + state_dim: int = 30, + rssm_hidden_dim: int = 200, +): + # MB environment + if use_decoder_in_env: + mb_env_obs_decoder = SafeModule( + decoder, + in_keys=["state", "belief"], + out_keys=[observation_out_key], + ) else: - recorder_rm = recorder - - if isinstance(create_env_fn, ParallelEnv): - sd = create_env_fn.state_dict()["worker0"] - elif isinstance(create_env_fn, EnvCreator): - _env = create_env_fn() - _env.rollout(2) - sd = _env.state_dict() - del _env + mb_env_obs_decoder = None + + transition_model = SafeSequential( + SafeModule( + rssm_prior, + in_keys=["state", "belief", "action"], + out_keys=[ + "_", + "_", + "state", + "belief", + ], + ), + ) + + reward_model = SafeProbabilisticTensorDictSequential( + SafeModule( + reward_module, + in_keys=["state", "belief"], + out_keys=["loc"], + ), + SafeProbabilisticModule( + in_keys=["loc"], + out_keys=["reward"], + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": 1}, + ), + ) + + model_based_env = DreamerEnv( + world_model=WorldModelWrapper( + transition_model, + reward_model, + ), + prior_shape=torch.Size([state_dim]), + belief_shape=torch.Size([rssm_hidden_dim]), + obs_decoder=mb_env_obs_decoder, + ) + + model_based_env.set_specs_from_env(test_env) + return model_based_env + + +def _dreamer_make_world_model( + encoder, + decoder, + rssm_prior, + rssm_posterior, + reward_module, + observation_in_key: NestedKey = "pixels", + observation_out_key: NestedKey = "reco_pixels", +): + # World Model and reward model + rssm_rollout = RSSMRollout( + TensorDictModule( + rssm_prior, + in_keys=["state", "belief", "action"], + out_keys=[ + ("next", "prior_mean"), + ("next", "prior_std"), + "_", + ("next", "belief"), + ], + ), + TensorDictModule( + rssm_posterior, + in_keys=[("next", "belief"), ("next", "encoded_latents")], + out_keys=[ + ("next", "posterior_mean"), + ("next", "posterior_std"), + ("next", "state"), + ], + ), + ) + event_dim = 3 if observation_out_key == "reco_pixels" else 1 # 3 for RGB + decoder = ProbabilisticTensorDictSequential( + TensorDictModule( + decoder, + in_keys=[("next", "state"), ("next", "belief")], + out_keys=["loc"], + ), + ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=[("next", observation_out_key)], + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": event_dim}, + ), + ) + + transition_model = TensorDictSequential( + TensorDictModule( + encoder, + in_keys=[("next", observation_in_key)], + out_keys=[("next", "encoded_latents")], + ), + rssm_rollout, + decoder, + ) + + reward_model = ProbabilisticTensorDictSequential( + TensorDictModule( + reward_module, + in_keys=[("next", "state"), ("next", "belief")], + out_keys=[("next", "loc")], + ), + ProbabilisticTensorDictModule( + in_keys=[("next", "loc")], + out_keys=[("next", "reward")], + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": 1}, + ), + ) + + world_model = WorldModelWrapper( + transition_model, + reward_model, + ) + return world_model + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(name): + if name == "relu": + return nn.ReLU + elif name == "tanh": + return nn.Tanh + elif name == "leaky_relu": + return nn.LeakyReLU + elif name == "elu": + return nn.ELU else: - sd = create_env_fn.state_dict() - sd = { - key: val - for key, val in sd.items() - if key.endswith("loc") or key.endswith("scale") - } - if not len(sd): - raise ValueError("Empty state dict") - recorder_rm.load_state_dict(sd, strict=False) - # reset reward scaling - for t in recorder.transform: - if isinstance(t, RewardScaling): - t.scale.fill_(1.0) - t.loc.fill_(0.0) - return recorder - - -@dataclass -class EnvConfig: - env_library: str = "gym" - # env_library used for the simulated environment. Default=gym - env_name: str = "Humanoid-v2" - # name of the environment to be created. Default=Humanoid-v2 - env_task: str = "" - # task (if any) for the environment. Default=run - from_pixels: bool = False - # whether the environment output should be state vector(s) (default) or the pixels. - frame_skip: int = 1 - # frame_skip for the environment. Note that this value does NOT impact the buffer size, - # maximum steps per trajectory, frames per batch or any other factor in the algorithm, - # e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4 - # the actual number of frames retrieved will be 200e6. Default=1. - reward_scaling: Optional[float] = None - # scale of the reward. - reward_loc: float = 0.0 - # location of the reward. - init_env_steps: int = 1000 - # number of random steps to compute normalizing constants - vecnorm: bool = False - # Normalizes the environment observation and reward outputs with the running statistics obtained across processes. - norm_rewards: bool = False - # If True, rewards will be normalized on the fly. This may interfere with SAC update rule and should be used cautiously. - norm_stats: bool = True - # Deactivates the normalization based on random collection of data. - noops: int = 0 - # number of random steps to do after reset. Default is 0 - catframes: int = 0 - # Number of frames to concatenate through time. Default is 0 (do not use CatFrames). - center_crop: Any = dataclass_field(default_factory=lambda: []) - # center crop size. - grayscale: bool = True - # Disables grayscale transform. - max_frames_per_traj: int = 1000 - # Number of steps before a reset of the environment is called (if it has not been flagged as done before). - batch_transform: bool = True - # if True, the transforms will be applied to the parallel env, and not to each individual env.\ - image_size: int = 84 + raise NotImplementedError + + +def _default_device(device=None): + if device in ("", None): + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + return torch.device(device) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 161ae04573a..b8af95f1657 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -32,7 +32,7 @@ from tensordict.utils import NestedKey from torch import multiprocessing as mp -LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "DEBUG") +LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO") logger = logging.getLogger("torchrl") logger.setLevel(getattr(logging, LOGGING_LEVEL)) # Disable propagation to the root logger @@ -98,6 +98,12 @@ def print(prefix=None): # noqa: T202 ) logger.info(" -- ".join(strings)) + @classmethod + def todict(cls, percall=True): + if percall: + return {key: val[0] for key, val in cls._REG.items()} + return {key: val[1] for key, val in cls._REG.items()} + @staticmethod def erase(): for k in timeit._REG: diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e4234dadeea..86cd88043da 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -25,12 +25,7 @@ from torch.utils._pytree import LeafSpec, tree_flatten, tree_map, tree_unflatten -from torchrl._utils import ( - _CKPT_BACKEND, - implement_for, - logger as torchrl_logger, - VERBOSE, -) +from torchrl._utils import _CKPT_BACKEND, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.utils import _is_int, INT_CLASSES try: @@ -913,8 +908,7 @@ def _init( self, data: Union[TensorDictBase, torch.Tensor, "PyTree"], # noqa: F821 ) -> None: - if VERBOSE: - torchrl_logger.info("Creating a TensorStorage...") + torchrl_logger.debug("Creating a TensorStorage...") if self.device == "auto": self.device = data.device @@ -1090,8 +1084,7 @@ def load_state_dict(self, state_dict): self._len = state_dict["_len"] def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: - if VERBOSE: - torchrl_logger.info("Creating a MemmapStorage...") + torchrl_logger.debug("Creating a MemmapStorage...") if self.device == "auto": self.device = data.device if self.device.type != "cpu": @@ -1116,14 +1109,13 @@ def max_size_along_dim0(data_shape): for key, tensor in sorted( out.items(include_nested=True, leaves_only=True), key=str ): - if VERBOSE: - try: - filesize = os.path.getsize(tensor.filename) / 1024 / 1024 - torchrl_logger.info( - f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." - ) - except RuntimeError: - pass + try: + filesize = os.path.getsize(tensor.filename) / 1024 / 1024 + torchrl_logger.debug( + f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." + ) + except RuntimeError: + pass else: out = _init_pytree(self.scratch_dir, max_size_along_dim0, data) self._storage = out @@ -1479,14 +1471,13 @@ def _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor): filename=total_tensor_path, dtype=tensor.dtype, ) - if VERBOSE: - try: - filesize = os.path.getsize(out.filename) / 1024 / 1024 - torchrl_logger.info( - f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." - ) - except RuntimeError: - pass + try: + filesize = os.path.getsize(tensor.filename) / 1024 / 1024 + torchrl_logger.debug( + f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." + ) + except RuntimeError: + pass return out diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index e8f7fbe3ff2..84676728aa5 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -36,7 +36,7 @@ VmasEnv, VmasWrapper, ) -from .model_based import ModelBasedEnvBase +from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase from .transforms import ( ActionMask, AutoResetEnv, diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c3cb58897f2..f5d4625fd07 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2627,6 +2627,7 @@ def _rollout_stop_early( sync_func() else: tensordict.clear_device_() + tensordict = policy(tensordict) if auto_cast_to_device: if env_device is not None: diff --git a/torchrl/envs/model_based/__init__.py b/torchrl/envs/model_based/__init__.py index 5f628079173..437146a4909 100644 --- a/torchrl/envs/model_based/__init__.py +++ b/torchrl/envs/model_based/__init__.py @@ -4,3 +4,4 @@ # LICENSE file in the root directory of this source tree. from .common import ModelBasedEnvBase +from .dreamer import DreamerDecoder, DreamerEnv diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index a04607829c6..c1940f75a8f 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -139,11 +139,15 @@ def __new__(cls, *args, **kwargs): def set_specs_from_env(self, env: EnvBase): """Sets the specs of the environment from the specs of the given environment.""" - self.observation_spec = env.observation_spec.clone().to(self.device) - self.reward_spec = env.reward_spec.clone().to(self.device) - self.action_spec = env.action_spec.clone().to(self.device) - self.done_spec = env.done_spec.clone().to(self.device) - self.state_spec = env.state_spec.clone().to(self.device) + device = self.device + output_spec = env.output_spec.clone() + input_spec = env.input_spec.clone() + if device is not None: + output_spec = output_spec.to(device) + input_spec = input_spec.to(device) + self.__dict__["_output_spec"] = output_spec + self.__dict__["_input_spec"] = input_spec + self.empty_cache() def _step( self, @@ -161,12 +165,13 @@ def _step( else: tensordict_out = self.world_model(tensordict_out) # done can be missing, it will be filled by `step` - return tensordict_out.select( + tensordict_out = tensordict_out.select( *self.observation_spec.keys(), *self.full_done_spec.keys(), *self.full_reward_spec.keys(), strict=False, ) + return tensordict_out @abc.abstractmethod def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict: diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index e36ddf9e02a..f44c4aa025c 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -14,6 +14,7 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.model_based import ModelBasedEnvBase +from torchrl.envs.transforms.transforms import Transform class DreamerEnv(ModelBasedEnvBase): @@ -39,14 +40,6 @@ def __init__( def set_specs_from_env(self, env: EnvBase): """Sets the specs of the environment from the specs of the given environment.""" super().set_specs_from_env(env) - # self.observation_spec = CompositeSpec( - # next_state=UnboundedContinuousTensorSpec( - # shape=self.prior_shape, device=self.device - # ), - # next_belief=UnboundedContinuousTensorSpec( - # shape=self.belief_shape, device=self.device - # ), - # ) self.action_spec = self.action_spec.to(self.device) self.state_spec = CompositeSpec( state=self.observation_spec["state"], @@ -57,10 +50,20 @@ def set_specs_from_env(self, env: EnvBase): def _reset(self, tensordict=None, **kwargs) -> TensorDict: batch_size = tensordict.batch_size if tensordict is not None else [] device = tensordict.device if tensordict is not None else self.device - td = self.state_spec.rand(shape=batch_size).to(device) - td.set("action", self.action_spec.rand(shape=batch_size).to(device)) - td[("next", "reward")] = self.reward_spec.rand(shape=batch_size).to(device) - td.update(self.observation_spec.rand(shape=batch_size).to(device)) + if tensordict is None: + td = self.state_spec.rand(shape=batch_size) + # why don't we reuse actions taken at those steps? + td.set("action", self.action_spec.rand(shape=batch_size)) + td[("next", "reward")] = self.reward_spec.rand(shape=batch_size) + td.update(self.observation_spec.rand(shape=batch_size)) + if device is not None: + td = td.to(device, non_blocking=True) + if torch.cuda.is_available() and device.type == "cpu": + torch.cuda.synchronize() + elif torch.backends.mps.is_available(): + torch.mps.synchronize() + else: + td = tensordict.clone() return td def decode_obs(self, tensordict: TensorDict, compute_latents=False) -> TensorDict: @@ -69,3 +72,21 @@ def decode_obs(self, tensordict: TensorDict, compute_latents=False) -> TensorDic if compute_latents: tensordict = self.world_model(tensordict) return self.obs_decoder(tensordict) + + +class DreamerDecoder(Transform): + """A transform to record the decoded observations in Dreamer. + + Examples: + >>> model_based_env = DreamerEnv(...) + >>> model_based_env_eval = model_based_env.append_transform(DreamerDecoder()) + """ + + def _call(self, tensordict): + return self.parent.base_env.obs_decoder(tensordict) + + def _reset(self, tensordict, tensordict_reset): + return self._call(tensordict_reset) + + def transform_observation_spec(self, observation_spec): + return observation_spec diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 4c61dd82f88..8eb338bc074 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3505,10 +3505,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "this functionality is not covered. Consider passing the in_keys " "or not passing any out_keys." ) - for in_key, item in list(tensordict.items(True, True)): + + def func(item): if item.dtype == self.dtype_in: item = self._apply_transform(item) - tensordict.set(in_key, item) + return item + + tensordict = tensordict._fast_apply(func) else: # we made sure that if in_keys is not None, out_keys is not None either for in_key, out_key in zip(in_keys, out_keys): @@ -5777,6 +5780,8 @@ class ExcludeTransform(Transform): Args: *excluded_keys (iterable of NestedKey): The name of the keys to exclude. If the key is not present, it is simply ignored. + inverse (bool, optional): if ``True``, the exclusion will occur during the ``inv`` call. + Defaults to ``False``. Examples: >>> import gymnasium @@ -5805,7 +5810,7 @@ class ExcludeTransform(Transform): """ - def __init__(self, *excluded_keys): + def __init__(self, *excluded_keys, inverse: bool = False): super().__init__() try: excluded_keys = unravel_key_list(excluded_keys) @@ -5814,35 +5819,46 @@ def __init__(self, *excluded_keys): "excluded keys must be a list or tuple of strings or tuples of strings." ) self.excluded_keys = excluded_keys + self.inverse = inverse def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.exclude(*self.excluded_keys) + if not self.inverse: + return tensordict.exclude(*self.excluded_keys) + return tensordict + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + if self.inverse: + return tensordict.exclude(*self.excluded_keys) + return tensordict forward = _call def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: - return tensordict_reset.exclude(*self.excluded_keys) + if not self.inverse: + return tensordict.exclude(*self.excluded_keys) + return tensordict def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: - full_done_spec = output_spec["full_done_spec"] - full_reward_spec = output_spec["full_reward_spec"] - full_observation_spec = output_spec["full_observation_spec"] - for key in self.excluded_keys: - # done_spec - if unravel_key(key) in list(full_done_spec.keys(True, True)): - del full_done_spec[key] - continue - # reward_spec - if unravel_key(key) in list(full_reward_spec.keys(True, True)): - del full_reward_spec[key] - continue - # observation_spec - if unravel_key(key) in list(full_observation_spec.keys(True, True)): - del full_observation_spec[key] - continue - raise KeyError(f"Key {key} not found in the environment outputs.") + if not self.inverse: + full_done_spec = output_spec["full_done_spec"] + full_reward_spec = output_spec["full_reward_spec"] + full_observation_spec = output_spec["full_observation_spec"] + for key in self.excluded_keys: + # done_spec + if unravel_key(key) in list(full_done_spec.keys(True, True)): + del full_done_spec[key] + continue + # reward_spec + if unravel_key(key) in list(full_reward_spec.keys(True, True)): + del full_reward_spec[key] + continue + # observation_spec + if unravel_key(key) in list(full_observation_spec.keys(True, True)): + del full_observation_spec[key] + continue + raise KeyError(f"Key {key} not found in the environment outputs.") return output_spec diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 67636523e46..49cf58f8103 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -159,6 +159,10 @@ def __init__( self.keys_from_next = self._repr_key_list_as_tree(self.keys_from_next) self.validated = None + # Model based envs can have missing keys + # TODO: do we want to always allow this? check_env_specs should catch these or downstream ops + self._allow_absent_keys = True + def validate(self, tensordict): if self.validated: return True @@ -204,7 +208,11 @@ def _repr_key_list_as_tree(key_list): @classmethod def _grab_and_place( - cls, nested_key_dict: dict, data_in: TensorDictBase, data_out: TensorDictBase + cls, + nested_key_dict: dict, + data_in: TensorDictBase, + data_out: TensorDictBase, + _allow_absent_keys: bool, ): for key, subdict in nested_key_dict.items(): val = data_in._get_str(key, NO_DEFAULT) @@ -216,7 +224,12 @@ def _grab_and_place( val = LazyStackedTensorDict( *( - cls._grab_and_place(subdict, _val, _val_out) + cls._grab_and_place( + subdict, + _val, + _val_out, + _allow_absent_keys=_allow_absent_keys, + ) for (_val, _val_out) in zip( val.unbind(val.stack_dim), val_out.unbind(val_out.stack_dim), @@ -225,10 +238,16 @@ def _grab_and_place( stack_dim=val.stack_dim, ) else: - val = cls._grab_and_place(subdict, val, val_out) - data_out._set_str( - key, val, validated=True, inplace=False, non_blocking=False - ) + val = cls._grab_and_place( + subdict, val, val_out, _allow_absent_keys=_allow_absent_keys + ) + if val is NO_DEFAULT: + if not _allow_absent_keys: + raise KeyError(f"key {key} not found.") + else: + data_out._set_str( + key, val, validated=True, inplace=False, non_blocking=False + ) return data_out @classmethod @@ -275,8 +294,18 @@ def __call__(self, tensordict): out = self._exclude(self.exclude_from_root, tensordict, out=None) else: out = next_td.empty() - self._grab_and_place(self.keys_from_root, tensordict, out) - self._grab_and_place(self.keys_from_next, next_td, out) + self._grab_and_place( + self.keys_from_root, + tensordict, + out, + _allow_absent_keys=self._allow_absent_keys, + ) + self._grab_and_place( + self.keys_from_next, + next_td, + out, + _allow_absent_keys=self._allow_absent_keys, + ) return out else: out = next_td.empty() diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 8f20f53fe5b..a987e701672 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -42,6 +42,7 @@ reset_noise, RSSMPosterior, RSSMPrior, + RSSMRollout, Squeeze2dLayer, SqueezeLayer, VDNMixer, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 7e8ace40dcd..7b11cae9515 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -8,7 +8,14 @@ from .decision_transformer import DecisionTransformer from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise -from .model_based import DreamerActor, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior +from .model_based import ( + DreamerActor, + ObsDecoder, + ObsEncoder, + RSSMPosterior, + RSSMPrior, + RSSMRollout, +) from .models import ( Conv2dNet, Conv3dNet, diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 6196d69c543..f8ee69363d9 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -6,13 +6,22 @@ import torch from packaging import version -from tensordict.nn import TensorDictModule, TensorDictModuleBase +from tensordict import LazyStackedTensorDict +from tensordict.nn import ( + NormalParamExtractor, + TensorDictModule, + TensorDictModuleBase, + TensorDictSequential, +) from torch import nn -from torchrl.envs.utils import step_mdp -from torchrl.modules.distributions import NormalParamWrapper +# from torchrl.modules.tensordict_module.rnn import GRUCell +from torch.nn import GRUCell +from torchrl._utils import timeit + from torchrl.modules.models.models import MLP -from torchrl.modules.tensordict_module.sequence import SafeSequential + +UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11") class DreamerActor(nn.Module): @@ -49,14 +58,17 @@ def __init__( std_min_val=1e-4, ): super().__init__() - self.backbone = NormalParamWrapper( - MLP( - out_features=2 * out_features, - depth=depth, - num_cells=num_cells, - activation_class=activation_class, + self.backbone = MLP( + out_features=2 * out_features, + depth=depth, + num_cells=num_cells, + activation_class=activation_class, + ) + self.backbone.append( + NormalParamExtractor( + scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", + # scale_mapping="relu", ), - scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", ) def forward(self, state, belief): @@ -67,7 +79,7 @@ def forward(self, state, belief): class ObsEncoder(nn.Module): """Observation encoder network. - Takes an pixel observation and encodes it into a latent space. + Takes a pixel observation and encodes it into a latent space. Reference: https://arxiv.org/abs/1803.10122 @@ -205,7 +217,7 @@ class RSSMRollout(TensorDictModuleBase): def __init__(self, rssm_prior: TensorDictModule, rssm_posterior: TensorDictModule): super().__init__() - _module = SafeSequential(rssm_prior, rssm_posterior) + _module = TensorDictSequential(rssm_prior, rssm_posterior) self.in_keys = _module.in_keys self.out_keys = _module.out_keys self.rssm_prior = rssm_prior @@ -231,26 +243,30 @@ def forward(self, tensordict): """ tensordict_out = [] *batch, time_steps = tensordict.shape - _tensordict = tensordict[..., 0] - update_values = tensordict.exclude(*self.out_keys) + update_values = tensordict.exclude(*self.out_keys).unbind(-1) + _tensordict = update_values[0] for t in range(time_steps): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] - self.rssm_prior(_tensordict) + with timeit("rssm_rollout/time-rssm_prior"): + self.rssm_prior(_tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] - self.rssm_posterior(_tensordict) + with timeit("rssm_rollout/time-rssm_posterior"): + self.rssm_posterior(_tensordict) tensordict_out.append(_tensordict) if t < time_steps - 1: - _tensordict = step_mdp( - _tensordict.select(*self.out_keys, strict=False), keep_other=False - ) - _tensordict = update_values[..., t + 1].update(_tensordict) + _tensordict = _tensordict.select(*self.in_keys, strict=False) + _tensordict = update_values[t + 1].update(_tensordict) - return torch.stack(tensordict_out, tensordict.ndimension() - 1).contiguous() + out = torch.stack(tensordict_out, tensordict.ndim - 1) + assert not any( + isinstance(val, LazyStackedTensorDict) for val in out.values(True) + ), out + return out class RSSMPrior(nn.Module): @@ -287,30 +303,27 @@ def __init__( super().__init__() # Prior - self.rnn = nn.GRUCell(hidden_dim, rnn_hidden_dim) + self.rnn = GRUCell(hidden_dim, rnn_hidden_dim) self.action_state_projector = nn.Sequential(nn.LazyLinear(hidden_dim), nn.ELU()) - self.rnn_to_prior_projector = NormalParamWrapper( - nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), - nn.ELU(), - nn.Linear(hidden_dim, 2 * state_dim), + self.rnn_to_prior_projector = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ELU(), + nn.Linear(hidden_dim, 2 * state_dim), + NormalParamExtractor( + scale_lb=scale_lb, + scale_mapping="softplus", ), - scale_lb=scale_lb, - scale_mapping="softplus", ) self.state_dim = state_dim self.rnn_hidden_dim = rnn_hidden_dim self.action_shape = action_spec.shape - self._unsqueeze_rnn_input = version.parse(torch.__version__) < version.parse( - "1.11" - ) def forward(self, state, belief, action): projector_input = torch.cat([state, action], dim=-1) action_state = self.action_state_projector(projector_input) unsqueeze = False - if self._unsqueeze_rnn_input and action_state.ndimension() == 1: + if UNSQUEEZE_RNN_INPUT and action_state.ndimension() == 1: if belief is not None: belief = belief.unsqueeze(0) action_state = action_state.unsqueeze(0) @@ -344,14 +357,14 @@ class RSSMPosterior(nn.Module): def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1): super().__init__() - self.obs_rnn_to_post_projector = NormalParamWrapper( - nn.Sequential( - nn.LazyLinear(hidden_dim), - nn.ELU(), - nn.Linear(hidden_dim, 2 * state_dim), + self.obs_rnn_to_post_projector = nn.Sequential( + nn.LazyLinear(hidden_dim), + nn.ELU(), + nn.Linear(hidden_dim, 2 * state_dim), + NormalParamExtractor( + scale_lb=scale_lb, + scale_mapping="softplus", ), - scale_lb=scale_lb, - scale_mapping="softplus", ) self.hidden_dim = hidden_dim diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 0c6ce1418ed..30f6dd10699 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -12,6 +12,7 @@ from tensordict.nn import TensorDictModule from tensordict.utils import NestedKey +from torchrl._utils import timeit from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule @@ -19,6 +20,7 @@ _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, + # distance_loss, hold_out_net, ValueEstimators, ) @@ -112,6 +114,8 @@ def __init__( self.free_nats = free_nats self.delayed_clamp = delayed_clamp self.global_average = global_average + self.__dict__["decoder"] = self.world_model[0][-1] + self.__dict__["reward_model"] = self.world_model[1] def _forward_value_estimator_keys(self, **kwargs) -> None: pass @@ -148,6 +152,7 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: reward_loss = reward_loss.squeeze(-1) reward_loss = reward_loss.mean().unsqueeze(-1) # import ipdb; ipdb.set_trace() + return ( TensorDict( { @@ -160,6 +165,12 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: tensordict.detach(), ) + @staticmethod + def normal_log_probability(x, mean, std): + return ( + -0.5 * ((x.to(mean.dtype) - mean) / std).pow(2) - std.log() + ) # - 0.5 * math.log(2 * math.pi) + def kl_loss( self, prior_mean: torch.Tensor, @@ -237,13 +248,13 @@ def __init__( model_based_env: DreamerEnv, *, imagination_horizon: int = 15, - discount_loss: bool = False, # for consistency with paper + discount_loss: bool = True, # for consistency with paper gamma: int = None, lmbda: int = None, ): super().__init__() self.actor_model = actor_model - self.value_model = value_model + self.__dict__["value_model"] = value_model self.model_based_env = model_based_env self.imagination_horizon = imagination_horizon self.discount_loss = discount_loss @@ -259,14 +270,13 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: ) def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: - with torch.no_grad(): - tensordict = tensordict.select("state", self.tensor_keys.belief) - tensordict = tensordict.reshape(-1) - - with hold_out_net(self.model_based_env), set_exploration_type( - ExplorationType.RANDOM - ): - tensordict = self.model_based_env.reset(tensordict.clone(recurse=False)) + tensordict = tensordict.select("state", self.tensor_keys.belief).detach() + tensordict = tensordict.reshape(-1) + + with timeit("actor_loss/time-rollout"), hold_out_net( + self.model_based_env + ), set_exploration_type(ExplorationType.RANDOM): + tensordict = self.model_based_env.reset(tensordict.copy()) fake_data = self.model_based_env.rollout( max_steps=self.imagination_horizon, policy=self.actor_model, @@ -274,10 +284,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict=tensordict, ) - next_tensordict = step_mdp( - fake_data, - keep_other=True, - ) + next_tensordict = step_mdp(fake_data, keep_other=True) with hold_out_net(self.value_model): next_tensordict = self.value_model(next_tensordict) @@ -342,6 +349,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = TDLambdaEstimator( **hp, value_network=value_net, + vectorized=True, # TODO: vectorized version seems not to be similar to the non vectorised ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -391,7 +399,7 @@ def __init__( self, value_model: TensorDictModule, value_loss: Optional[str] = None, - discount_loss: bool = False, # for consistency with paper + discount_loss: bool = True, # for consistency with paper gamma: int = 0.99, ): super().__init__() @@ -435,6 +443,5 @@ def forward(self, fake_data) -> torch.Tensor: .sum((-1, -2)) .mean() ) - loss_tensordict = TensorDict({"loss_value": value_loss}, []) return loss_tensordict, fake_data diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 6bcd3f50c86..3f188a02a61 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -40,6 +40,8 @@ def add_scalar(self, name: str, value: float, global_step: Optional[int] = None) value = float(value) self.scalars[name].append((global_step, value)) filepath = os.path.join(self.log_dir, "scalars", "".join([name, ".csv"])) + if not os.path.isfile(filepath): + os.makedirs(Path(filepath).parent, exist_ok=True) if filepath not in self.files: self.files[filepath] = open(filepath, "a") fd = self.files[filepath] @@ -95,6 +97,8 @@ def add_text(self, tag, text, global_step: Optional[int] = None): filepath = os.path.join( self.log_dir, "texts", "".join([tag, str(global_step)]) + ".txt" ) + if not os.path.isfile(filepath): + os.makedirs(Path(filepath).parent, exist_ok=True) if filepath not in self.files: self.files[filepath] = open(filepath, "w+") fd = self.files[filepath]