diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index c84c19406f3..163fec36721 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -45,9 +45,15 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans optim.updates_per_episode=3 \ optim.warmup_steps=10 \ optim.device=cuda:0 \ - logger.backend= \ - env.backend=gymnasium \ - env.name=HalfCheetah-v4 + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_offline.py \ + optim.gradient_steps=55 \ + optim.device=cuda:0 \ + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_offline.py \ + optim.gradient_steps=55 \ + optim.device=cuda:0 \ + logger.backend= # ==================================================================================== # # ================================ Gymnasium ========================================= # @@ -115,7 +121,6 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_c collector.frames_per_batch=16 \ collector.env_per_collector=2 \ collector.device=cuda:0 \ - optim.optim_steps_per_batch=1 \ replay_buffer.size=120 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ @@ -174,11 +179,20 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ collector.total_frames=48 \ - buffer.batch_size=10 \ + optim.batch_size=10 \ collector.frames_per_batch=16 \ - collector.env_per_collector=2 \ + env.train_num_envs=2 \ + optim.device=cuda:0 \ collector.device=cuda:0 \ - network.device=cuda:0 \ + logger.mode=offline \ + logger.backend= + python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \ + collector.total_frames=48 \ + optim.batch_size=10 \ + collector.frames_per_batch=16 \ + env.train_num_envs=2 \ + collector.device=cuda:0 \ + optim.device=cuda:0 \ logger.mode=offline \ logger.backend= @@ -248,12 +262,21 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ collector.total_frames=48 \ + optim.batch_size=10 \ collector.frames_per_batch=16 \ - collector.env_per_collector=1 \ + env.train_num_envs=1 \ + logger.mode=offline \ + optim.device=cuda:0 \ collector.device=cuda:0 \ - network.device=cuda:0 \ - buffer.batch_size=10 \ + logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \ + collector.total_frames=48 \ + optim.batch_size=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=1 \ logger.mode=offline \ + optim.device=cuda:0 \ + collector.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ collector.total_frames=48 \ diff --git a/examples/cql/cql_offline.py b/examples/cql/cql_offline.py index 122dd2579b8..87739763fa2 100644 --- a/examples/cql/cql_offline.py +++ b/examples/cql/cql_offline.py @@ -10,6 +10,8 @@ """ +import time + import hydra import numpy as np import torch @@ -18,16 +20,18 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + log_metrics, + make_continuous_cql_optimizer, + make_continuous_loss, make_cql_model, - make_cql_optimizer, make_environment, - make_loss, make_offline_replay_buffer, ) @hydra.main(config_path=".", config_name="offline_config", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 + # Create logger exp_name = generate_exp_name("CQL-offline", cfg.env.exp_name) logger = None if cfg.logger.backend: @@ -37,49 +41,96 @@ def main(cfg: "DictConfig"): # noqa: F821 experiment_name=exp_name, wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, ) - + # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) device = torch.device(cfg.optim.device) - # Make Env + # Create env train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs) - # Make Buffer + # Create replay buffer replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) - # Make Model + # Create agent model = make_cql_model(cfg, train_env, eval_env, device) - # Make Loss - loss_module, target_net_updater = make_loss(cfg.loss, model) + # Create loss + loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) - # Make Optimizer - optimizer = make_cql_optimizer(cfg.optim, loss_module) + # Create Optimizer + ( + policy_optim, + critic_optim, + alpha_optim, + alpha_prime_optim, + ) = make_continuous_cql_optimizer(cfg, loss_module) pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) - r0 = None - l0 = None - gradient_steps = cfg.optim.gradient_steps + policy_eval_start = cfg.optim.policy_eval_start evaluation_interval = cfg.logger.eval_iter eval_steps = cfg.logger.eval_steps + # Training loop + start_time = time.time() for i in range(gradient_steps): pbar.update(i) + # sample data data = replay_buffer.sample() - # loss - loss_vals = loss_module(data) - # backprop - actor_loss = loss_vals["loss_actor"] + # compute loss + loss_vals = loss_module(data.clone().to(device)) + + # official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks + if i >= policy_eval_start: + actor_loss = loss_vals["loss_actor"] + else: + actor_loss = loss_vals["loss_actor_bc"] q_loss = loss_vals["loss_qvalue"] - value_loss = loss_vals["loss_value"] - loss_val = actor_loss + q_loss + value_loss + cql_loss = loss_vals["loss_cql"] + + q_loss = q_loss + cql_loss + + alpha_loss = loss_vals["loss_alpha"] + alpha_prime_loss = loss_vals["loss_alpha_prime"] + + # update model + alpha_loss = loss_vals["loss_alpha"] + alpha_prime_loss = loss_vals["loss_alpha_prime"] + + alpha_optim.zero_grad() + alpha_loss.backward() + alpha_optim.step() - optimizer.zero_grad() - loss_val.backward() - optimizer.step() + policy_optim.zero_grad() + actor_loss.backward() + policy_optim.step() + + if alpha_prime_optim is not None: + alpha_prime_optim.zero_grad() + alpha_prime_loss.backward(retain_graph=True) + alpha_prime_optim.step() + + critic_optim.zero_grad() + # TODO: we have the option to compute losses independently retain is not needed? + q_loss.backward(retain_graph=False) + critic_optim.step() + + loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss + + # log metrics + to_log = { + "loss": loss.item(), + "loss_actor_bc": loss_vals["loss_actor_bc"].item(), + "loss_actor": loss_vals["loss_actor"].item(), + "loss_qvalue": q_loss.item(), + "loss_cql": cql_loss.item(), + "loss_alpha": alpha_loss.item(), + "loss_alpha_prime": alpha_prime_loss.item(), + } + + # update qnet_target params target_net_updater.step() # evaluation @@ -88,20 +139,13 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) + eval_reward = eval_td["next", "reward"].sum(1).mean().item() + to_log["evaluation_reward"] = eval_reward - if r0 is None: - r0 = eval_td["next", "reward"].sum(1).mean().item() - if l0 is None: - l0 = loss_val.item() - - for key, value in loss_vals.items(): - logger.log_scalar(key, value.item(), i) - eval_reward = eval_td["next", "reward"].sum(1).mean().item() - logger.log_scalar("evaluation_reward", eval_reward, i) + log_metrics(logger, to_log, i) - pbar.set_description( - f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), evaluation_reward: {eval_reward: 4.4f} (init={r0: 4.4f})" - ) + pbar.close() + print(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/cql/cql_online.py b/examples/cql/cql_online.py index beb1a71201d..93427a0d8cf 100644 --- a/examples/cql/cql_online.py +++ b/examples/cql/cql_online.py @@ -6,29 +6,36 @@ This is a self-contained example of an online CQL training script. +It works across Gym and MuJoCo over a variety of tasks. + The helper functions are coded in the utils.py associated with this script. """ +import time + import hydra import numpy as np import torch import tqdm +from tensordict import TensorDict from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + log_metrics, make_collector, + make_continuous_cql_optimizer, + make_continuous_loss, make_cql_model, - make_cql_optimizer, make_environment, - make_loss, make_replay_buffer, ) @hydra.main(version_base="1.1", config_path=".", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 + # Create logger exp_name = generate_exp_name("CQL-online", cfg.env.exp_name) logger = None if cfg.logger.backend: @@ -39,14 +46,19 @@ def main(cfg: "DictConfig"): # noqa: F821 wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, ) + # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) device = torch.device(cfg.optim.device) - # Make Env - train_env, eval_env = make_environment(cfg, cfg.collector.env_per_collector) + # Create env + train_env, eval_env = make_environment( + cfg, + cfg.env.train_num_envs, + cfg.env.eval_num_envs, + ) - # Make Buffer + # Create replay buffer replay_buffer = make_replay_buffer( batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, @@ -54,25 +66,26 @@ def main(cfg: "DictConfig"): # noqa: F821 device="cpu", ) - # Make Model + # create agent model = make_cql_model(cfg, train_env, eval_env, device) - # Make Collector + # Create collector collector = make_collector(cfg, train_env, actor_model_explore=model[0]) - # Make Loss - loss_module, target_net_updater = make_loss(cfg.loss, model) + # Create loss + loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) - # Make Optimizer - optimizer = make_cql_optimizer(cfg.optim, loss_module) - - rewards = [] - rewards_eval = [] + # Create optimizer + ( + policy_optim, + critic_optim, + alpha_optim, + alpha_prime_optim, + ) = make_continuous_cql_optimizer(cfg, loss_module) # Main loop + start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) - r0 = None - q_loss = None init_random_frames = cfg.collector.init_random_frames num_updates = int( @@ -81,28 +94,28 @@ def main(cfg: "DictConfig"): # noqa: F821 * cfg.optim.utd_ratio ) prb = cfg.replay_buffer.prb - env_per_collector = cfg.collector.env_per_collector eval_iter = cfg.logger.eval_iter - frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip - eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip + frames_per_batch = cfg.collector.frames_per_batch + eval_rollout_steps = cfg.collector.max_frames_per_traj - for i, tensordict in enumerate(collector): + sampling_start = time.time() + for tensordict in collector: + sampling_time = time.time() - sampling_start + pbar.update(tensordict.numel()) # update weights of the inference policy collector.update_policy_weights_() - if r0 is None: - r0 = tensordict["next", "reward"].sum(-1).mean().item() - pbar.update(tensordict.numel()) - tensordict = tensordict.view(-1) current_frames = tensordict.numel() + # add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames # optimization steps + training_start = time.time() if collected_frames >= init_random_frames: - (actor_losses, q_losses, alpha_losses, alpha_primes) = ([], [], [], []) - for _ in range(num_updates): + log_loss_td = TensorDict({}, [num_updates]) + for j in range(num_updates): # sample from replay buffer sampled_tensordict = replay_buffer.sample() if sampled_tensordict.device != device: @@ -116,18 +129,29 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_loss = loss_td["loss_actor"] q_loss = loss_td["loss_qvalue"] + cql_loss = loss_td["loss_cql"] + q_loss = q_loss + cql_loss alpha_loss = loss_td["loss_alpha"] alpha_prime_loss = loss_td["loss_alpha_prime"] - loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss - optimizer.zero_grad() - loss.backward() - optimizer.step() + alpha_optim.zero_grad() + alpha_loss.backward() + alpha_optim.step() + + policy_optim.zero_grad() + actor_loss.backward() + policy_optim.step() + + if alpha_prime_optim is not None: + alpha_prime_optim.zero_grad() + alpha_prime_loss.backward(retain_graph=True) + alpha_prime_optim.step() - q_losses.append(q_loss.item()) - actor_losses.append(actor_loss.item()) - alpha_losses.append(alpha_loss.item()) - alpha_primes.append(alpha_prime_loss.item()) + critic_optim.zero_grad() + q_loss.backward(retain_graph=False) + critic_optim.step() + + log_loss_td[j] = loss_td.detach() # update qnet_target params target_net_updater.step() @@ -136,45 +160,53 @@ def main(cfg: "DictConfig"): # noqa: F821 if prb: replay_buffer.update_priority(sampled_tensordict) - rewards.append( - (i, tensordict["next", "reward"].sum().item() / env_per_collector) - ) - train_log = { - "train_reward": rewards[-1][1], - "collected_frames": collected_frames, - } - if q_loss is not None: - train_log.update( - { - "actor_loss": np.mean(actor_losses), - "q_loss": np.mean(q_losses), - "alpha_loss": np.mean(alpha_losses), - "alpha_prime_loss": np.mean(alpha_primes), - "entropy": loss_td["entropy"], - } + training_time = time.time() - training_start + episode_rewards = tensordict["next", "episode_reward"][ + tensordict["next", "done"] + ] + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length ) - if logger is not None: - for key, value in train_log.items(): - logger.log_scalar(key, value, step=collected_frames) - if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: + if collected_frames >= init_random_frames: + metrics_to_log["train/loss_actor"] = log_loss_td.get("loss_actor").mean() + metrics_to_log["train/loss_qvalue"] = log_loss_td.get("loss_qvalue").mean() + metrics_to_log["train/loss_alpha"] = log_loss_td.get("loss_alpha").mean() + metrics_to_log["train/loss_alpha_prime"] = log_loss_td.get( + "loss_alpha_prime" + ).mean() + metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean() + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], auto_cast_to_device=True, break_when_any_done=True, ) + eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - rewards_eval.append((i, eval_reward)) - eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - if logger is not None: - logger.log_scalar( - "evaluation_reward", rewards_eval[-1][1], step=collected_frames - ) - if len(rewards_eval): - pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str - ) + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") collector.shutdown() diff --git a/examples/cql/discrete_cql_config.yaml b/examples/cql/discrete_cql_config.yaml index 1bfbb6916e9..2b449629d16 100644 --- a/examples/cql/discrete_cql_config.yaml +++ b/examples/cql/discrete_cql_config.yaml @@ -2,7 +2,7 @@ env: name: CartPole-v1 task: "" - library: gym + backend: gym exp_name: cql_cartpole_gym n_samples_stats: 1000 max_episode_steps: 200 @@ -20,7 +20,8 @@ collector: annealing_frames: 10000 eps_start: 1.0 eps_end: 0.01 -# logger + +# Logger logger: backend: wandb log_interval: 5000 # record interval in frames @@ -42,8 +43,6 @@ optim: lr: 1e-3 weight_decay: 0.0 batch_size: 256 - lr_scheduler: "" - optim_steps_per_batch: 200 # Policy and model model: diff --git a/examples/cql/discrete_cql_online.py b/examples/cql/discrete_cql_online.py index 5dfde6a082d..0c93875ec9c 100644 --- a/examples/cql/discrete_cql_online.py +++ b/examples/cql/discrete_cql_online.py @@ -25,9 +25,9 @@ from utils import ( log_metrics, make_collector, - make_cql_optimizer, + make_discrete_cql_optimizer, + make_discrete_loss, make_discretecql_model, - make_discreteloss, make_environment, make_replay_buffer, ) @@ -59,7 +59,7 @@ def main(cfg: "DictConfig"): # noqa: F821 model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device) # Create loss - loss_module, target_net_updater = make_discreteloss(cfg.loss, model) + loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) # Create off-policy collector collector = make_collector(cfg, train_env, explore_policy) @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - optimizer = make_cql_optimizer(cfg, loss_module) + optimizer = make_discrete_cql_optimizer(cfg, loss_module) # Main loop collected_frames = 0 diff --git a/examples/cql/offline_config.yaml b/examples/cql/offline_config.yaml index 1c8d24073bf..517da255481 100644 --- a/examples/cql/offline_config.yaml +++ b/examples/cql/offline_config.yaml @@ -1,37 +1,36 @@ -# Task and env +# env and task env: - name: HalfCheetah-v2 + name: Hopper-v2 task: "" library: gym - exp_name: cql_halfcheetah-medium-v2 + exp_name: cql_${replay_buffer.dataset} n_samples_stats: 1000 - frame_skip: 1 - reward_scaling: 1.0 - noop: 1 seed: 0 + backend: gym # D4RL uses gym so we make sure gymnasium is hidden # logger logger: backend: wandb - eval_iter: 500 + eval_iter: 5000 eval_steps: 1000 mode: online eval_envs: 5 -# Buffer +# replay buffer replay_buffer: - dataset: halfcheetah-medium-v2 + dataset: hopper-medium-v2 batch_size: 256 -# Optimization +# optimization optim: device: cuda:0 - lr: 3e-4 + actor_lr: 3e-4 + critic_lr: 3e-4 weight_decay: 0.0 - batch_size: 256 - gradient_steps: 100000 + gradient_steps: 1_000_000 + policy_eval_start: 40_000 -# Policy and model +# policy and model model: hidden_sizes: [256, 256] activation: relu @@ -40,14 +39,14 @@ model: # loss loss: - loss_function: smooth_l1 + loss_function: l2 gamma: 0.99 tau: 0.005 - # CQL hyperparameter +# CQL specific hyperparameter temperature: 1.0 min_q_weight: 1.0 max_q_backup: False - deterministic_backup: True + deterministic_backup: False num_random: 10 - with_lagrange: False - lagrange_thresh: 0.0 + with_lagrange: True + lagrange_thresh: 5.0 # tau diff --git a/examples/cql/online_config.yaml b/examples/cql/online_config.yaml index 4528fe3fb8d..6c29820856b 100644 --- a/examples/cql/online_config.yaml +++ b/examples/cql/online_config.yaml @@ -2,30 +2,29 @@ env: name: Pendulum-v1 task: "" - library: gym - exp_name: cql_pendulum_gym - record_video: 0 + exp_name: cql_${env.name} n_samples_stats: 1000 - frame_skip: 1 - reward_scaling: 1.0 - noop: 1 seed: 0 + train_num_envs: 1 + eval_num_envs: 1 + backend: gym # Collector collector: - frames_per_batch: 200 - total_frames: 1000000 + frames_per_batch: 1000 + total_frames: 20000 multi_step: 0 - init_random_frames: 1000 + init_random_frames: 5_000 env_per_collector: 1 device: cpu - max_frames_per_traj: 200 + max_frames_per_traj: 1000 + # logger logger: backend: wandb log_interval: 5000 # record interval in frames - eval_steps: 200 + eval_steps: 1000 mode: online eval_iter: 1000 @@ -39,10 +38,10 @@ replay_buffer: optim: utd_ratio: 1 device: cuda:0 - lr: 3e-4 + actor_lr: 3e-4 + critic_lr: 3e-4 weight_decay: 0.0 batch_size: 256 - lr_scheduler: "" optim_steps_per_batch: 200 # Policy and model @@ -54,7 +53,7 @@ model: # loss loss: - loss_function: smooth_l1 + loss_function: l2 gamma: 0.99 tau: 0.005 # CQL hyperparameter diff --git a/examples/cql/utils.py b/examples/cql/utils.py index c64e9d62db7..f14d3784577 100644 --- a/examples/cql/utils.py +++ b/examples/cql/utils.py @@ -1,3 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import torch.nn import torch.optim from tensordict.nn import TensorDictModule, TensorDictSequential @@ -42,7 +46,7 @@ def env_maker(cfg, device="cpu"): - lib = cfg.env.library + lib = cfg.env.backend if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( @@ -58,11 +62,12 @@ def env_maker(cfg, device="cpu"): raise NotImplementedError(f"Unknown lib {lib}.") -def apply_env_transforms(env, reward_scaling=1.0): +def apply_env_transforms( + env, +): transformed_env = TransformedEnv( env, Compose( - # RewardScaling(loc=0.0, scale=reward_scaling), DoubleToFloat(), RewardSum(), ), @@ -70,10 +75,10 @@ def apply_env_transforms(env, reward_scaling=1.0): return transformed_env -def make_environment(cfg, num_envs=1): +def make_environment(cfg, train_num_envs=1, eval_num_envs=1): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( - num_envs, + train_num_envs, EnvCreator(lambda cfg=cfg: env_maker(cfg)), ) parallel_env.set_seed(cfg.env.seed) @@ -82,7 +87,7 @@ def make_environment(cfg, num_envs=1): eval_env = TransformedEnv( ParallelEnv( - num_envs, + eval_num_envs, EnvCreator(lambda cfg=cfg: env_maker(cfg)), ), train_env.transform.clone(), @@ -100,6 +105,7 @@ def make_collector(cfg, train_env, actor_model_explore): collector = SyncDataCollector( train_env, actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, @@ -150,6 +156,8 @@ def make_offline_replay_buffer(rb_cfg): split_trajs=False, batch_size=rb_cfg.batch_size, sampler=SamplerWithoutReplacement(drop_last=False), + prefetch=4, + direct_download=True, ) data.append_transform(DoubleToFloat()) @@ -187,8 +195,10 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"): spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "min": action_spec.space.low, - "max": action_spec.space.high, + "min": action_spec.space.low[len(train_env.batch_size) :], + "max": action_spec.space.high[ + len(train_env.batch_size) : + ], # remove batch-size "tanh_loc": False, }, default_interaction_type=ExplorationType.RANDOM, @@ -284,7 +294,7 @@ def make_cql_modules_state(model_cfg, proof_environment): # --------- -def make_loss(loss_cfg, model): +def make_continuous_loss(loss_cfg, model): loss_module = CQLLoss( model[0], model[1], @@ -303,16 +313,7 @@ def make_loss(loss_cfg, model): return loss_module, target_net_updater -def make_cql_optimizer(cfg, loss_module): - optim = torch.optim.Adam( - loss_module.parameters(), - lr=cfg.optim.lr, - weight_decay=cfg.optim.weight_decay, - ) - return optim - - -def make_discreteloss(loss_cfg, model): +def make_discrete_loss(loss_cfg, model): loss_module = DiscreteCQLLoss( model, loss_function=loss_cfg.loss_function, @@ -325,6 +326,48 @@ def make_discreteloss(loss_cfg, model): return loss_module, target_net_updater +def make_discrete_cql_optimizer(cfg, loss_module): + optim = torch.optim.Adam( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + ) + return optim + + +def make_continuous_cql_optimizer(cfg, loss_module): + critic_params = loss_module.qvalue_network_params.flatten_keys().values() + actor_params = loss_module.actor_network_params.flatten_keys().values() + actor_optim = torch.optim.Adam( + actor_params, + lr=cfg.optim.actor_lr, + weight_decay=cfg.optim.weight_decay, + ) + critic_optim = torch.optim.Adam( + critic_params, + lr=cfg.optim.critic_lr, + weight_decay=cfg.optim.weight_decay, + ) + alpha_optim = torch.optim.Adam( + [loss_module.log_alpha], + lr=cfg.optim.actor_lr, + weight_decay=cfg.optim.weight_decay, + ) + if loss_module.with_lagrange: + alpha_prime_optim = torch.optim.Adam( + [loss_module.log_alpha_prime], + lr=cfg.optim.critic_lr, + ) + else: + alpha_prime_optim = None + return actor_optim, critic_optim, alpha_optim, alpha_prime_optim + + +# ==================================================================== +# General utils +# --------- + + def log_metrics(logger, metrics, step): if logger is not None: for metric_name, metric_value in metrics.items(): diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index d0600f66efe..5232901a114 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -241,6 +241,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): transform=transforms, use_truncated_as_done=True, direct_download=True, + prefetch=4, ) loc = ( data._storage._storage.get(("_data", "observation")) diff --git a/examples/iql/iql_offline.py b/examples/iql/iql_offline.py new file mode 100644 index 00000000000..f6612048318 --- /dev/null +++ b/examples/iql/iql_offline.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""IQL Example. + +This is a self-contained example of an offline IQL training script. + +The helper functions are coded in the utils.py associated with this script. + +""" +import time + +import hydra +import numpy as np +import torch +import tqdm + +from torchrl.envs import set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.record.loggers import generate_exp_name, get_logger + +from utils import ( + log_metrics, + make_environment, + make_iql_model, + make_iql_optimizer, + make_loss, + make_offline_replay_buffer, +) + + +@set_gym_backend("gym") +@hydra.main(config_path=".", config_name="offline_config") +def main(cfg: "DictConfig"): # noqa: F821 + + # Create logger + exp_name = generate_exp_name("IQL-offline", cfg.env.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="iql_logging", + experiment_name=exp_name, + wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + ) + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + device = torch.device(cfg.optim.device) + + # Creante env + train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs) + + # Create replay buffer + replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + + # Create agent + model = make_iql_model(cfg, train_env, eval_env, device) + + # Create loss + loss_module, target_net_updater = make_loss(cfg.loss, model) + + # Create optimizer + optimizer = make_iql_optimizer(cfg.optim, loss_module) + + pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) + + gradient_steps = cfg.optim.gradient_steps + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + + # Training loop + start_time = time.time() + for i in range(gradient_steps): + pbar.update(i) + # sample data + data = replay_buffer.sample() + # compute loss + loss_vals = loss_module(data.clone().to(device)) + + actor_loss = loss_vals["loss_actor"] + q_loss = loss_vals["loss_qvalue"] + value_loss = loss_vals["loss_value"] + loss_val = actor_loss + q_loss + value_loss + + # update model + optimizer.zero_grad() + loss_val.backward() + optimizer.step() + target_net_updater.step() + + # log metrics + to_log = { + "loss_actor": actor_loss.item(), + "loss_qvalue": q_loss.item(), + "loss_value": value_loss.item(), + } + + # evaluation + if i % evaluation_interval == 0: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_td = eval_env.rollout( + max_steps=eval_steps, policy=model[0], auto_cast_to_device=True + ) + eval_reward = eval_td["next", "reward"].sum(1).mean().item() + to_log["evaluation_reward"] = eval_reward + if logger is not None: + log_metrics(logger, to_log, i) + + pbar.close() + print(f"Training time: {time.time() - start_time}") + + +if __name__ == "__main__": + main() diff --git a/examples/iql/iql_online.py b/examples/iql/iql_online.py index f27adc1789a..290c6c2a8de 100644 --- a/examples/iql/iql_online.py +++ b/examples/iql/iql_online.py @@ -2,361 +2,188 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""IQL Example. +This is a self-contained example of an online IQL training script. -import hydra +It works across Gym and MuJoCo over a variety of tasks. + +The helper functions are coded in the utils.py associated with this script. + +""" +import time + +import hydra import numpy as np import torch -import torch.cuda import tqdm -from tensordict.nn import TensorDictModule -from tensordict.nn.distributions import NormalParamExtractor - -from torch import nn, optim -from torchrl.collectors import SyncDataCollector -from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer - -from torchrl.data.replay_buffers.storages import LazyMemmapStorage -from torchrl.envs import ( - CatTensors, - DMControlEnv, - EnvCreator, - ParallelEnv, - TransformedEnv, -) -from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from tensordict import TensorDict from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import MLP, ProbabilisticActor, ValueOperator -from torchrl.modules.distributions import TanhNormal - -from torchrl.objectives import SoftUpdate -from torchrl.objectives.iql import IQLLoss from torchrl.record.loggers import generate_exp_name, get_logger - -def env_maker(cfg, device="cpu"): - lib = cfg.env.library - if lib in ("gym", "gymnasium"): - with set_gym_backend(lib): - return GymEnv( - cfg.env.name, - device=device, - frame_skip=cfg.env.frame_skip, - ) - elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task, frame_skip=cfg.env.frame_skip) - return TransformedEnv( - env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") - ) - else: - raise NotImplementedError(f"Unknown lib {lib}.") - - -def make_replay_buffer( - batch_size, - prb=False, - buffer_size=1000000, - buffer_scratch_dir=None, - device="cpu", - prefetch=3, -): - if prb: - replay_buffer = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.5, - pin_memory=False, - prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), - batch_size=batch_size, - ) - else: - replay_buffer = TensorDictReplayBuffer( - pin_memory=False, - prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), - batch_size=batch_size, - ) - return replay_buffer +from utils import ( + log_metrics, + make_collector, + make_environment, + make_iql_model, + make_iql_optimizer, + make_loss, + make_replay_buffer, +) -@hydra.main(version_base="1.1", config_path=".", config_name="online_config") +@hydra.main(config_path=".", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 - - device = torch.device(cfg.network.device) - - exp_name = generate_exp_name("Online_IQL", cfg.logger.exp_name) + # Create logger + exp_name = generate_exp_name("IQL-online", cfg.env.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="iql_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode}, + wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, ) - torch.manual_seed(cfg.optim.seed) - np.random.seed(cfg.optim.seed) - - def env_factory(num_workers): - """Creates an instance of the environment.""" - - # 1.2 Create env vector - vec_env = ParallelEnv( - create_env_fn=EnvCreator(lambda cfg=cfg: env_maker(cfg=cfg)), - num_workers=num_workers, - ) - - return vec_env - - # Sanity check - test_env = env_factory(num_workers=cfg.collector.env_per_collector) - num_actions = test_env.action_spec.shape[-1] - - # Create Agent - # Define Actor Network - in_keys = ["observation"] - action_spec = test_env.action_spec - actor_net_kwargs = { - "num_cells": [256, 256], - "out_features": 2 * num_actions, - "activation_class": nn.ReLU, - } - - actor_net = MLP(**actor_net_kwargs) + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + device = torch.device(cfg.optim.device) - dist_class = TanhNormal - dist_kwargs = { - "min": action_spec.space.low[-1], - "max": action_spec.space.high[-1], - "tanh_loc": cfg.network.tanh_loc, - } - - actor_extractor = NormalParamExtractor( - scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}", - scale_lb=cfg.network.scale_lb, - ) - - actor_net = nn.Sequential(actor_net, actor_extractor) - in_keys_actor = in_keys - actor_module = TensorDictModule( - actor_net, - in_keys=in_keys_actor, - out_keys=[ - "loc", - "scale", - ], - ) - actor = ProbabilisticActor( - spec=action_spec, - in_keys=["loc", "scale"], - module=actor_module, - distribution_class=dist_class, - distribution_kwargs=dist_kwargs, - default_interaction_type=ExplorationType.RANDOM, - return_log_prob=False, + # Create environments + train_env, eval_env = make_environment( + cfg, + cfg.env.train_num_envs, + cfg.env.eval_num_envs, ) - # Define Critic Network - qvalue_net_kwargs = { - "num_cells": [256, 256], - "out_features": 1, - "activation_class": nn.ReLU, - } - - qvalue_net = MLP( - **qvalue_net_kwargs, - ) - - qvalue = ValueOperator( - in_keys=["action"] + in_keys, - module=qvalue_net, - ) - - # Define Value Network - value_net_kwargs = { - "num_cells": [256, 256], - "out_features": 1, - "activation_class": nn.ReLU, - } - value_net = MLP(**value_net_kwargs) - value = ValueOperator( - in_keys=in_keys, - module=value_net, - ) - - model = nn.ModuleList([actor, qvalue, value]).to(device) - - # init nets - with torch.no_grad(): - td = test_env.reset() - td = td.to(device) - actor(td) - qvalue(td) - value(td) - - del td - test_env.close() - test_env.eval() - - # Create IQL loss - loss_module = IQLLoss( - actor_network=model[0], - qvalue_network=model[1], - value_network=model[2], - num_qvalue_nets=2, - temperature=cfg.loss.temperature, - expectile=cfg.loss.expectile, - loss_function=cfg.loss.loss_function, - ) - loss_module.make_value_estimator(gamma=cfg.loss.gamma) - - # Define Target Network Updater - target_net_updater = SoftUpdate(loss_module, eps=cfg.loss.target_update_polyak) - - # Make Off-Policy Collector - collector = SyncDataCollector( - env_factory, - create_env_kwargs={"num_workers": cfg.collector.env_per_collector}, - policy=model[0], - frames_per_batch=cfg.collector.frames_per_batch, - max_frames_per_traj=cfg.collector.max_frames_per_traj, - total_frames=cfg.collector.total_frames, - device=cfg.collector.device, - ) - collector.set_seed(cfg.optim.seed) - - # Make Replay Buffer + # Create replay buffer replay_buffer = make_replay_buffer( - buffer_size=cfg.buffer.size, + batch_size=cfg.optim.batch_size, + prb=cfg.replay_buffer.prb, + buffer_size=cfg.replay_buffer.size, device="cpu", - batch_size=cfg.buffer.batch_size, - prefetch=cfg.buffer.prefetch, - prb=cfg.buffer.prb, ) - # Optimizers - params = list(loss_module.parameters()) - optimizer = optim.Adam( - params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps - ) + # Create model + model = make_iql_model(cfg, train_env, eval_env, device) + + # Create collector + collector = make_collector(cfg, train_env, actor_model_explore=model[0]) - rewards = [] - rewards_eval = [] + # Create loss + loss_module, target_net_updater = make_loss(cfg.loss, model) + + # Create optimizer + optimizer = make_iql_optimizer(cfg.optim, loss_module) # Main loop collected_frames = 0 - pbar = tqdm.tqdm(total=cfg.collector.total_frames) - r0 = None - loss = None - num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) - env_per_collector = cfg.collector.env_per_collector - prb = cfg.buffer.prb - max_frames_per_traj = cfg.collector.max_frames_per_traj - - for i, tensordict in enumerate(collector): + init_random_frames = cfg.collector.init_random_frames + num_updates = int( + cfg.collector.env_per_collector + * cfg.collector.frames_per_batch + * cfg.optim.utd_ratio + ) + prb = cfg.replay_buffer.prb + eval_iter = cfg.logger.eval_iter + frames_per_batch = cfg.collector.frames_per_batch + eval_rollout_steps = cfg.collector.max_frames_per_traj + sampling_start = start_time = time.time() + for tensordict in collector: + sampling_time = time.time() - sampling_start + pbar.update(tensordict.numel()) # update weights of the inference policy collector.update_policy_weights_() - if r0 is None: - r0 = tensordict["next", "reward"].sum(-1).mean().item() - pbar.update(tensordict.numel()) - - if "mask" in tensordict.keys(): - # if multi-step, a mask is present to help filter padded values - current_frames = tensordict["mask"].sum() - tensordict = tensordict[tensordict.get("mask").squeeze(-1)] - else: - tensordict = tensordict.view(-1) - current_frames = tensordict.numel() + tensordict = tensordict.view(-1) + current_frames = tensordict.numel() + # add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames - ( - actor_losses, - q_losses, - value_losses, - ) = ([], [], []) # optimization steps - for _ in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device == device: - sampled_tensordict = sampled_tensordict.clone() - else: - sampled_tensordict = sampled_tensordict.to(device, non_blocking=True) - - loss_td = loss_module(sampled_tensordict) - - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] - value_loss = loss_td["loss_value"] - - loss = actor_loss + q_loss + value_loss - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - q_losses.append(q_loss.item()) - actor_losses.append(actor_loss.item()) - value_losses.append(value_loss.item()) - - # update qnet_target params - target_net_updater.step() - - # update priority - if prb: - replay_buffer.update_priority(sampled_tensordict) - - rewards.append( - (i, tensordict["next", "reward"].sum().item() / env_per_collector) - ) - train_log = { - "train_reward": rewards[-1][1], - "collected_frames": collected_frames, - } - if q_loss is not None: - train_log.update( - { - "actor_loss": np.mean(actor_losses), - "q_loss": np.mean(q_losses), - "value_loss": np.mean(value_losses), - } + training_start = time.time() + if collected_frames >= init_random_frames: + log_loss_td = TensorDict({}, [num_updates]) + for j in range(num_updates): + # sample from replay buffer + sampled_tensordict = replay_buffer.sample().clone() + if sampled_tensordict.device != device: + sampled_tensordict = sampled_tensordict.to( + device, non_blocking=True + ) + else: + sampled_tensordict = sampled_tensordict + # compute loss + loss_td = loss_module(sampled_tensordict) + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_qvalue"] + value_loss = loss_td["loss_value"] + loss = actor_loss + q_loss + value_loss + + # update model + optimizer.zero_grad() + loss.backward() + optimizer.step() + + log_loss_td[j] = loss_td.detach() + + # update qnet_target params + target_net_updater.step() + + # update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) + training_time = time.time() - training_start + episode_rewards = tensordict["next", "episode_reward"][ + tensordict["next", "done"] + ] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length ) - if logger is not None: - for key, value in train_log.items(): - logger.log_scalar(key, value, step=collected_frames) - - with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): - eval_rollout = test_env.rollout( - max_steps=max_frames_per_traj, - policy=model[0], - auto_cast_to_device=True, - ).clone() - eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - rewards_eval.append((i, eval_reward)) - eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - if logger is not None: - logger.log_scalar( - "test_reward", rewards_eval[-1][1], step=collected_frames + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = log_loss_td.get("loss_qvalue").detach() + metrics_to_log["train/actor_loss"] = log_loss_td.get("loss_actor").detach() + metrics_to_log["train/value_loss"] = log_loss_td.get("loss_value").detach() + metrics_to_log["train/entropy"] = log_loss_td.get("entropy").detach() + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model[0], + auto_cast_to_device=True, + break_when_any_done=True, ) - if len(rewards_eval): - pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str - ) + eval_time = time.time() - eval_start + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/iql/offline_config.yaml b/examples/iql/offline_config.yaml new file mode 100644 index 00000000000..473decb2980 --- /dev/null +++ b/examples/iql/offline_config.yaml @@ -0,0 +1,45 @@ +# env and task +env: + name: HalfCheetah-v2 + task: "" + library: gym + exp_name: iql_${replay_buffer.dataset} + n_samples_stats: 1000 + seed: 0 + +# logger +logger: + backend: wandb + eval_iter: 500 + eval_steps: 1000 + mode: online + eval_envs: 5 + +# replay buffer +replay_buffer: + dataset: halfcheetah-medium-v2 + batch_size: 256 + +# optimization +optim: + device: cuda:0 + lr: 3e-4 + weight_decay: 0.0 + gradient_steps: 50000 + +# network +model: + hidden_sizes: [256, 256] + activation: relu + default_policy_scale: 1.0 + scale_lb: 0.1 + +# loss +loss: + loss_function: l2 + gamma: 0.99 + tau: 0.005 + +# IQL specific hyperparameter + temperature: 3.0 + expectile: 0.7 diff --git a/examples/iql/online_config.yaml b/examples/iql/online_config.yaml index 350560ea9a1..e3ef0d081c4 100644 --- a/examples/iql/online_config.yaml +++ b/examples/iql/online_config.yaml @@ -1,46 +1,61 @@ +# task and env env: name: Pendulum-v1 - library: gym - async_collection: 1 - record_video: 0 - frame_skip: 1 + task: "" + exp_name: iql_${env.name} + n_samples_stats: 1000 + seed: 0 + train_num_envs: 1 + eval_num_envs: 1 + backend: gym + +# collector +collector: + frames_per_batch: 200 + total_frames: 20000 + multi_step: 0 + init_random_frames: 5000 + env_per_collector: 1 + device: cpu + max_frames_per_traj: 200 + +# logger logger: - exp_name: "iql_pendulum" backend: wandb + log_interval: 5000 # record interval in frames + eval_steps: 200 mode: online + eval_iter: 1000 +# replay buffer +replay_buffer: + prb: 0 + buffer_prefetch: 64 + size: 1_000_000 + +# optimization optim: - seed: 42 - utd_ratio: 1.0 + utd_ratio: 1 + device: cuda:0 lr: 3e-4 weight_decay: 0.0 - eps: 1e-4 + batch_size: 256 + optim_steps_per_batch: 200 -network: - tanh_loc: False +# network +model: + hidden_sizes: [256, 256] + activation: relu default_policy_scale: 1.0 scale_lb: 0.1 - device: "cuda:0" -collector: - total_frames: 1000000 - init_random_frames: 5000 - device: cuda:0 - frames_per_batch: 1000 # 5*200 - env_per_collector: 5 - max_frames_per_traj: 200 +# loss +loss: + loss_function: l2 + gamma: 0.99 + tau: 0.005 -# IQL hyperparameter -loss: +# IQL specific hyperparameter temperature: 3.0 expectile: 0.7 - gamma: 0.99 - target_update_polyak: 0.995 - loss_function: smooth_l1 - -buffer: - prefetch: 64 - prb: 0 - size: 100000 - batch_size: 256 diff --git a/examples/iql/utils.py b/examples/iql/utils.py new file mode 100644 index 00000000000..1dff1c7bd34 --- /dev/null +++ b/examples/iql/utils.py @@ -0,0 +1,287 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor + +from torchrl.collectors import SyncDataCollector +from torchrl.data import ( + LazyMemmapStorage, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.envs import ( + Compose, + DoubleToFloat, + EnvCreator, + InitTracker, + ParallelEnv, + RewardSum, + TransformedEnv, +) +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.objectives import IQLLoss, SoftUpdate + +from torchrl.trainers.helpers.models import ACTIVATIONS + + +# ==================================================================== +# Environment utils +# ----------------- + + +def env_maker(task, device="cpu", from_pixels=False): + with set_gym_backend("gym"): + return GymEnv(task, device=device, from_pixels=from_pixels) + + +def apply_env_transforms( + env, +): + transformed_env = TransformedEnv( + env, + Compose( + InitTracker(), + DoubleToFloat(), + RewardSum(), + ), + ) + return transformed_env + + +def make_environment(cfg, train_num_envs=1, eval_num_envs=1): + """Make environments for training and evaluation.""" + parallel_env = ParallelEnv( + train_num_envs, + EnvCreator(lambda: env_maker(task=cfg.env.name)), + ) + parallel_env.set_seed(cfg.env.seed) + + train_env = apply_env_transforms(parallel_env) + + eval_env = TransformedEnv( + ParallelEnv( + eval_num_envs, + EnvCreator(lambda: env_maker(task=cfg.env.name)), + ), + train_env.transform.clone(), + ) + return train_env, eval_env + + +# ==================================================================== +# Collector and replay buffer +# --------------------------- + + +def make_collector(cfg, train_env, actor_model_explore): + """Make collector.""" + collector = SyncDataCollector( + train_env, + actor_model_explore, + frames_per_batch=cfg.collector.frames_per_batch, + init_random_frames=cfg.collector.init_random_frames, + max_frames_per_traj=cfg.collector.max_frames_per_traj, + total_frames=cfg.collector.total_frames, + device=cfg.collector.device, + ) + collector.set_seed(cfg.env.seed) + return collector + + +def make_replay_buffer( + batch_size, + prb=False, + buffer_size=1000000, + buffer_scratch_dir=None, + device="cpu", + prefetch=3, +): + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + return replay_buffer + + +def make_offline_replay_buffer(rb_cfg): + data = D4RLExperienceReplay( + name=rb_cfg.dataset, + split_trajs=False, + batch_size=rb_cfg.batch_size, + sampler=SamplerWithoutReplacement(drop_last=False), + prefetch=4, + direct_download=True, + ) + + data.append_transform(DoubleToFloat()) + + return data + + +# ==================================================================== +# Model +# ----- +# +# We give one version of the model for learning from pixels, and one for state. +# TorchRL comes in handy at this point, as the high-level interactions with +# these models is unchanged, regardless of the modality. +# + + +def make_iql_model(cfg, train_env, eval_env, device="cpu"): + model_cfg = cfg.model + + in_keys = ["observation"] + action_spec = train_env.action_spec + if train_env.batch_size: + action_spec = action_spec[(0,) * len(train_env.batch_size)] + actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env) + + out_keys = ["loc", "scale"] + + actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys) + + # We use a ProbabilisticActor to make sure that we map the + # network output to the right space using a TanhDelta + # distribution. + actor = ProbabilisticActor( + module=actor_module, + in_keys=["loc", "scale"], + spec=action_spec, + distribution_class=TanhNormal, + distribution_kwargs={ + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": False, + }, + default_interaction_type=ExplorationType.RANDOM, + ) + + in_keys = ["observation", "action"] + + out_keys = ["state_action_value"] + qvalue = ValueOperator( + in_keys=in_keys, + out_keys=out_keys, + module=q_net, + ) + in_keys = ["observation"] + out_keys = ["state_value"] + value_net = ValueOperator( + in_keys=in_keys, + out_keys=out_keys, + module=value_net, + ) + model = torch.nn.ModuleList([actor, qvalue, value_net]).to(device) + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = eval_env.reset() + td = td.to(device) + for net in model: + net(td) + del td + eval_env.close() + + return model + + +def make_iql_modules_state(model_cfg, proof_environment): + action_spec = proof_environment.action_spec + + actor_net_kwargs = { + "num_cells": model_cfg.hidden_sizes, + "out_features": 2 * action_spec.shape[-1], + "activation_class": ACTIVATIONS[model_cfg.activation], + } + actor_net = MLP(**actor_net_kwargs) + actor_extractor = NormalParamExtractor( + scale_mapping=f"biased_softplus_{model_cfg.default_policy_scale}", + scale_lb=model_cfg.scale_lb, + ) + actor_net = torch.nn.Sequential(actor_net, actor_extractor) + + qvalue_net_kwargs = { + "num_cells": model_cfg.hidden_sizes, + "out_features": 1, + "activation_class": ACTIVATIONS[model_cfg.activation], + } + + q_net = MLP(**qvalue_net_kwargs) + + # Define Value Network + value_net_kwargs = { + "num_cells": model_cfg.hidden_sizes, + "out_features": 1, + "activation_class": ACTIVATIONS[model_cfg.activation], + } + value_net = MLP(**value_net_kwargs) + + return actor_net, q_net, value_net + + +# ==================================================================== +# IQL Loss +# --------- + + +def make_loss(loss_cfg, model): + loss_module = IQLLoss( + model[0], + model[1], + value_network=model[2], + loss_function=loss_cfg.loss_function, + temperature=loss_cfg.temperature, + expectile=loss_cfg.expectile, + ) + loss_module.make_value_estimator(gamma=loss_cfg.gamma) + target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) + + return loss_module, target_net_updater + + +def make_iql_optimizer(optim_cfg, loss_module): + optim = torch.optim.Adam( + loss_module.parameters(), + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, + ) + return optim + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + if logger is not None: + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) diff --git a/examples/sac/sac.py b/examples/sac/sac.py index ed0a38b144c..76bfea72e45 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -114,12 +114,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Optimization steps training_start = time.time() if collected_frames >= init_random_frames: - losses = TensorDict( - {}, - batch_size=[ - num_updates, - ], - ) + losses = TensorDict({}, batch_size=[num_updates]) for i in range(num_updates): # Sample from replay buffer sampled_tensordict = replay_buffer.sample() diff --git a/test/test_cost.py b/test/test_cost.py index 8bae683c5d5..1457aaf72a2 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4860,7 +4860,7 @@ def _create_seq_mock_data_cql( return td @pytest.mark.parametrize("delay_actor", (True, False)) - @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("delay_qvalue", (True, True)) @pytest.mark.parametrize("max_q_backup", [True, False]) @pytest.mark.parametrize("deterministic_backup", [True, False]) @pytest.mark.parametrize("with_lagrange", [True, False]) @@ -4876,8 +4876,6 @@ def test_cql( device, td_est, ): - if delay_actor or delay_qvalue: - pytest.skip("incompatible config") torch.manual_seed(self.seed) td = self._create_mock_data_cql(device=device) @@ -4885,12 +4883,6 @@ def test_cql( actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - kwargs = {} - if delay_actor: - kwargs["delay_actor"] = True - if delay_qvalue: - kwargs["delay_qvalue"] = True - loss_fn = CQLLoss( actor_network=actor, qvalue_network=qvalue, @@ -4898,7 +4890,8 @@ def test_cql( max_q_backup=max_q_backup, deterministic_backup=deterministic_backup, with_lagrange=with_lagrange, - **kwargs, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): @@ -4935,6 +4928,19 @@ def test_cql( include_nested=True, leaves_only=True ) ) + elif k == "loss_actor_bc": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) elif k == "loss_qvalue": assert all( (p.grad is None) or (p.grad == 0).all() @@ -4948,6 +4954,19 @@ def test_cql( include_nested=True, leaves_only=True ) ) + elif k == "loss_cql": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) elif k == "loss_alpha": assert all( (p.grad is None) or (p.grad == 0).all() diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 0c8caa5a60b..431f8503b3d 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -81,9 +81,6 @@ class CQLLoss(LossModule): with_lagrange (bool, optional): Whether to use the Lagrange multiplier. Default is ``False``. lagrange_thresh (float, optional): Lagrange threshold. Default is 0.0. - priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] - Tensordict key where to write the - priority (for prioritized replay buffer usage). Defaults to ``"td_error"``. Examples: >>> import torch @@ -215,14 +212,33 @@ class _AcceptedKeys: state action value is expected. Defaults to ``"state_action_value"``. log_prob (NestedKey): The input tensordict key where the log probability is expected. Defaults to ``"_log_prob"``. + pred_q1 (NestedKey): The input tensordict key where the predicted Q1 values are expected. + Defaults to ``"pred_q1"``. + pred_q2 (NestedKey): The input tensordict key where the predicted Q2 values are expected. + Defaults to ``"pred_q2"``. priority (NestedKey): The input tensordict key where the target priority is written to. Defaults to ``"td_error"``. + cql_q1_loss (NestedKey): The input tensordict key where the CQL Q1 loss is expected. + Defaults to ``"cql_q1_loss"``. + cql_q2_loss (NestedKey): The input tensordict key where the CQL Q2 loss is expected. + Defaults to ``"cql_q2_loss"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Defaults to ``"reward"``. + done (NestedKey): The input tensordict key where the done flag is expected. + Defaults to ``"done"``. + terminated (NestedKey): The input tensordict key where the terminated flag is expected. + Defaults to ``"terminated"``. """ action: NestedKey = "action" value: NestedKey = "state_value" state_action_value: NestedKey = "state_action_value" log_prob: NestedKey = "_log_prob" + pred_q1: NestedKey = "pred_q1" + pred_q2: NestedKey = "pred_q2" + priority: NestedKey = "td_error" + cql_q1_loss: NestedKey = "cql_q1_loss" + cql_q2_loss: NestedKey = "cql_q2_loss" priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" @@ -253,11 +269,9 @@ def __init__( num_random: int = 10, with_lagrange: bool = False, lagrange_thresh: float = 0.0, - priority_key: str = None, ) -> None: self._out_keys = None super().__init__() - self._set_deprecated_ctor_keys(priority_key=priority_key) # Actor self.delay_actor = delay_actor @@ -431,14 +445,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams } self._value_estimator.set_keys(**tensor_keys) - @property - def device(self) -> torch.device: - for p in self.parameters(): - return p.device - raise RuntimeError( - "At least one of the networks of CQLLoss must have trainable " "parameters." - ) - @property def in_keys(self): keys = [ @@ -458,7 +464,9 @@ def out_keys(self): if self._out_keys is None: keys = [ "loss_actor", + "loss_actor_bc", "loss_qvalue", + "loss_cql", "loss_alpha", "loss_alpha_prime", "alpha", @@ -480,28 +488,36 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: tensordict_reshape = tensordict - device = self.device - td_device = tensordict_reshape.to(device) + td_device = tensordict_reshape.to(tensordict.device) - loss_qvalue, loss_alpha_prime, priority = self._loss_qvalue_v(td_device) - loss_actor = self._loss_actor(td_device) - loss_alpha = self._loss_alpha(td_device) - tensordict_reshape.set(self.tensor_keys.priority, priority) - if loss_actor.shape != loss_qvalue.shape: - raise RuntimeError( - f"Losses shape mismatch: {loss_actor.shape} and {loss_qvalue.shape}" - ) + q_loss, metadata = self.q_loss(td_device) + cql_loss, cql_metadata = self.cql_loss(td_device) + if self.with_lagrange: + alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(td_device) + metadata.update(alpha_prime_metadata) + loss_actor_bc, bc_metadata = self.actor_bc_loss(td_device) + loss_actor, actor_metadata = self.actor_loss(td_device) + loss_alpha, alpha_metadata = self.alpha_loss(td_device) + metadata.update(bc_metadata) + metadata.update(cql_metadata) + metadata.update(actor_metadata) + metadata.update(alpha_metadata) + tensordict_reshape.set( + self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values + ) if shape: tensordict.update(tensordict_reshape.view(shape)) out = { "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_qvalue.mean(), + "loss_actor_bc": loss_actor_bc.mean(), + "loss_qvalue": q_loss.mean(), + "loss_cql": cql_loss.mean(), "loss_alpha": loss_alpha.mean(), - "loss_alpha_prime": loss_alpha_prime, "alpha": self._alpha, "entropy": -td_device.get(self.tensor_keys.log_prob).mean().detach(), } - + if self.with_lagrange: + out["loss_alpha_prime"] = alpha_prime_loss.mean() return TensorDict(out, []) @property @@ -509,12 +525,29 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _cached_detach_qvalue_params(self): return self.qvalue_network_params.detach() - def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: + def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor: with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): - dist = self.actor_network.get_dist(tensordict) - a_reparm = dist.rsample() + dist = self.actor_network.get_dist( + tensordict, + ) + a_reparm = dist.rsample() + log_prob = dist.log_prob(a_reparm) + bc_log_prob = dist.log_prob(tensordict.get(self.tensor_keys.action)) + + bc_actor_loss = self._alpha * log_prob - bc_log_prob + metadata = {"bc_log_prob": bc_log_prob.mean().detach()} + return bc_actor_loss, metadata + + def actor_loss(self, tensordict: TensorDictBase) -> Tensor: + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist( + tensordict, + ) + a_reparm = dist.rsample() log_prob = dist.log_prob(a_reparm) td_q = tensordict.select(*self.qvalue_network.in_keys) @@ -534,7 +567,9 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: # write log_prob in tensordict for alpha loss tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) - return self._alpha * log_prob - min_q_logprob + actor_loss = self._alpha * log_prob - min_q_logprob + + return actor_loss, {} def _get_policy_actions(self, data, actor_params, num_actions=10): batch_size = data.batch_size @@ -580,7 +615,6 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): next_state_value = next_tensordict_expand.get( self.tensor_keys.state_action_value ).min(0)[0] - # could be wrong to min if ( next_state_value.shape[-len(next_sample_log_prob.shape) :] != next_sample_log_prob.shape @@ -615,7 +649,7 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) return target_value - def _loss_qvalue_v(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def q_loss(self, tensordict: TensorDictBase) -> Tensor: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. target_value = self._get_value_v( tensordict, @@ -629,7 +663,36 @@ def _loss_qvalue_v(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: tensordict_pred_q, self.qvalue_network_params ).get(self.tensor_keys.state_action_value) - # add CQL + # write pred values in tensordict for cql loss + tensordict.set(self.tensor_keys.pred_q1, q_pred[0]) + tensordict.set(self.tensor_keys.pred_q2, q_pred[1]) + + q_pred = q_pred.squeeze(-1) + loss_qval = distance_loss( + q_pred, + target_value.expand_as(q_pred), + loss_function=self.loss_function, + ) + td_error = (q_pred - target_value).pow(2) + metadata = {"td_error": td_error.detach()} + + return loss_qval.sum(0).mean(), metadata + + def cql_loss(self, tensordict: TensorDictBase) -> Tensor: + pred_q1 = tensordict.get(self.tensor_keys.pred_q1) + pred_q2 = tensordict.get(self.tensor_keys.pred_q2) + + if pred_q1 is None: + raise KeyError( + f"Couldn't find the pred_q1 with key {self.tensor_keys.pred_q1} in the input tensordict. " + "This could be caused by calling cql_loss method before q_loss method." + ) + if pred_q2 is None: + raise KeyError( + f"Couldn't find the pred_q2 with key {self.tensor_keys.pred_q2} in the input tensordict. " + "This could be caused by calling cql_loss method before q_loss method." + ) + random_actions_tensor = ( torch.FloatTensor( tensordict.shape[0] * self.num_random, @@ -729,46 +792,49 @@ def _loss_qvalue_v(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: ) min_qf1_loss = ( - torch.logsumexp(cat_q1 / self.temperature, dim=1).mean() + torch.logsumexp(cat_q1 / self.temperature, dim=1) * self.min_q_weight * self.temperature ) min_qf2_loss = ( - torch.logsumexp(cat_q2 / self.temperature, dim=1).mean() + torch.logsumexp(cat_q2 / self.temperature, dim=1) * self.min_q_weight * self.temperature ) # Subtract the log likelihood of data - min_qf1_loss = min_qf1_loss - q_pred[0].mean() * self.min_q_weight - min_qf2_loss = min_qf2_loss - q_pred[1].mean() * self.min_q_weight - alpha_prime_loss = 0 - if self.with_lagrange: - alpha_prime = torch.clamp( - self.log_alpha_prime.exp(), min=0.0, max=1000000.0 - ) - min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap) - min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap) + cql_q1_loss = min_qf1_loss - pred_q1 * self.min_q_weight + cql_q2_loss = min_qf2_loss - pred_q2 * self.min_q_weight - alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5 + # write cql losses in tensordict for alpha prime loss + tensordict.set(self.tensor_keys.cql_q1_loss, cql_q1_loss) + tensordict.set(self.tensor_keys.cql_q2_loss, cql_q2_loss) - q_pred = q_pred.squeeze(-1) - loss_qval = distance_loss( - q_pred, - target_value.expand_as(q_pred), - loss_function=self.loss_function, - ) + return (cql_q1_loss + cql_q2_loss).mean(), {} - qf1_loss = loss_qval[0] + min_qf1_loss - qf2_loss = loss_qval[1] + min_qf2_loss + def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor: + cql_q1_loss = tensordict.get(self.tensor_keys.cql_q1_loss) + cql_q2_loss = tensordict.get(self.tensor_keys.cql_q2_loss) - loss_qval = qf1_loss + qf2_loss + if cql_q1_loss is None: + raise KeyError( + f"Couldn't find the cql_q1_loss with key {self.tensor_keys.cql_q1_loss} in the input tensordict. " + "This could be caused by calling alpha_prime_loss method before cql_loss method." + ) + if cql_q2_loss is None: + raise KeyError( + f"Couldn't find the cql_q2_loss with key {self.tensor_keys.cql_q2_loss} in the input tensordict. " + "This could be caused by calling alpha_prime_loss method before cql_loss method." + ) - td_error = abs(q_pred - target_value) + alpha_prime = torch.clamp_max(self.log_alpha_prime.exp(), max=1000000.0) + min_qf1_loss = alpha_prime * (cql_q1_loss.mean() - self.target_action_gap) + min_qf2_loss = alpha_prime * (cql_q2_loss.mean() - self.target_action_gap) - return loss_qval, alpha_prime_loss, td_error.detach().max(0)[0] + alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5 + return alpha_prime_loss, {} - def _loss_alpha(self, tensordict: TensorDictBase) -> Tensor: + def alpha_loss(self, tensordict: TensorDictBase) -> Tensor: log_pi = tensordict.get(self.tensor_keys.log_prob) if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter @@ -776,7 +842,7 @@ def _loss_alpha(self, tensordict: TensorDictBase) -> Tensor: else: # placeholder alpha_loss = torch.zeros_like(log_pi) - return alpha_loss + return alpha_loss, {} @property def _alpha(self): diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 59c3f32697f..dc0ed1e1df4 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -379,7 +379,6 @@ class DistributionalDQNLoss(LossModule): value_network (DistributionalQValueActor or nn.Module): the distributional Q value operator. gamma (scalar): a discount factor for return computation. - .. note:: Unlike :class:`DQNLoss`, this class does not currently support custom value functions. The next value estimation is always diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index e64dfa11f2d..a741d83ba13 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -290,6 +290,10 @@ def __init__( @property def device(self) -> torch.device: + warnings.warn( + "The device attributes of the looses will be deprecated in v0.3.", + category=DeprecationWarning, + ) for p in self.parameters(): return p.device raise RuntimeError( @@ -344,27 +348,29 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: tensordict_reshape = tensordict - device = self.device - td_device = tensordict_reshape.to(device) - - loss_actor = self._loss_actor(td_device) - loss_qvalue, priority = self._loss_qvalue(td_device) - loss_value = self._loss_value(td_device) + loss_actor, metadata = self.actor_loss(tensordict_reshape) + loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict_reshape) + loss_value, metadata_value = self.value_loss(tensordict_reshape) + metadata.update(**metadata_qvalue, **metadata_value) - tensordict_reshape.set(self.tensor_keys.priority, priority) if (loss_actor.shape != loss_qvalue.shape) or ( loss_value is not None and loss_actor.shape != loss_value.shape ): raise RuntimeError( f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}" ) + tensordict_reshape.set( + self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values + ) if shape: tensordict.update(tensordict_reshape.view(shape)) + + entropy = -tensordict_reshape.get(self.tensor_keys.log_prob).detach() out = { "loss_actor": loss_actor.mean(), "loss_qvalue": loss_qvalue.mean(), "loss_value": loss_value.mean(), - "entropy": -td_device.get(self.tensor_keys.log_prob).mean().detach(), + "entropy": entropy.mean(), } return TensorDict( @@ -372,7 +378,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: [], ) - def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: + def actor_loss(self, tensordict: TensorDictBase) -> Tensor: # KL loss with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) @@ -402,9 +408,9 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: # write log_prob in tensordict for alpha loss tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) - return -(exp_a * log_prob).mean() + return -(exp_a * log_prob).mean(), {} - def _loss_value(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) @@ -415,9 +421,9 @@ def _loss_value(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: self.value_network(td_copy) value = td_copy.get(self.tensor_keys.value).squeeze(-1) value_loss = self.loss_value_diff(min_q - value, self.expectile).mean() - return value_loss + return value_loss, {} - def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: obs_keys = self.actor_network.in_keys tensordict = tensordict.select("next", *obs_keys, self.tensor_keys.action) @@ -431,7 +437,7 @@ def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze( -1 ) - td_error = abs(pred_val - target_value) + td_error = (pred_val - target_value).pow(2) loss_qval = ( distance_loss( pred_val, @@ -441,7 +447,8 @@ def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: .sum(0) .mean() ) - return loss_qval, td_error.detach().max(0)[0] + metadata = {"td_error": td_error.detach()} + return loss_qval, metadata def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: