diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 528acd6d331..0f3685ee59e 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -200,6 +200,15 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online collector.device=cuda:0 \ logger.mode=offline \ logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/iql/discrete_iql.py \ + collector.total_frames=48 \ + optim.batch_size=10 \ + collector.frames_per_batch=16 \ + env.train_num_envs=2 \ + optim.device=cuda:0 \ + collector.device=cuda:0 \ + logger.mode=offline \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \ collector.total_frames=48 \ optim.batch_size=10 \ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 29bfa7d466e..1aec88f2d11 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -127,6 +127,7 @@ IQL :template: rl_template_noinherit.rst IQLLoss + DiscreteIQLLoss CQL ---- diff --git a/examples/iql/discrete_iql.py b/examples/iql/discrete_iql.py new file mode 100644 index 00000000000..39009923d02 --- /dev/null +++ b/examples/iql/discrete_iql.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""IQL Example. + +This is a self-contained example of an online discrete IQL training script. + +It works across Gym and MuJoCo over a variety of tasks. + +The helper functions are coded in the utils.py associated with this script. + +""" +import logging +import time + +import hydra +import numpy as np +import torch +import tqdm +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.record.loggers import generate_exp_name, get_logger + +from utils import ( + log_metrics, + make_collector, + make_discrete_iql_model, + make_discrete_loss, + make_environment, + make_iql_optimizer, + make_replay_buffer, +) + + +@hydra.main(config_path=".", config_name="discrete_iql") +def main(cfg: "DictConfig"): # noqa: F821 + # Create logger + exp_name = generate_exp_name("Discrete-IQL-online", cfg.env.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="iql_logging", + experiment_name=exp_name, + wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + ) + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + device = torch.device(cfg.optim.device) + + # Create environments + train_env, eval_env = make_environment( + cfg, + cfg.env.train_num_envs, + cfg.env.eval_num_envs, + ) + + # Create replay buffer + replay_buffer = make_replay_buffer( + batch_size=cfg.optim.batch_size, + prb=cfg.replay_buffer.prb, + buffer_size=cfg.replay_buffer.size, + device="cpu", + ) + + # Create model + model = make_discrete_iql_model(cfg, train_env, eval_env, device) + + # Create collector + collector = make_collector(cfg, train_env, actor_model_explore=model[0]) + + # Create loss + loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) + + # Create optimizer + optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( + cfg.optim, loss_module + ) + + # Main loop + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + init_random_frames = cfg.collector.init_random_frames + num_updates = int( + cfg.collector.env_per_collector + * cfg.collector.frames_per_batch + * cfg.optim.utd_ratio + ) + prb = cfg.replay_buffer.prb + eval_iter = cfg.logger.eval_iter + frames_per_batch = cfg.collector.frames_per_batch + eval_rollout_steps = cfg.collector.max_frames_per_traj + sampling_start = start_time = time.time() + for tensordict in collector: + sampling_time = time.time() - sampling_start + pbar.update(tensordict.numel()) + # update weights of the inference policy + collector.update_policy_weights_() + + tensordict = tensordict.reshape(-1) + current_frames = tensordict.numel() + # add to replay buffer + replay_buffer.extend(tensordict.cpu()) + collected_frames += current_frames + + # optimization steps + training_start = time.time() + if collected_frames >= init_random_frames: + for _ in range(num_updates): + # sample from replay buffer + sampled_tensordict = replay_buffer.sample().clone() + if sampled_tensordict.device != device: + sampled_tensordict = sampled_tensordict.to( + device, non_blocking=True + ) + else: + sampled_tensordict = sampled_tensordict + # compute losses + actor_loss, _ = loss_module.actor_loss(sampled_tensordict) + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + value_loss, _ = loss_module.value_loss(sampled_tensordict) + optimizer_value.zero_grad() + value_loss.backward() + optimizer_value.step() + + q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict) + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() + + # update qnet_target params + target_net_updater.step() + + # update priority + if prb: + sampled_tensordict.set( + loss_module.tensor_keys.priority, + metadata.pop("td_error").detach().max(0).values, + ) + replay_buffer.update_priority(sampled_tensordict) + + training_time = time.time() - training_start + episode_rewards = tensordict["next", "episode_reward"][ + tensordict["next", "done"] + ] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = q_loss.detach() + metrics_to_log["train/actor_loss"] = actor_loss.detach() + metrics_to_log["train/value_loss"] = value_loss.detach() + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model[0], + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_time = time.time() - eval_start + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + logging.info(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/iql/discrete_iql.yaml b/examples/iql/discrete_iql.yaml new file mode 100644 index 00000000000..52b6f8e13ca --- /dev/null +++ b/examples/iql/discrete_iql.yaml @@ -0,0 +1,58 @@ +# task and env +env: + name: CartPole-v1 + task: "" + exp_name: iql_${env.name} + n_samples_stats: 1000 + seed: 0 + train_num_envs: 1 + eval_num_envs: 1 + backend: gym + + +# collector +collector: + frames_per_batch: 200 + total_frames: 20000 + init_random_frames: 1000 + env_per_collector: 1 + device: cpu + max_frames_per_traj: 200 + +# logger +logger: + backend: wandb + log_interval: 5000 # record interval in frames + eval_steps: 200 + mode: online + eval_iter: 1000 + +# replay buffer +replay_buffer: + prb: 0 + buffer_prefetch: 64 + size: 1_000_000 + +# optimization +optim: + utd_ratio: 1 + device: cuda:0 + lr: 3e-4 + weight_decay: 0.0 + batch_size: 256 + +# network +model: + hidden_sizes: [256, 256] + activation: relu + + +# loss +loss: + loss_function: l2 + gamma: 0.99 + hard_update_interval: 10 + +# IQL specific hyperparameter + temperature: 100 + expectile: 0.8 diff --git a/examples/iql/iql_offline.py b/examples/iql/iql_offline.py index b5df32d7f2d..927ac924e90 100644 --- a/examples/iql/iql_offline.py +++ b/examples/iql/iql_offline.py @@ -34,7 +34,6 @@ @set_gym_backend("gym") @hydra.main(config_path=".", config_name="offline_config") def main(cfg: "DictConfig"): # noqa: F821 - # Create logger exp_name = generate_exp_name("IQL-offline", cfg.env.exp_name) logger = None @@ -64,7 +63,9 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_module, target_net_updater = make_loss(cfg.loss, model) # Create optimizer - optimizer = make_iql_optimizer(cfg.optim, loss_module) + optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( + cfg.optim, loss_module + ) pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) @@ -78,18 +79,29 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(i) # sample data data = replay_buffer.sample() - # compute loss - loss_vals = loss_module(data.clone().to(device)) - - actor_loss = loss_vals["loss_actor"] - q_loss = loss_vals["loss_qvalue"] - value_loss = loss_vals["loss_value"] - loss_val = actor_loss + q_loss + value_loss - - # update model - optimizer.zero_grad() - loss_val.backward() - optimizer.step() + + if data.device != device: + data = data.to(device, non_blocking=True) + + # compute losses + loss_info = loss_module(data) + actor_loss = loss_info["loss_actor"] + value_loss = loss_info["loss_value"] + q_loss = loss_info["loss_qvalue"] + + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + optimizer_value.zero_grad() + value_loss.backward() + optimizer_value.step() + + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() + + # update qnet_target params target_net_updater.step() # log metrics diff --git a/examples/iql/iql_online.py b/examples/iql/iql_online.py index 8dd7c0fdd07..663aa2d82d3 100644 --- a/examples/iql/iql_online.py +++ b/examples/iql/iql_online.py @@ -18,7 +18,6 @@ import numpy as np import torch import tqdm -from tensordict import TensorDict from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -76,7 +75,9 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_module, target_net_updater = make_loss(cfg.loss, model) # Create optimizer - optimizer = make_iql_optimizer(cfg.optim, loss_module) + optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( + cfg.optim, loss_module + ) # Main loop collected_frames = 0 @@ -108,8 +109,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # optimization steps training_start = time.time() if collected_frames >= init_random_frames: - log_loss_td = TensorDict({}, [num_updates]) - for j in range(num_updates): + for _ in range(num_updates): # sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() if sampled_tensordict.device != device: @@ -118,20 +118,23 @@ def main(cfg: "DictConfig"): # noqa: F821 ) else: sampled_tensordict = sampled_tensordict - # compute loss - loss_td = loss_module(sampled_tensordict) + # compute losses + loss_info = loss_module(sampled_tensordict) + actor_loss = loss_info["loss_actor"] + value_loss = loss_info["loss_value"] + q_loss = loss_info["loss_qvalue"] - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] - value_loss = loss_td["loss_value"] - loss = actor_loss + q_loss + value_loss + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() - # update model - optimizer.zero_grad() - loss.backward() - optimizer.step() + optimizer_value.zero_grad() + value_loss.backward() + optimizer_value.step() - log_loss_td[j] = loss_td.detach() + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() # update qnet_target params target_net_updater.step() @@ -155,10 +158,10 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_length ) if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = log_loss_td.get("loss_qvalue").detach() - metrics_to_log["train/actor_loss"] = log_loss_td.get("loss_actor").detach() - metrics_to_log["train/value_loss"] = log_loss_td.get("loss_value").detach() - metrics_to_log["train/entropy"] = log_loss_td.get("entropy").detach() + metrics_to_log["train/q_loss"] = q_loss.detach() + metrics_to_log["train/actor_loss"] = actor_loss.detach() + metrics_to_log["train/value_loss"] = value_loss.detach() + metrics_to_log["train/entropy"] = loss_info.get("entropy").detach() metrics_to_log["train/sampling_time"] = sampling_time metrics_to_log["train/training_time"] = training_time diff --git a/examples/iql/offline_config.yaml b/examples/iql/offline_config.yaml index 473decb2980..8b8cbe8c776 100644 --- a/examples/iql/offline_config.yaml +++ b/examples/iql/offline_config.yaml @@ -2,7 +2,7 @@ env: name: HalfCheetah-v2 task: "" - library: gym + backend: gym exp_name: iql_${replay_buffer.dataset} n_samples_stats: 1000 seed: 0 diff --git a/examples/iql/utils.py b/examples/iql/utils.py index e69cf45a0cd..31dcb732b00 100644 --- a/examples/iql/utils.py +++ b/examples/iql/utils.py @@ -4,11 +4,12 @@ # LICENSE file in the root directory of this source tree. import torch.nn import torch.optim -from tensordict.nn import TensorDictModule +from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torchrl.collectors import SyncDataCollector from torchrl.data import ( + CompositeSpec, LazyMemmapStorage, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, @@ -16,7 +17,9 @@ from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( + CatTensors, Compose, + DMControlEnv, DoubleToFloat, EnvCreator, InitTracker, @@ -24,10 +27,18 @@ RewardSum, TransformedEnv, ) + from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator -from torchrl.objectives import IQLLoss, SoftUpdate +from torchrl.modules import ( + MLP, + OneHotCategorical, + ProbabilisticActor, + SafeModule, + TanhNormal, + ValueOperator, +) +from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate from torchrl.trainers.helpers.models import ACTIVATIONS @@ -37,9 +48,21 @@ # ----------------- -def env_maker(task, device="cpu", from_pixels=False): - with set_gym_backend("gym"): - return GymEnv(task, device=device, from_pixels=from_pixels) +def env_maker(cfg, device="cpu"): + lib = cfg.env.backend + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") def apply_env_transforms( @@ -60,7 +83,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( train_num_envs, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda: env_maker(cfg)), ) parallel_env.set_seed(cfg.env.seed) @@ -69,7 +92,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1): eval_env = TransformedEnv( ParallelEnv( eval_num_envs, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda: env_maker(cfg)), ), train_env.transform.clone(), ) @@ -247,6 +270,82 @@ def make_iql_modules_state(model_cfg, proof_environment): return actor_net, q_net, value_net +def make_discrete_iql_model(cfg, train_env, eval_env, device): + """Make discrete IQL agent.""" + # Define Actor Network + in_keys = ["observation"] + action_spec = train_env.action_spec + if train_env.batch_size: + action_spec = action_spec[(0,) * len(train_env.batch_size)] + # Define Actor Network + in_keys = ["observation"] + + actor_net_kwargs = { + "num_cells": cfg.model.hidden_sizes, + "out_features": action_spec.shape[-1], + "activation_class": ACTIVATIONS[cfg.model.activation], + } + + actor_net = MLP(**actor_net_kwargs) + + actor_module = SafeModule( + module=actor_net, + in_keys=in_keys, + out_keys=["logits"], + ) + actor = ProbabilisticActor( + spec=CompositeSpec(action=eval_env.action_spec), + module=actor_module, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + distribution_kwargs={}, + default_interaction_type=InteractionType.RANDOM, + return_log_prob=False, + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": cfg.model.hidden_sizes, + "out_features": action_spec.shape[-1], + "activation_class": ACTIVATIONS[cfg.model.activation], + } + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = TensorDictModule( + in_keys=["observation"], + out_keys=["state_action_value"], + module=qvalue_net, + ) + + # Define Value Network + value_net_kwargs = { + "num_cells": cfg.model.hidden_sizes, + "out_features": 1, + "activation_class": ACTIVATIONS[cfg.model.activation], + } + value_net = MLP(**value_net_kwargs) + value_net = TensorDictModule( + in_keys=["observation"], + out_keys=["state_value"], + module=value_net, + ) + + model = torch.nn.ModuleList([actor, qvalue, value_net]).to(device) + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = eval_env.reset() + td = td.to(device) + for net in model: + net(td) + del td + eval_env.close() + + return model + + # ==================================================================== # IQL Loss # --------- @@ -267,13 +366,44 @@ def make_loss(loss_cfg, model): return loss_module, target_net_updater +def make_discrete_loss(loss_cfg, model): + loss_module = DiscreteIQLLoss( + model[0], + model[1], + value_network=model[2], + loss_function=loss_cfg.loss_function, + temperature=loss_cfg.temperature, + expectile=loss_cfg.expectile, + ) + loss_module.make_value_estimator(gamma=loss_cfg.gamma) + target_net_updater = HardUpdate( + loss_module, value_network_update_interval=loss_cfg.hard_update_interval + ) + + return loss_module, target_net_updater + + def make_iql_optimizer(optim_cfg, loss_module): - optim = torch.optim.Adam( - loss_module.parameters(), + critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) + actor_params = list(loss_module.actor_network_params.flatten_keys().values()) + value_params = list(loss_module.value_network_params.flatten_keys().values()) + + optimizer_actor = torch.optim.Adam( + actor_params, + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, + ) + optimizer_critic = torch.optim.Adam( + critic_params, + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, + ) + optimizer_value = torch.optim.Adam( + value_params, lr=optim_cfg.lr, weight_decay=optim_cfg.weight_decay, ) - return optim + return optimizer_actor, optimizer_critic, optimizer_value # ==================================================================== diff --git a/test/test_cost.py b/test/test_cost.py index dc9b75e7d87..8d704566c39 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -101,6 +101,7 @@ CQLLoss, DDPGLoss, DiscreteCQLLoss, + DiscreteIQLLoss, DiscreteSACLoss, DistributionalDQNLoss, DQNLoss, @@ -8950,6 +8951,768 @@ def test_iql_notensordict( assert loss_value == loss_val_td["loss_value"] +@pytest.mark.skipif( + not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" +) +class TestDiscreteIQL(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + action_key="action", + ): + # Actor + action_spec = OneHotDiscreteTensorSpec(action_dim) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) + actor = ProbabilisticActor( + spec=action_spec, + module=module, + in_keys=["logits"], + out_keys=[action_key], + distribution_class=OneHotCategorical, + return_log_prob=False, + ) + return actor.to(device) + + def _create_mock_qvalue( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + observation_key="observation", + ): + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim, action_dim) + + def forward(self, obs): + return self.linear(obs) + + module = ValueClass() + qvalue = ValueOperator( + module=module, in_keys=[observation_key], out_keys=["state_action_value"] + ) + return qvalue.to(device) + + def _create_mock_value( + self, + batch=2, + obs_dim=3, + device="cpu", + out_keys=None, + observation_key="observation", + ): + module = nn.Linear(obs_dim, 1) + value = ValueOperator( + module=module, in_keys=[observation_key], out_keys=out_keys + ) + return value.to(device) + + def _create_mock_common_layer_setup( + self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2, T=10 + ): + common_net = MLP( + num_cells=ncells, + in_features=n_obs, + depth=3, + out_features=n_hidden, + ) + actor_net = MLP( + num_cells=ncells, + in_features=n_hidden, + depth=1, + out_features=2 * n_act, + ) + value_net = MLP( + in_features=n_hidden, + num_cells=ncells, + depth=1, + out_features=1, + ) + qvalue_net = MLP( + in_features=n_hidden, + num_cells=ncells, + depth=1, + out_features=n_act, + ) + batch = [batch, T] + td = TensorDict( + { + "obs": torch.randn(*batch, n_obs), + "action": torch.randn(*batch, n_act), + "sample_log_prob": torch.randn(*batch), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(*batch, n_obs), + "reward": torch.randn(*batch, 1), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + }, + }, + batch, + names=[None, "time"], + ) + common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) + actor = ProbSeq( + common, + Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["logits"]), + ProbMod( + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + ), + ) + value = Seq( + common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]) + ) + qvalue = Seq( + common, + Mod( + qvalue_net, + in_keys=["hidden"], + out_keys=["state_action_value"], + ), + ) + qvalue(actor(td.clone())) + value(td.clone()) + return actor, value, qvalue, common, td + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_data_discrete_iql( + self, + batch=16, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + observation_key="observation", + action_key="action", + done_key="done", + terminated_key="terminated", + reward_key="reward", + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action_value = torch.randn(batch, action_dim, device=device) + action = (action_value == action_value.max(-1, True)[0]).to(torch.long) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + observation_key: obs, + "next": { + observation_key: next_obs, + done_key: done, + terminated_key: terminated, + reward_key: reward, + }, + action_key: action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_discrete_iql( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action_value = torch.randn(batch, T, action_dim, device=device) + action = (action_value == action_value.max(-1, True)[0]).to(torch.long) + + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = torch.ones(batch, T, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "done": done, + "terminated": terminated, + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + "collector": {"mask": mask}, + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + names=[None, "time"], + device=device, + ) + return td + + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0]) + @pytest.mark.parametrize("expectile", [0.1, 0.5]) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + def test_discrete_iql( + self, + num_qvalue, + device, + temperature, + expectile, + td_est, + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_discrete_iql(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + value = self._create_mock_value(device=device) + + loss_fn = DiscreteIQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + temperature=temperature, + expectile=expectile, + loss_function="l2", + ) + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + with pytest.raises(NotImplementedError): + loss_fn.make_value_estimator(td_est) + return + if td_est is not None: + loss_fn.make_value_estimator(td_est) + + with _check_td_steady(td), pytest.warns( + UserWarning, match="No target network updater" + ): + loss = loss_fn(td) + assert loss_fn.tensor_keys.priority in td.keys() + + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_value": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum([item for _, item in loss.items()]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("temperature", [0.0]) + @pytest.mark.parametrize("expectile", [0.1]) + def test_discrete_iql_state_dict( + self, + num_qvalue, + device, + temperature, + expectile, + ): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + value = self._create_mock_value(device=device) + + loss_fn = DiscreteIQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + temperature=temperature, + expectile=expectile, + loss_function="l2", + ) + sd = loss_fn.state_dict() + loss_fn2 = DiscreteIQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + temperature=temperature, + expectile=expectile, + loss_function="l2", + ) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("separate_losses", [False, True]) + def test_discrete_iql_separate_losses(self, separate_losses): + torch.manual_seed(self.seed) + actor, value, qvalue, common, td = self._create_mock_common_layer_setup() + loss_fn = DiscreteIQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + loss_function="l2", + separate_losses=separate_losses, + ) + with pytest.warns(UserWarning, match="No target network updater has been"): + loss = loss_fn(td) + + assert loss_fn.tensor_keys.priority in td.keys() + + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + common_layers_no = len(list(common.parameters())) + if k == "loss_actor": + if separate_losses: + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + common_layers = itertools.islice( + loss_fn.value_network_params.values(True, True), + common_layers_no, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in common_layers + ) + value_layers = itertools.islice( + loss_fn.value_network_params.values(True, True), + common_layers_no, + None, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in value_layers + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_value": + if separate_losses: + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + common_layers = itertools.islice( + loss_fn.actor_network_params.values(True, True), + common_layers_no, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in common_layers + ) + actor_layers = itertools.islice( + loss_fn.actor_network_params.values(True, True), + common_layers_no, + None, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in actor_layers + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + if separate_losses: + common_layers = itertools.islice( + loss_fn.value_network_params.values(True, True), + common_layers_no, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in common_layers + ) + value_layers = itertools.islice( + loss_fn.value_network_params.values(True, True), + common_layers_no, + None, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in value_layers + ) + else: + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) + ) + if separate_losses: + common_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in common_layers + ) + qvalue_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + None, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in qvalue_layers + ) + else: + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0]) + @pytest.mark.parametrize("expectile", [0.1, 0.5]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_discrete_iql_batcher( + self, + n, + num_qvalue, + temperature, + expectile, + device, + gamma=0.9, + ): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_discrete_iql(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + value = self._create_mock_value(device=device) + + loss_fn = DiscreteIQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + temperature=temperature, + expectile=expectile, + loss_function="l2", + ) + + ms = MultiStep(gamma=gamma, n_steps=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + with _check_td_steady(ms_td), pytest.warns( + UserWarning, match="No target network updater" + ): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + if n == 0: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum([item for _, item in loss.items()]) + _loss_ms = sum([item for _, item in loss_ms.items()]) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum([item for _, item in loss_ms.items()]).backward() + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + # Check param update effect on targets + target_qvalue = [ + p.clone() + for p in loss_fn.target_qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_qvalue2 = [ + p.clone() + for p in loss_fn.target_qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ] + if loss_fn.delay_qvalue: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_discrete_iql_tensordict_keys(self, td_est): + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + value = self._create_mock_value() + + loss_fn = DiscreteIQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + loss_function="l2", + ) + + default_keys = { + "priority": "td_error", + "log_prob": "_log_prob", + "action": "action", + "state_action_value": "state_action_value", + "value": "state_value", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + td_est=td_est, + ) + + value = self._create_mock_value(out_keys=["value_test"]) + loss_fn = DiscreteIQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + loss_function="l2", + ) + + key_mapping = { + "value": ("value", "value_test"), + "done": ("done", "done_test"), + "terminated": ("terminated", "terminated_test"), + "reward": ("reward", ("reward", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("action_key", ["action", "action2"]) + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_discrete_iql_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_discrete_iql( + action_key=action_key, + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + + actor = self._create_mock_actor(observation_key=observation_key) + qvalue = self._create_mock_qvalue( + observation_key=observation_key, + out_keys=["state_action_value"], + ) + value = self._create_mock_value(observation_key=observation_key) + + loss = DiscreteIQLLoss( + actor_network=actor, qvalue_network=qvalue, value_network=value + ) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) + + kwargs = { + action_key: td.get(action_key), + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": td.get(("next", observation_key)), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + with pytest.warns( + UserWarning, + match="No target network updater has been associated with this loss module", + ): + loss_val = loss(**kwargs) + loss_val_td = loss(td) + assert len(loss_val) == 4 + torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0]) + torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1]) + torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[2]) + torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[3]) + # test select + torch.manual_seed(self.seed) + loss.select_out_keys("loss_actor", "loss_value") + if torch.__version__ >= "2.0.0": + loss_actor, loss_value = loss(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor, loss_value = loss(**kwargs) + return + assert loss_actor == loss_val_td["loss_actor"] + assert loss_value == loss_val_td["loss_value"] + + @pytest.mark.parametrize("create_target_params", [True, False]) @pytest.mark.parametrize( "cast", [None, torch.float, torch.double, *get_default_devices()] diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 4840d12b2d4..f8d2bd1d977 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -10,7 +10,7 @@ from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss -from .iql import IQLLoss +from .iql import DiscreteIQLLoss, IQLLoss from .multiagent import QMixerLoss from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss from .redq import REDQLoss diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 08fce2dd4dd..aa0ada74801 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -4,17 +4,18 @@ # LICENSE file in the root directory of this source tree. import warnings from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from tensordict.nn import dispatch, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor +from torchrl.data.tensor_specs import TensorSpec +from torchrl.data.utils import _find_action_space from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_WARNING, _vmap_func, @@ -492,3 +493,352 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) + + +class DiscreteIQLLoss(IQLLoss): + r"""TorchRL implementation of the discrete IQL loss. + + Presented in "Offline Reinforcement Learning with Implicit Q-Learning" https://arxiv.org/abs/2110.06169 + + Args: + actor_network (ProbabilisticActor): stochastic actor + qvalue_network (TensorDictModule): Q(s) parametric model + value_network (TensorDictModule, optional): V(s) parametric model. + + Keyword Args: + action_space (str or TensorSpec): Action space. Must be one of + ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + num_qvalue_nets (integer, optional): number of Q-Value networks used. + Defaults to ``2``. + loss_function (str, optional): loss function to be used with + the value function loss. Default is `"smooth_l1"`. + temperature (float, optional): Inverse temperature (beta). + For smaller hyperparameter values, the objective behaves similarly to + behavioral cloning, while for larger values, it attempts to recover the + maximum of the Q-function. + expectile (float, optional): expectile :math:`\tau`. A larger value of :math:`\tau` is crucial + for antmaze tasks that require dynamical programming ("stichting"). + priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] + tensordict key where to write the priority (for prioritized replay + buffer usage). Default is `"td_error"`. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper + >>> from torchrl.modules.distributions.discrete import OneHotCategorical + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.iql import DiscreteIQLLoss + >>> from tensordict.tensordict import TensorDict + >>> n_act, n_obs = 4, 3 + >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) + >>> actor = ProbabilisticActor( + ... module=module, + ... in_keys=["logits"], + ... out_keys=["action"], + ... spec=spec, + ... distribution_class=OneHotCategorical) + >>> qvalue = TensorDictModule( + ... nn.Linear(n_obs), + ... in_keys=["observation"], + ... out_keys=["state_action_value"], + ... ) + >>> value = TensorDictModule( + ... nn.Linear(n_obs), + ... in_keys=["observation"], + ... out_keys=["state_value"], + ... ) + >>> loss = DiscreteIQLLoss(actor, qvalue, value) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> data = TensorDict({ + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... ("next", "observation"): torch.randn(*batch, n_obs), + ... }, batch) + >>> loss(data) + TensorDict( + fields={ + entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + This class is compatible with non-tensordict based modules too and can be + used without recurring to any tensordict-related primitive. In this case, + the expected keyword arguments are: + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network + The return value is a tuple of tensors in the following order: + ``["loss_actor", "loss_qvalue", "loss_value", "entropy"]``. + + Examples: + >>> import torch + >>> import torch + >>> from torch import nn + >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper + >>> from torchrl.modules.distributions.discrete import OneHotCategorical + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.iql import DiscreteIQLLoss + >>> from tensordict.tensordict import TensorDict + >>> _ = torch.manual_seed(42) + >>> n_act, n_obs = 4, 3 + >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"]) + >>> actor = ProbabilisticActor( + ... module=module, + ... in_keys=["logits"], + ... out_keys=["action"], + ... spec=spec, + ... distribution_class=OneHotCategorical) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs, n_act) + ... def forward(self, obs): + ... return self.linear(obs) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation']) + >>> module = nn.Linear(n_obs, 1) + >>> value = ValueOperator( + ... module=module, + ... in_keys=["observation"]) + >>> loss = DiscreteIQLLoss(actor, qvalue, value) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> loss_actor, loss_qvalue, loss_value, entropy = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_observation=torch.zeros(*batch, n_obs), + ... next_reward=torch.randn(*batch, 1)) + >>> loss_actor.backward() + + + The output keys can also be filtered using the :meth:`DiscreteIQLLoss.select_out_keys` + method. + + Examples: + >>> loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value') + >>> loss_actor, loss_qvalue, loss_value = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_observation=torch.zeros(*batch, n_obs), + ... next_reward=torch.randn(*batch, 1)) + >>> loss_actor.backward() + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values + + Attributes: + value (NestedKey): The input tensordict key where the state value is expected. + Will be used for the underlying value estimator. Defaults to ``"state_value"``. + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + log_prob (NestedKey): The input tensordict key where the log probability is expected. + Defaults to ``"_log_prob"``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + state_action_value (NestedKey): The input tensordict key where the + state action value is expected. Will be used for the underlying + value estimator as value key. Defaults to ``"state_action_value"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + """ + + value: NestedKey = "state_value" + action: NestedKey = "action" + log_prob: NestedKey = "_log_prob" + priority: NestedKey = "td_error" + state_action_value: NestedKey = "state_action_value" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + out_keys = [ + "loss_actor", + "loss_qvalue", + "loss_value", + "entropy", + ] + + def __init__( + self, + actor_network: ProbabilisticActor, + qvalue_network: TensorDictModule, + value_network: Optional[TensorDictModule], + *, + action_space: Union[str, TensorSpec] = None, + num_qvalue_nets: int = 2, + loss_function: str = "smooth_l1", + temperature: float = 1.0, + expectile: float = 0.5, + gamma: float = None, + priority_key: str = None, + separate_losses: bool = False, + ) -> None: + self._in_keys = None + self._out_keys = None + if expectile >= 1.0: + raise ValueError(f"Expectile should be lower than 1.0 but is {expectile}") + super().__init__( + actor_network=actor_network, + qvalue_network=qvalue_network, + value_network=value_network, + num_qvalue_nets=num_qvalue_nets, + loss_function=loss_function, + temperature=temperature, + expectile=expectile, + gamma=gamma, + priority_key=priority_key, + separate_losses=separate_losses, + ) + if action_space is None: + warnings.warn( + "action_space was not specified. DiscreteIQLLoss will default to 'one-hot'." + "This behaviour will be deprecated soon and a space will have to be passed." + "Check the DiscreteIQLLoss documentation to see how to pass the action space. " + ) + action_space = "one-hot" + self.action_space = _find_action_space(action_space) + + def actor_loss(self, tensordict: TensorDictBase) -> Tensor: + # KL loss + with self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist(tensordict) + + log_prob = dist.log_prob(tensordict[self.tensor_keys.action]) + + # Min Q value + td_q = tensordict.select(*self.qvalue_network.in_keys) + td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) + state_action_value = td_q.get(self.tensor_keys.state_action_value) + action = tensordict.get(self.tensor_keys.action) + if self.action_space == "categorical": + if action.shape != state_action_value.shape: + # unsqueeze the action if it lacks on trailing singleton dim + action = action.unsqueeze(-1) + chosen_state_action_value = torch.gather( + state_action_value, -1, index=action + ).squeeze(-1) + else: + action = action.to(torch.float) + chosen_state_action_value = (state_action_value * action).sum(-1) + min_Q, _ = torch.min(chosen_state_action_value, dim=0) + if log_prob.shape != min_Q.shape: + raise RuntimeError( + f"Losses shape mismatch: {log_prob.shape} and {min_Q.shape}" + ) + with torch.no_grad(): + # state value + td_copy = tensordict.select(*self.value_network.in_keys).detach() + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) + value = td_copy.get(self.tensor_keys.value).squeeze( + -1 + ) # assert has no gradient + + exp_a = torch.exp((min_Q - value) * self.temperature) + exp_a = torch.min(exp_a, torch.FloatTensor([100.0]).to(self.device)) + + # write log_prob in tensordict for alpha loss + tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) + return -(exp_a * log_prob).mean(), {} + + def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + # Min Q value + with torch.no_grad(): + # Min Q value + td_q = tensordict.select(*self.qvalue_network.in_keys) + td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) + state_action_value = td_q.get(self.tensor_keys.state_action_value) + action = tensordict.get(self.tensor_keys.action) + if self.action_space == "categorical": + if action.shape != state_action_value.shape: + # unsqueeze the action if it lacks on trailing singleton dim + action = action.unsqueeze(-1) + chosen_state_action_value = torch.gather( + state_action_value, -1, index=action + ).squeeze(-1) + else: + action = action.to(torch.float) + chosen_state_action_value = (state_action_value * action).sum(-1) + min_Q, _ = torch.min(chosen_state_action_value, dim=0) + # state value + td_copy = tensordict.select(*self.value_network.in_keys) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) + value = td_copy.get(self.tensor_keys.value).squeeze(-1) + value_loss = self.loss_value_diff(min_Q - value, self.expectile).mean() + return value_loss, {} + + def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + obs_keys = self.actor_network.in_keys + next_td = tensordict.select("next", *obs_keys, self.tensor_keys.action) + with torch.no_grad(): + target_value = self.value_estimator.value_estimate( + next_td, target_params=self.target_value_network_params + ).squeeze(-1) + + # predict current Q value + td_q = tensordict.select(*self.qvalue_network.in_keys) + td_q = self._vmap_qvalue_networkN0(td_q, self.qvalue_network_params) + state_action_value = td_q.get(self.tensor_keys.state_action_value) + action = tensordict.get(self.tensor_keys.action) + if self.action_space == "categorical": + if action.shape != state_action_value.shape: + # unsqueeze the action if it lacks on trailing singleton dim + action = action.unsqueeze(-1) + pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1) + else: + action = action.to(torch.float) + pred_val = (state_action_value * action).sum(-1) + + td_error = (pred_val - target_value.expand_as(pred_val)).pow(2) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .sum(0) + .mean() + ) + metadata = {"td_error": td_error.detach()} + return loss_qval, metadata