diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 201971f4161..5d46ad5dda3 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -270,6 +270,12 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/multiagent/q train.num_epochs=3 \ train.minibatch_size=100 \ logger.backend= +python .circleci/unittest/helpers/coverage_run_parallel.py examples/multiagent/sac.py \ + collector.n_iters=2 \ + collector.frames_per_batch=200 \ + train.num_epochs=3 \ + train.minibatch_size=100 \ + logger.backend= python .circleci/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100 diff --git a/examples/discrete_sac/discrete_sac.py b/examples/discrete_sac/discrete_sac.py index c5c03cf7042..12ac76f20e7 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/examples/discrete_sac/discrete_sac.py @@ -9,7 +9,7 @@ import torch import torch.cuda import tqdm -from tensordict.nn import InteractionType +from tensordict.nn import InteractionType, TensorDictModule from torch import nn, optim from torchrl.collectors import SyncDataCollector @@ -27,7 +27,7 @@ from torchrl.modules import MLP, SafeModule from torchrl.modules.distributions import OneHotCategorical -from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator +from torchrl.modules.tensordict_module.actors import ProbabilisticActor from torchrl.objectives import DiscreteSACLoss, SoftUpdate from torchrl.record.loggers import generate_exp_name, get_logger @@ -150,8 +150,9 @@ def env_factory(num_workers): **qvalue_net_kwargs, ) - qvalue = ValueOperator( + qvalue = TensorDictModule( in_keys=in_keys, + out_keys=["action_value"], module=qvalue_net, ).to(device) @@ -171,6 +172,7 @@ def env_factory(num_workers): # Create SAC loss loss_module = DiscreteSACLoss( actor_network=model[0], + action_space=test_env.action_spec, qvalue_network=model[1], num_actions=num_actions, num_qvalue_nets=2, diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py new file mode 100644 index 00000000000..ae25038eefd --- /dev/null +++ b/examples/multiagent/sac.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import hydra +import torch + +from tensordict.nn import TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor +from torch import nn +from torch.distributions import Categorical +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from torchrl.data.replay_buffers.storages import LazyTensorStorage +from torchrl.envs import RewardSum, TransformedEnv +from torchrl.envs.libs.vmas import VmasEnv +from torchrl.envs.utils import ExplorationType, set_exploration_type + +from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.modules.models.multiagent import MultiAgentMLP +from torchrl.objectives import DiscreteSACLoss, SACLoss, SoftUpdate, ValueEstimators +from utils.logging import init_logging, log_evaluation, log_training + + +def rendering_callback(env, td): + env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) + + +@hydra.main(version_base="1.1", config_path=".", config_name="sac") +def train(cfg: "DictConfig"): # noqa: F821 + # Device + cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.env.device = cfg.train.device + + # Seeding + torch.manual_seed(cfg.seed) + + # Sampling + cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps + cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters + cfg.buffer.memory_size = cfg.collector.frames_per_batch + + # Create env and env_test + env = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.env.vmas_envs, + continuous_actions=cfg.env.continuous_actions, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + env = TransformedEnv( + env, + RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]), + ) + + env_test = VmasEnv( + scenario=cfg.env.scenario_name, + num_envs=cfg.eval.evaluation_episodes, + continuous_actions=cfg.env.continuous_actions, + max_steps=cfg.env.max_steps, + device=cfg.env.device, + seed=cfg.seed, + # Scenario kwargs + **cfg.env.scenario, + ) + + # Policy + if cfg.env.continuous_actions: + actor_net = nn.Sequential( + MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=2 * env.action_spec.shape[-1], + n_agents=env.n_agents, + centralised=False, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ), + NormalParamExtractor(), + ) + policy_module = TensorDictModule( + actor_net, + in_keys=[("agents", "observation")], + out_keys=[("agents", "loc"), ("agents", "scale")], + ) + + policy = ProbabilisticActor( + module=policy_module, + spec=env.unbatched_action_spec, + in_keys=[("agents", "loc"), ("agents", "scale")], + out_keys=[env.action_key], + distribution_class=TanhNormal, + distribution_kwargs={ + "min": env.unbatched_action_spec[("agents", "action")].space.minimum, + "max": env.unbatched_action_spec[("agents", "action")].space.maximum, + }, + return_log_prob=True, + ) + + # Critic + module = MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1] + + env.action_spec.shape[-1], # Q critic takes action and value + n_agent_outputs=1, + n_agents=env.n_agents, + centralised=cfg.model.centralised_critic, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ) + value_module = ValueOperator( + module=module, + in_keys=[("agents", "observation"), env.action_key], + out_keys=[("agents", "state_action_value")], + ) + else: + actor_net = nn.Sequential( + MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=env.action_spec.space.n, + n_agents=env.n_agents, + centralised=False, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ), + ) + policy_module = TensorDictModule( + actor_net, + in_keys=[("agents", "observation")], + out_keys=[("agents", "logits")], + ) + policy = ProbabilisticActor( + module=policy_module, + spec=env.unbatched_action_spec, + in_keys=[("agents", "logits")], + out_keys=[env.action_key], + distribution_class=Categorical, + return_log_prob=True, + ) + + # Critic + module = MultiAgentMLP( + n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], + n_agent_outputs=env.action_spec.space.n, + n_agents=env.n_agents, + centralised=cfg.model.centralised_critic, + share_params=cfg.model.shared_parameters, + device=cfg.train.device, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ) + value_module = ValueOperator( + module=module, + in_keys=[("agents", "observation")], + out_keys=[("agents", "action_value")], + ) + + collector = SyncDataCollector( + env, + policy, + device=cfg.env.device, + storing_device=cfg.train.device, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + ) + + replay_buffer = TensorDictReplayBuffer( + storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device), + sampler=SamplerWithoutReplacement(), + batch_size=cfg.train.minibatch_size, + ) + + if cfg.env.continuous_actions: + loss_module = SACLoss( + actor_network=policy, qvalue_network=value_module, delay_qvalue=True + ) + loss_module.set_keys( + state_action_value=("agents", "state_action_value"), + action=env.action_key, + reward=env.reward_key, + ) + else: + loss_module = DiscreteSACLoss( + actor_network=policy, + qvalue_network=value_module, + delay_qvalue=True, + num_actions=env.action_spec.space.n, + action_space=env.unbatched_action_spec, + ) + loss_module.set_keys( + action_value=("agents", "action_value"), + action=env.action_key, + reward=env.reward_key, + ) + + loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma) + target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau) + + optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr) + + # Logging + if cfg.logger.backend: + model_name = ( + ("Het" if not cfg.model.shared_parameters else "") + + ("MA" if cfg.model.centralised_critic else "I") + + "SAC" + ) + logger = init_logging(cfg, model_name) + + total_time = 0 + total_frames = 0 + sampling_start = time.time() + for i, tensordict_data in enumerate(collector): + print(f"\nIteration {i}") + + sampling_time = time.time() - sampling_start + + tensordict_data.set( + ("next", "done"), + tensordict_data.get(("next", "done")) + .unsqueeze(-1) + .expand(tensordict_data.get(("next", env.reward_key)).shape), + ) # We need to expand the done to match the reward shape + + current_frames = tensordict_data.numel() + total_frames += current_frames + data_view = tensordict_data.reshape(-1) + replay_buffer.extend(data_view) + + training_tds = [] + training_start = time.time() + for _ in range(cfg.train.num_epochs): + for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size): + subdata = replay_buffer.sample() + loss_vals = loss_module(subdata) + training_tds.append(loss_vals.detach()) + + loss_value = ( + loss_vals["loss_actor"] + + loss_vals["loss_alpha"] + + loss_vals["loss_qvalue"] + + loss_vals["loss_alpha"] + ) + + loss_value.backward() + + total_norm = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), cfg.train.max_grad_norm + ) + training_tds[-1].set("grad_norm", total_norm.mean()) + + optim.step() + optim.zero_grad() + target_net_updater.step() + + collector.update_policy_weights_() + + training_time = time.time() - training_start + + iteration_time = sampling_time + training_time + total_time += iteration_time + training_tds = torch.stack(training_tds) + + # More logs + if cfg.logger.backend: + log_training( + logger, + training_tds, + tensordict_data, + sampling_time, + training_time, + total_time, + i, + current_frames, + total_frames, + step=i, + ) + + if ( + cfg.eval.evaluation_episodes > 0 + and i % cfg.eval.evaluation_interval == 0 + and cfg.logger.backend + ): + evaluation_start = time.time() + with torch.no_grad() and set_exploration_type(ExplorationType.MODE): + env_test.frames = [] + rollouts = env_test.rollout( + max_steps=cfg.env.max_steps, + policy=policy, + callback=rendering_callback, + auto_cast_to_device=True, + break_when_any_done=False, + # We are running vectorized evaluation we do not want it to stop when just one env is done + ) + + evaluation_time = time.time() - evaluation_start + + log_evaluation(logger, rollouts, env_test, evaluation_time, step=i) + + if cfg.logger.backend == "wandb": + logger.experiment.log({}, commit=True) + sampling_start = time.time() + + +if __name__ == "__main__": + train() diff --git a/examples/multiagent/sac.yaml b/examples/multiagent/sac.yaml new file mode 100644 index 00000000000..98f55045464 --- /dev/null +++ b/examples/multiagent/sac.yaml @@ -0,0 +1,40 @@ +seed: 0 + +env: + continuous_actions: True # False for discrete sac + max_steps: 100 + scenario_name: "balance" + scenario: + n_agents: 3 + device: ??? # These values will be populated dynamically + vmas_envs: ??? + +model: + shared_parameters: True + centralised_critic: True + +collector: + frames_per_batch: 60_000 # Frames sampled each sampling iteration + n_iters: 500 # Number of sampling/training iterations + total_frames: ??? + +buffer: + memory_size: ??? + +loss: + gamma: 0.9 + tau: 0.005 # For target net + +train: + num_epochs: 45 # optimization steps per batch of data collected + minibatch_size: 4096 # size of minibatches used in each epoch + lr: 5e-5 + max_grad_norm: 2.0 + device: ??? + +eval: + evaluation_interval: 20 + evaluation_episodes: 200 + +logger: + backend: wandb # Delete to remove logging diff --git a/test/test_cost.py b/test/test_cost.py index 38faaaf1048..aa1ae5d245d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -3058,7 +3058,9 @@ def forward(self, obs): return self.linear(obs) module = ValueClass() - qvalue = ValueOperator(module=module, in_keys=[observation_key]) + qvalue = ValueOperator( + module=module, in_keys=[observation_key], out_keys=["action_value"] + ) return qvalue.to(device) def _create_mock_distributional_actor( @@ -3473,14 +3475,6 @@ def test_discrete_sac_notensordict( torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2]) torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3]) torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4]) - torch.testing.assert_close( - loss_val_td.get("state_action_value_actor"), loss_val[5] - ) - torch.testing.assert_close( - loss_val_td.get("action_log_prob_actor"), loss_val[6] - ) - torch.testing.assert_close(loss_val_td.get("next.state_value"), loss_val[7]) - torch.testing.assert_close(loss_val_td.get("target_value"), loss_val[8]) # test select torch.manual_seed(self.seed) loss.select_out_keys("loss_actor", "loss_alpha") diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index 67756660856..dad28c710bb 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -12,9 +12,14 @@ from torch import Tensor from torchrl.data.tensor_specs import ( + BinaryDiscreteTensorSpec, CompositeSpec, + DiscreteTensorSpec, LazyStackedCompositeSpec, LazyStackedTensorSpec, + MultiDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, + OneHotDiscreteTensorSpec, TensorSpec, ) @@ -43,6 +48,27 @@ INDEX_TYPING = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]] +ACTION_SPACE_MAP = { + OneHotDiscreteTensorSpec: "one_hot", + MultiOneHotDiscreteTensorSpec: "mult_one_hot", + BinaryDiscreteTensorSpec: "binary", + DiscreteTensorSpec: "categorical", + "one_hot": "one_hot", + "one-hot": "one_hot", + "mult_one_hot": "mult_one_hot", + "mult-one-hot": "mult_one_hot", + "multi_one_hot": "mult_one_hot", + "multi-one-hot": "mult_one_hot", + "binary": "binary", + "categorical": "categorical", + MultiDiscreteTensorSpec: "multi_categorical", + "multi_categorical": "multi_categorical", + "multi-categorical": "multi_categorical", + "multi_discrete": "multi_categorical", + "multi-discrete": "multi_categorical", +} + + def consolidate_spec( spec: CompositeSpec, recurse_through_entries: bool = True, @@ -129,7 +155,7 @@ def _empty_like_spec(specs: List[TensorSpec], shape): shape = list(shape[: spec.stack_dim]) + list(shape[spec.stack_dim + 1 :]) return LazyStackedTensorSpec( *[_empty_like_spec(spec._specs, shape) for _ in spec._specs], - dim=spec.stack_dim + dim=spec.stack_dim, ) else: # the exclusive key has values which are TensorSpecs -> @@ -206,3 +232,79 @@ def __call__(self, *args, **kwargs) -> Any: kwargs = {k: item for k, item in kwargs.items()} kwargs.update(self.kwargs) return self.fn(**kwargs) + + +def _process_action_space_spec(action_space, spec): + original_spec = spec + composite_spec = False + if isinstance(spec, CompositeSpec): + # this will break whenever our action is more complex than a single tensor + try: + if "action" in spec.keys(): + _key = "action" + else: + # the first key is the action + for _key in spec.keys(True, True): + if isinstance(_key, tuple) and _key[-1] == "action": + break + else: + raise KeyError + spec = spec[_key] + composite_spec = True + except KeyError: + raise KeyError( + "action could not be found in the spec. Make sure " + "you pass a spec that is either a native action spec or a composite action spec " + "with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only." + ) + if action_space is not None: + if isinstance(action_space, CompositeSpec): + raise ValueError("action_space cannot be of type CompositeSpec.") + if ( + spec is not None + and isinstance(action_space, TensorSpec) + and action_space is not spec + ): + raise ValueError( + "Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match." + ) + if isinstance(action_space, TensorSpec): + spec = action_space + action_space = _find_action_space(action_space) + # check that the spec and action_space match + if spec is not None and _find_action_space(spec) != action_space: + raise ValueError( + f"The action spec and the action space do not match: got action_space={action_space} and spec={spec}." + ) + elif spec is not None: + action_space = _find_action_space(spec) + else: + raise ValueError( + "Neither action_space nor spec was defined. The action space cannot be inferred." + ) + if composite_spec: + spec = original_spec + return action_space, spec + + +def _find_action_space(action_space): + if isinstance(action_space, TensorSpec): + if isinstance(action_space, CompositeSpec): + if "action" in action_space.keys(): + _key = "action" + else: + # the first key is the action + for _key in action_space.keys(True, True): + if isinstance(_key, tuple) and _key[-1] == "action": + break + else: + raise KeyError + action_space = action_space[_key] + action_space = type(action_space) + try: + action_space = ACTION_SPACE_MAP[action_space] + except KeyError: + raise ValueError( + f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}." + ) + return action_space diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 1cd6745826a..4ae4c9a9de1 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -20,6 +20,7 @@ from torch.distributions import Categorical from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.utils import _process_action_space_spec from torchrl.modules.models.models import DistributionalDQNnet from torchrl.modules.tensordict_module.common import SafeModule from torchrl.modules.tensordict_module.probabilistic import ( @@ -27,7 +28,6 @@ SafeProbabilisticTensorDictSequential, ) from torchrl.modules.tensordict_module.sequence import SafeSequential -from torchrl.modules.utils.utils import _find_action_space class Actor(SafeModule): @@ -682,59 +682,6 @@ def _binary(self, value: torch.Tensor) -> torch.Tensor: ) -def _process_action_space_spec(action_space, spec): - original_spec = spec - composite_spec = False - if isinstance(spec, CompositeSpec): - # this will break whenever our action is more complex than a single tensor - try: - if "action" in spec.keys(): - _key = "action" - else: - # the first key is the action - for _key in spec.keys(True, True): - if isinstance(_key, tuple) and _key[-1] == "action": - break - else: - raise KeyError - spec = spec[_key] - composite_spec = True - except KeyError: - raise KeyError( - "action could not be found in the spec. Make sure " - "you pass a spec that is either a native action spec or a composite action spec " - "with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only." - ) - if action_space is not None: - if isinstance(action_space, CompositeSpec): - raise ValueError("action_space cannot be of type CompositeSpec.") - if ( - spec is not None - and isinstance(action_space, TensorSpec) - and action_space is not spec - ): - raise ValueError( - "Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match." - ) - if isinstance(action_space, TensorSpec): - spec = action_space - action_space = _find_action_space(action_space) - # check that the spec and action_space match - if spec is not None and _find_action_space(spec) != action_space: - raise ValueError( - f"The action spec and the action space do not match: got action_space={action_space} and spec={spec}." - ) - elif spec is not None: - action_space = _find_action_space(spec) - else: - raise ValueError( - "Neither action_space nor spec was defined. The action space cannot be inferred." - ) - if composite_spec: - spec = original_spec - return action_space, spec - - class QValueHook: """Q-Value hook for Q-value policies. diff --git a/torchrl/modules/utils/utils.py b/torchrl/modules/utils/utils.py index 95427fce078..e69de29bb2d 100644 --- a/torchrl/modules/utils/utils.py +++ b/torchrl/modules/utils/utils.py @@ -1,51 +0,0 @@ -from torchrl.data.tensor_specs import ( - BinaryDiscreteTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, - TensorSpec, -) - -ACTION_SPACE_MAP = {} -ACTION_SPACE_MAP[OneHotDiscreteTensorSpec] = "one_hot" -ACTION_SPACE_MAP[MultiOneHotDiscreteTensorSpec] = "mult_one_hot" -ACTION_SPACE_MAP[BinaryDiscreteTensorSpec] = "binary" -ACTION_SPACE_MAP[DiscreteTensorSpec] = "categorical" -ACTION_SPACE_MAP["one_hot"] = "one_hot" -ACTION_SPACE_MAP["one-hot"] = "one_hot" -ACTION_SPACE_MAP["mult_one_hot"] = "mult_one_hot" -ACTION_SPACE_MAP["mult-one-hot"] = "mult_one_hot" -ACTION_SPACE_MAP["multi_one_hot"] = "mult_one_hot" -ACTION_SPACE_MAP["multi-one-hot"] = "mult_one_hot" -ACTION_SPACE_MAP["binary"] = "binary" -ACTION_SPACE_MAP["categorical"] = "categorical" -# TODO for the future ;) -# ACTION_SPACE_MAP[MultiDiscreteTensorSpec] = "multi_categorical" -# ACTION_SPACE_MAP["multi_categorical"] = "multi_categorical" -# ACTION_SPACE_MAP["multi-categorical"] = "multi_categorical" -# ACTION_SPACE_MAP["multi_discrete"] = "multi_categorical" -# ACTION_SPACE_MAP["multi-discrete"] = "multi_categorical" - - -def _find_action_space(action_space): - if isinstance(action_space, TensorSpec): - if isinstance(action_space, CompositeSpec): - if "action" in action_space.keys(): - _key = "action" - else: - # the first key is the action - for _key in action_space.keys(True, True): - if isinstance(_key, tuple) and _key[-1] == "action": - break - else: - raise KeyError - action_space = action_space[_key] - action_space = type(action_space) - try: - action_space = ACTION_SPACE_MAP[action_space] - except KeyError: - raise ValueError( - f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}." - ) - return action_space diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index c37328e9326..527af5bf481 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -13,6 +13,8 @@ from torch import nn from torchrl.data.tensor_specs import TensorSpec +from torchrl.data.utils import _find_action_space + from torchrl.envs.utils import step_mdp from torchrl.modules.tensordict_module.actors import ( DistributionalQValueActor, @@ -20,8 +22,6 @@ ) from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible -from torchrl.modules.utils.utils import _find_action_space - from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_WARNING, diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index e9eca7ce293..f09fa3c0e03 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -18,12 +18,12 @@ from torchrl.data.tensor_specs import TensorSpec +from torchrl.data.utils import _find_action_space + from torchrl.modules import SafeSequential from torchrl.modules.tensordict_module.actors import QValueActor from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible -from torchrl.modules.utils.utils import _find_action_space - from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index aeb9adbafea..a1ebf0e5873 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -6,7 +6,7 @@ import warnings from dataclasses import dataclass from numbers import Number -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import torch @@ -15,8 +15,9 @@ from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data import CompositeSpec -from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp +from torchrl.data import CompositeSpec, TensorSpec +from torchrl.data.utils import _find_action_space +from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor from torchrl.modules.tensordict_module.actors import ActorCriticWrapper @@ -516,18 +517,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: tensordict_reshape = tensordict - device = self.device - td_device = tensordict_reshape.to(device) - if self._version == 1: - loss_qvalue, priority = self._loss_qvalue_v1(td_device) - loss_value = self._loss_value(td_device) + loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict_reshape) + loss_value, _ = self._value_loss(tensordict_reshape) else: - loss_qvalue, priority = self._loss_qvalue_v2(td_device) + loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape) loss_value = None - loss_actor = self._loss_actor(td_device) - loss_alpha = self._loss_alpha(td_device) - tensordict_reshape.set(self.tensor_keys.priority, priority) + loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) + loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"]) + tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) if (loss_actor.shape != loss_qvalue.shape) or ( loss_value is not None and loss_actor.shape != loss_value.shape ): @@ -536,12 +534,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) if shape: tensordict.update(tensordict_reshape.view(shape)) + entropy = -metadata_actor["log_prob"].mean() out = { "loss_actor": loss_actor.mean(), "loss_qvalue": loss_qvalue.mean(), "loss_alpha": loss_alpha.mean(), "alpha": self._alpha, - "entropy": -td_device.get(self.tensor_keys.log_prob).mean().detach(), + "entropy": entropy, } if self._version == 1: out["loss_value"] = loss_value.mean() @@ -552,7 +551,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _cached_detached_qvalue_params(self): return self.qvalue_network_params.detach() - def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: + def _actor_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: with set_exploration_type(ExplorationType.RANDOM): dist = self.actor_network.get_dist( tensordict, @@ -575,9 +576,7 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}" ) - # write log_prob in tensordict for alpha loss - tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) - return self._alpha * log_prob - min_q_logprob + return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()} @property @_cache_values @@ -593,7 +592,9 @@ def _cached_target_params_actor_value(self): _run_checks=False, ) - def _loss_qvalue_v1(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def _qvalue_v1_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: target_params = self._cached_target_params_actor_value with set_exploration_type(ExplorationType.MODE): target_value = self.value_estimator.value_estimate( @@ -623,11 +624,13 @@ def _loss_qvalue_v1(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: loss_value = distance_loss( pred_val, target_chunks, loss_function=self.loss_function ).view(*shape) - priority_value = torch.cat((pred_val - target_chunks).pow(2).unbind(0), 0) + metadata = { + "td_error": torch.cat((pred_val - target_chunks).pow(2).unbind(0), 0) + } - return loss_value, priority_value + return loss_value, metadata - def _get_value_v2(self, tensordict, _alpha, actor_params, qval_params): + def _compute_target_v2(self, tensordict) -> Tensor: r"""Value network for SAC v2. SAC v2 is based on a value estimate of the form: @@ -645,14 +648,16 @@ def _get_value_v2(self, tensordict, _alpha, actor_params, qval_params): with set_exploration_type(ExplorationType.RANDOM): next_tensordict = tensordict.get("next").clone(False) next_dist = self.actor_network.get_dist( - next_tensordict, params=actor_params + next_tensordict, params=self.actor_network_params ) next_action = next_dist.rsample() next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) # get q-values - next_tensordict_expand = self._vmap_qnetworkN0(next_tensordict, qval_params) + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict, self.target_qvalue_network_params + ) state_action_value = next_tensordict_expand.get( self.tensor_keys.state_action_value ) @@ -661,7 +666,7 @@ def _get_value_v2(self, tensordict, _alpha, actor_params, qval_params): != next_sample_log_prob.shape ): next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) - next_state_value = state_action_value - _alpha * next_sample_log_prob + next_state_value = state_action_value - self._alpha * next_sample_log_prob next_state_value = next_state_value.min(0)[0] tensordict.set( ("next", self.value_estimator.tensor_keys.value), next_state_value @@ -669,14 +674,11 @@ def _get_value_v2(self, tensordict, _alpha, actor_params, qval_params): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) return target_value - def _loss_qvalue_v2(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def _qvalue_v2_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. - target_value = self._get_value_v2( - tensordict, - self._alpha, - self.actor_network_params, - self.target_qvalue_network_params, - ) + target_value = self._compute_target_v2(tensordict) tensordict_expand = self._vmap_qnetworkN0( tensordict.select(*self.qvalue_network.in_keys), @@ -691,9 +693,12 @@ def _loss_qvalue_v2(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: target_value.expand_as(pred_val), loss_function=self.loss_function, ).mean(0) - return loss_qval, td_error.detach().max(0)[0] + metadata = {"td_error": td_error.detach().max(0)[0]} + return loss_qval, metadata - def _loss_value(self, tensordict: TensorDictBase) -> Tensor: + def _value_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: # value loss td_copy = tensordict.select(*self.value_network.in_keys).detach() self.value_network( @@ -729,16 +734,16 @@ def _loss_value(self, tensordict: TensorDictBase) -> Tensor: loss_value = distance_loss( pred_val, target_val, loss_function=self.loss_function ) - return loss_value + return loss_value, {} + + def _alpha_loss(self, log_prob: Tensor) -> Tensor: - def _loss_alpha(self, tensordict: TensorDictBase) -> Tensor: - log_pi = tensordict.get(self.tensor_keys.log_prob) if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha * (log_pi.detach() + self.target_entropy) + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) else: # placeholder - alpha_loss = torch.zeros_like(log_pi) + alpha_loss = torch.zeros_like(log_prob) return alpha_loss @property @@ -756,9 +761,15 @@ class DiscreteSACLoss(LossModule): Args: actor_network (ProbabilisticActor): the actor to be trained qvalue_network (TensorDictModule): a single Q-value network that will be multiplicated as many times as needed. - num_actions (int): number of actions in the action space. + action_space (str or TensorSpec): Action space. Must be one of + ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + num_actions (int, optional): number of actions in the action space. + To be provided if target_entropy is ste to "auto". num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. - loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", + loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", "l1", Default is "smooth_l1". alpha_init (float, optional): initial entropy multiplier. Default is 1.0. @@ -788,52 +799,44 @@ class DiscreteSACLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import DiscreteSACLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule >>> n_act, n_obs = 4, 3 >>> spec = OneHotDiscreteTensorSpec(n_act) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) - >>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"]) + >>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, ... in_keys=["logits"], ... out_keys=["action"], ... spec=spec, ... distribution_class=OneHotCategorical) - >>> class ValueClass(nn.Module): - ... def __init__(self): - ... super().__init__() - ... self.linear = nn.Linear(n_obs, n_act) - ... def forward(self, obs): - ... return self.linear(obs) - >>> module = ValueClass() - >>> qvalue = ValueOperator( - ... module=module, - ... in_keys=['observation']) - >>> loss = DiscreteSACLoss(actor, qvalue, num_actions=actor.spec["action"].space.n) - >>> batch = [2, ] + >>> qvalue = TensorDictModule( + ... nn.Linear(n_obs, n_act), + ... in_keys=["observation"], + ... out_keys=["action_value"], + ... ) + >>> loss = DiscreteSACLoss(actor, qvalue, action_space=spec, num_actions=spec.space.n) + >>> batch = [2,] >>> action = spec.rand(batch) >>> data = TensorDict({ - ... "observation": torch.randn(*batch, n_obs), - ... "action": action, - ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), - ... ("next", "reward"): torch.randn(*batch, 1), - ... ("next", "observation"): torch.randn(*batch, n_obs), + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) >>> loss(data) TensorDict( - fields={ - action_log_prob_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - next.state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([]), - device=None, - is_shared=False) + fields={ + alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, @@ -841,8 +844,7 @@ class DiscreteSACLoss(LossModule): ``["action", "next_reward", "next_done"]`` + in_keys of the actor and qvalue network. The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_alpha", - "alpha", "entropy", "state_action_value_actor", - "action_log_prob_actor", "next.state_value", "target_value"]`` + "alpha", "entropy"]`` The output keys can also be filtered using :meth:`DiscreteSACLoss.select_out_keys` method. Examples: @@ -911,9 +913,11 @@ class _AcceptedKeys: action: NestedKey = "action" value: NestedKey = "state_value" + action_value: NestedKey = "action_value" priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + log_prob: NestedKey = "log_prob" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -924,18 +928,15 @@ class _AcceptedKeys: "loss_alpha", "alpha", "entropy", - "state_action_value_actor", - "action_log_prob_actor", - "next.state_value", - "target_value", ] def __init__( self, actor_network: ProbabilisticActor, qvalue_network: TensorDictModule, - num_actions: int, # replace with spec? *, + action_space: Union[str, TensorSpec] = None, + num_actions: Optional[int] = None, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", alpha_init: float = 1.0, @@ -958,7 +959,7 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist_params"], + funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -1011,22 +1012,24 @@ def __init__( torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), ) + if action_space is None: + warnings.warn( + "action_space was not specified. DiscreteSACLoss will default to 'one-hot'." + "This behaviour will be deprecated soon and a space will have to be passed." + "Check the DiscreteSACLoss documentation to see how to pass the action space. " + ) + action_space = "one-hot" + self.action_space = _find_action_space(action_space) if target_entropy == "auto": + if num_actions is None: + raise ValueError( + "num_actions needs to be provided if target_entropy == 'auto'" + ) target_entropy = -float(np.log(1.0 / num_actions) * target_entropy_weight) self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) - - self._vmap_getdist = vmap(self.actor_network.get_dist_params) - self._vmap_qnetwork = vmap(self.qvalue_network) - - @property - def alpha(self): - if self.min_log_alpha is not None: - self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) - with torch.no_grad(): - alpha = self.log_alpha.exp() - return alpha + self._vmap_qnetworkN0 = vmap(self.qvalue_network, (None, 0)) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -1060,172 +1063,172 @@ def in_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - obs_keys = self.actor_network.in_keys - tensordict_select = tensordict.clone(False).select( - "next", *obs_keys, self.tensor_keys.action - ) + shape = None + if tensordict.ndimension() > 1: + shape = tensordict.shape + tensordict_reshape = tensordict.reshape(-1) + else: + tensordict_reshape = tensordict - actor_params = torch.stack( - [self.actor_network_params, self.target_actor_network_params], 0 + loss_value, metadata_value = self._value_loss(tensordict_reshape) + loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) + loss_alpha = self._alpha_loss( + log_prob=metadata_actor["log_prob"], ) - tensordict_actor_grad = tensordict_select.select( - *obs_keys - ) # to avoid overwriting keys - next_td_actor = step_mdp(tensordict_select).select( - *self.actor_network.in_keys - ) # next_observation -> - tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) - tensordict_actor = tensordict_actor.contiguous() - - with set_exploration_type(ExplorationType.RANDOM): - # vmap doesn't support sampling, so we take it out from the vmap - td_params = self._vmap_getdist( - tensordict_actor, - actor_params, + tensordict_reshape.set(self.tensor_keys.priority, metadata_value["td_error"]) + if loss_actor.shape != loss_value.shape: + raise RuntimeError( + f"Losses shape mismatch: {loss_actor.shape}, and {loss_value.shape}" ) - if isinstance(self.actor_network, ProbabilisticActor): - tensordict_actor_dist = self.actor_network.build_dist_from_params( - td_params - ) - else: - tensordict_actor_dist = self.actor_network.build_dist_from_params( - td_params - ) - probs = tensordict_actor_dist.probs - z = (probs == 0.0).float() * 1e-8 - logp_pi = torch.log(probs + z) - logp_pi_pol = torch.sum(probs * logp_pi, dim=-1, keepdim=True) - - # repeat tensordict_actor to match the qvalue size - _actor_loss_td = ( - tensordict_actor[0] - .select(*self.qvalue_network.in_keys) - .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) - ) # for actor loss - _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( - self.num_qvalue_nets, - *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, - ) # for qvalue loss - _next_val_td = ( - tensordict_actor[1] - .select(*self.qvalue_network.in_keys) - .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) - ) # for next value estimation - tensordict_qval = torch.cat( - [ - _actor_loss_td, - _next_val_td, - _qval_td, - ], - 0, - ) + if shape: + tensordict.update(tensordict_reshape.view(shape)) + entropy = -metadata_actor["log_prob"].mean() + out = { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_value.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self._alpha, + "entropy": entropy, + } + return TensorDict(out, []) - # cat params - q_params_detach = self.qvalue_network_params.detach() - qvalue_params = torch.cat( - [ - q_params_detach, - self.target_qvalue_network_params, - self.qvalue_network_params, - ], - 0, - ) - tensordict_qval = self._vmap_qnetwork( - tensordict_qval, - qvalue_params, - ) + def _compute_target(self, tensordict) -> Tensor: + r"""Value network for SAC v2. - state_action_value = tensordict_qval.get(self.tensor_keys.value).squeeze(-1) - ( - state_action_value_actor, - next_state_action_value_qvalue, - state_action_value_qvalue, - ) = state_action_value.split( - [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], - dim=0, - ) + SAC v2 is based on a value estimate of the form: - loss_actor = -( - (state_action_value_actor.min(0)[0] * probs[0]).sum(-1, keepdim=True) - - self.alpha * logp_pi_pol[0] - ).mean() + .. math:: - pred_next_val = ( - probs[1] - * (next_state_action_value_qvalue.min(0)[0] - self.alpha * logp_pi[1]) - ).sum(dim=-1, keepdim=True) + V = Q(s,a) - \alpha * \log p(a | s) - tensordict_select.set( - ("next", self.value_estimator.tensor_keys.value), pred_next_val - ) - target_value = self.value_estimator.value_estimate(tensordict_select).squeeze( - -1 - ) + This class computes this value given the actor and qvalue network + + """ + tensordict = tensordict.clone(False) + # get actions and log-probs + with torch.no_grad(): + next_tensordict = tensordict.get("next").clone(False) - actions = torch.argmax(tensordict_select.get(self.tensor_keys.action), dim=-1) + # get probs and log probs for actions computed from "next" + next_dist = self.actor_network.get_dist( + next_tensordict, params=self.actor_network_params + ) + next_prob = next_dist.probs + next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob)) - pred_val_1 = ( - state_action_value_qvalue[0].gather(-1, actions.unsqueeze(-1)).unsqueeze(0) - ) - pred_val_2 = ( - state_action_value_qvalue[1].gather(-1, actions.unsqueeze(-1)).unsqueeze(0) - ) - pred_val = torch.cat([pred_val_1, pred_val_2], dim=0).squeeze() - td_error = (pred_val - target_value.expand_as(pred_val)).pow(2) - loss_qval = ( - distance_loss( - pred_val, - target_value.expand_as(pred_val), - loss_function=self.loss_function, + # get q-values for all actions + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict, self.target_qvalue_network_params + ) + next_action_value = next_tensordict_expand.get( + self.tensor_keys.action_value ) - .mean(-1) - .sum() - * 0.5 - ) - tensordict.set(self.tensor_keys.priority, td_error.detach().max(0)[0]) + # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term + next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob + # unlike in continuous SAC, we can compute the exact expectation over all discrete actions + next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) - loss_alpha = self._loss_alpha(logp_pi_pol) - if not loss_qval.shape == loss_actor.shape: - raise RuntimeError( - f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + tensordict.set( + ("next", self.value_estimator.tensor_keys.value), next_state_value ) - td_out = TensorDict( - { - "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_qval.mean(), - "loss_alpha": loss_alpha.mean(), - "alpha": self.alpha.detach(), - "entropy": -logp_pi.mean().detach(), - "state_action_value_actor": state_action_value_actor.mean().detach(), - "action_log_prob_actor": logp_pi.mean().detach(), - "next.state_value": pred_next_val.mean().detach(), - "target_value": target_value.mean().detach(), - }, - [], + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + return target_value + + def _value_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: + target_value = self._compute_target(tensordict) + tensordict_expand = self._vmap_qnetworkN0( + tensordict.select(*self.qvalue_network.in_keys), + self.qvalue_network_params, ) - return td_out + action_value = tensordict_expand.get(self.tensor_keys.action_value) + action = tensordict.get(self.tensor_keys.action) + action = action.expand((action_value.shape[0], *action.shape)) # Add vmap dim + + # TODO this block comes from the dqn loss, we need to swap all these with a proper + # helper function which selects the value given the action for all discrete spaces + if self.action_space == "categorical": + if action.shape != action_value.shape: + # unsqueeze the action if it lacks on trailing singleton dim + action = action.unsqueeze(-1) + chosen_action_value = torch.gather(action_value, -1, index=action).squeeze( + -1 + ) + else: + action = action.to(torch.float) + chosen_action_value = (action_value * action).sum(-1) + + td_error = torch.abs(chosen_action_value - target_value) + loss_qval = distance_loss( + chosen_action_value, + target_value.expand_as(chosen_action_value), + loss_function=self.loss_function, + ).mean(0) + + metadata = { + "td_error": td_error.detach().max(0)[0], + } + return loss_qval, metadata - def _loss_alpha(self, log_pi: Tensor) -> Tensor: - if torch.is_grad_enabled() and not log_pi.requires_grad: + def _actor_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: + # get probs and log probs for actions + dist = self.actor_network.get_dist( + tensordict, + params=self.actor_network_params, + ) + prob = dist.probs + log_prob = torch.log(torch.where(prob == 0, 1e-8, prob)) + + td_q = tensordict.select(*self.qvalue_network.in_keys) + td_q = self._vmap_qnetworkN0( + td_q, self._cached_detached_qvalue_params # should we clone? + ) + min_q = td_q.get(self.tensor_keys.action_value).min(0)[0] + + if log_prob.shape != min_q.shape: raise RuntimeError( - "expected log_pi to require gradient for the alpha loss)" + f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}" ) + + # like in continuous SAC, we take the entropy term and subtract the minimum of the value ensemble + loss = self._alpha * log_prob - min_q + # unlike in continuous SAC, we can compute the exact expectation over all discrete actions + loss = (prob * loss).sum(-1) + + return loss, {"log_prob": (log_prob * prob).sum(-1).detach()} + + def _alpha_loss(self, log_prob: Tensor) -> Tensor: if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha * (log_pi.detach() + self.target_entropy) + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) else: # placeholder - alpha_loss = torch.zeros_like(log_pi) + alpha_loss = torch.zeros_like(log_prob) return alpha_loss + @property + def _alpha(self): + if self.min_log_alpha is not None: + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha + + @property + @_cache_values + def _cached_detached_qvalue_params(self): + return self.qvalue_network_params.detach() + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: value_type = self.default_value_estimator self.value_type = value_type - value_net = None hp = dict(default_value_kwargs(value_type)) hp.update(hyperparams) if hasattr(self, "gamma"): @@ -1233,12 +1236,12 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams if value_type is ValueEstimators.TD1: self._value_estimator = TD1Estimator( **hp, - value_network=value_net, + value_network=None, ) elif value_type is ValueEstimators.TD0: self._value_estimator = TD0Estimator( **hp, - value_network=value_net, + value_network=None, ) elif value_type is ValueEstimators.GAE: raise NotImplementedError( @@ -1247,7 +1250,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams elif value_type is ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator( **hp, - value_network=value_net, + value_network=None, ) else: raise NotImplementedError(f"Unknown value type {value_type}")