diff --git a/test/test_cost.py b/test/test_cost.py index 5686249accf..41fc0781a01 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -490,7 +490,10 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DQNLoss( - actor, loss_function="l2", delay_value=delay_value, double_dqn=double_dqn + actor, + loss_function="l2", + delay_value=delay_value, + double_dqn=double_dqn, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): @@ -699,7 +702,11 @@ def test_distributional_dqn( td = self._create_mock_data_dqn( action_spec_type=action_spec_type, atoms=atoms ).to(device) - loss_fn = DistributionalDQNLoss(actor, gamma=gamma, delay_value=delay_value) + loss_fn = DistributionalDQNLoss( + actor, + gamma=gamma, + delay_value=delay_value, + ) if td_est not in (None, ValueEstimators.TD0): with pytest.raises(NotImplementedError): @@ -717,6 +724,7 @@ def test_distributional_dqn( else contextlib.nullcontext() ): loss = loss_fn(td) + assert loss_fn.tensor_keys.priority in td.keys() sum([item for _, item in loss.items()]).backward() @@ -843,6 +851,58 @@ def test_distributional_dqn_tensordict_run(self, action_spec_type, td_est): _ = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_dqn_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + actor = self._create_mock_actor(action_spec_type="categorical", device=device) + td = self._create_mock_data_dqn(action_spec_type="categorical", device=device) + loss_fn = DQNLoss( + actor, + loss_function="l2", + delay_value=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + + @pytest.mark.parametrize("atoms", range(4, 10)) + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_distributional_dqn_reduction(self, reduction, atoms): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + actor = self._create_mock_distributional_actor( + action_spec_type="categorical", atoms=atoms + ).to(device) + td = self._create_mock_data_dqn(action_spec_type="categorical", device=device) + loss_fn = DistributionalDQNLoss( + actor, gamma=0.9, delay_value=False, reduction=reduction + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + class TestQMixer(LossModuleTestBase): seed = 0 @@ -1884,6 +1944,35 @@ def test_ddpg_notensordict(self): assert loss_actor == loss_val_td["loss_actor"] assert (target_value == loss_val_td["target_value"]).all() + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_ddpg_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_mock_data_ddpg(device=device) + loss_fn = DDPGLoss( + actor, + value, + loss_function="l2", + delay_actor=False, + delay_value=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" @@ -2553,6 +2642,39 @@ def test_td3_notensordict( assert loss_actor == loss_val_td["loss_actor"] assert loss_qvalue == loss_val_td["loss_qvalue"] + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_td3_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_mock_data_td3(device=device) + action_spec = actor.spec + bounds = None + loss_fn = TD3Loss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + delay_qvalue=False, + delay_actor=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" @@ -2820,6 +2942,7 @@ def test_sac( UserWarning, match="No target network updater" ): loss = loss_fn(td) + assert loss_fn.tensor_keys.priority in td.keys() # check that losses are independent @@ -3420,6 +3543,41 @@ def test_state_dict(self, version): ) loss.load_state_dict(state) + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_sac_reduction(self, reduction, version): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_mock_data_sac(device=device) + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + if version == 1: + value = self._create_mock_value(device=device) + else: + value = None + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + loss_function="l2", + delay_qvalue=False, + delay_actor=False, + delay_value=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" @@ -3609,6 +3767,7 @@ def test_discrete_sac( UserWarning, match="No target network updater" ): loss = loss_fn(td) + assert loss_fn.tensor_keys.priority in td.keys() # check that losses are independent @@ -3968,6 +4127,36 @@ def test_discrete_sac_notensordict( assert loss_actor == loss_val_td["loss_actor"] assert loss_alpha == loss_val_td["loss_alpha"] + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_discrete_sac_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_mock_data_sac(device=device) + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + loss_fn = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec["action"].space.n, + loss_function="l2", + action_space="one-hot", + delay_qvalue=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" @@ -4874,6 +5063,44 @@ def test_redq_notensordict( assert loss_actor == loss_val_td["loss_actor"] assert loss_alpha == loss_val_td["loss_alpha"] + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + @pytest.mark.parametrize("deprecated_loss", [True, False]) + def test_redq_reduction(self, reduction, deprecated_loss): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_mock_data_redq(device=device) + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + if deprecated_loss: + loss_fn = REDQLoss_deprecated( + actor_network=actor, + qvalue_network=qvalue, + loss_function="l2", + delay_qvalue=False, + reduction=reduction, + ) + else: + loss_fn = REDQLoss( + actor_network=actor, + qvalue_network=qvalue, + loss_function="l2", + delay_qvalue=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape[-1] == td.shape[0] + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + class TestCQL(LossModuleTestBase): seed = 0 diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 03e82689ad5..2ecdfde6bb3 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -5,6 +5,7 @@ from __future__ import annotations +import functools from copy import deepcopy from dataclasses import dataclass from typing import Tuple @@ -19,6 +20,7 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, + _reduce, default_value_kwargs, distance_loss, ValueEstimators, @@ -41,6 +43,10 @@ class DDPGLoss(LossModule): policy and critic will only be trained on the policy loss. Defaults to ``False``, ie. gradients are propagated to shared parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Examples: >>> import torch @@ -189,8 +195,11 @@ def __init__( delay_value: bool = True, gamma: float = None, separate_losses: bool = False, + reduction: str = None, ) -> None: self._in_keys = None + if reduction is None: + reduction = "mean" super().__init__() self.delay_actor = delay_actor self.delay_value = delay_value @@ -229,7 +238,7 @@ def __init__( ) self.loss_function = loss_function - + self.reduction = reduction if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) @@ -283,10 +292,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: loss_value, metadata = self.loss_value(tensordict) loss_actor, metadata_actor = self.loss_actor(tensordict) metadata.update(metadata_actor) - return TensorDict( + td_out = TensorDict( source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata}, batch_size=[], ) + td_out = td_out.apply( + functools.partial(_reduce, reduction=self.reduction), batch_size=[] + ) + return td_out def loss_actor( self, @@ -299,9 +312,9 @@ def loss_actor( td_copy = self.actor_network(td_copy) with self._cached_detached_value_params.to_module(self.value_network): td_copy = self.value_network(td_copy) - loss_actor = -td_copy.get(self.tensor_keys.state_action_value) + loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1) metadata = {} - return loss_actor.mean(), metadata + return loss_actor, metadata def loss_value( self, @@ -333,13 +346,13 @@ def loss_value( ) with torch.no_grad(): metadata = { - "td_error": td_error.mean(), - "pred_value": pred_val.mean(), - "target_value": target_value.mean(), + "td_error": td_error, + "pred_value": pred_val, + "target_value": target_value, "target_value_max": target_value.max(), "pred_value_max": pred_val.max(), } - return loss_value.mean(), metadata + return loss_value, metadata def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 3ff093d445c..d08edee71bd 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import math from dataclasses import dataclass from numbers import Number @@ -22,6 +23,7 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, + _reduce, _vmap_func, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -74,6 +76,10 @@ class REDQLoss_deprecated(LossModule): policy and critic will only be trained on the policy loss. Defaults to ``False``, ie. gradients are propagated to shared parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. """ @dataclass @@ -136,9 +142,12 @@ def __init__( gamma: float = None, priority_key: str = None, separate_losses: bool = False, + reduction: str = None, ): self._in_keys = None self._out_keys = None + if reduction is None: + reduction = "mean" super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -197,6 +206,7 @@ def __init__( self._action_spec = action_spec self.target_entropy_buffer = None self.gSDE = gSDE + self.reduction = reduction self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) @@ -298,15 +308,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) td_out = TensorDict( { - "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_qval.mean(), - "loss_alpha": loss_alpha.mean(), + "loss_actor": loss_actor, + "loss_qvalue": loss_qval, + "loss_alpha": loss_alpha, "alpha": self.alpha, - "entropy": -sample_log_prob.mean().detach(), + "entropy": -sample_log_prob.detach(), }, [], ) - + td_out = td_out.apply( + functools.partial(_reduce, reduction=self.reduction), batch_size=[] + ) return td_out @property @@ -330,7 +342,7 @@ def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: loss_actor = -( state_action_value - self.alpha * tensordict_clone.get("sample_log_prob").squeeze(-1) - ).mean(0) + ) return loss_actor, tensordict_clone.get("sample_log_prob") def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: @@ -385,7 +397,7 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: pred_val, target_value.expand_as(pred_val), loss_function=self.loss_function, - ).mean(0) + ) tensordict_save.set("td_error", td_error.detach().max(0)[0]) return loss_qval diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 2298c262368..28e40cea9d4 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -25,6 +25,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_ERROR, + _reduce, default_value_kwargs, distance_loss, ValueEstimators, @@ -58,6 +59,10 @@ class DQNLoss(LossModule): The key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. This is to be used when the sampler is of type :class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Examples: >>> from torchrl.modules import MLP @@ -171,6 +176,7 @@ def __init__( gamma: float = None, action_space: Union[str, TensorSpec] = None, priority_key: str = None, + reduction: str = None, ) -> None: if delay_value is None: warnings.warn( @@ -182,6 +188,8 @@ def __init__( "script." ) delay_value = False + if reduction is None: + reduction = "mean" super().__init__() self._in_keys = None if double_dqn and not delay_value: @@ -225,7 +233,7 @@ def __init__( ) action_space = "one-hot" self.action_space = _find_action_space(action_space) - + self.reduction = reduction if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) @@ -362,7 +370,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: inplace=True, ) loss = distance_loss(pred_val_index, target_value, self.loss_function) - return TensorDict({"loss": loss.mean()}, []) + loss = _reduce(loss, reduction=self.reduction) + td_out = TensorDict({"loss": loss}, []) + return td_out class DistributionalDQNLoss(LossModule): @@ -385,13 +395,16 @@ class DistributionalDQNLoss(LossModule): Unlike :class:`DQNLoss`, this class does not currently support custom value functions. The next value estimation is always bootstrapped. + delay_value (bool): whether to duplicate the value network into a new + target value network to create double DQN priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] The key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. This is to be used when the sampler is of type :class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``. - - delay_value (bool): whether to duplicate the value network into a new - target value network to create double DQN + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. """ @dataclass @@ -432,9 +445,11 @@ class _AcceptedKeys: def __init__( self, value_network: Union[DistributionalQValueActor, nn.Module], + *, gamma: float, delay_value: bool = None, priority_key: str = None, + reduction: str = None, ): if delay_value is None: warnings.warn( @@ -446,6 +461,8 @@ def __init__( "script." ) delay_value = False + if reduction is None: + reduction = "mean" super().__init__() self._set_deprecated_ctor_keys(priority=priority_key) self.register_buffer("gamma", torch.tensor(gamma)) @@ -461,6 +478,7 @@ def __init__( create_target_params=self.delay_value, ) self.action_space = self.value_network.action_space + self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: pass @@ -596,8 +614,9 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: loss.detach().unsqueeze(1).to(input_tensordict.device), inplace=True, ) - loss_td = TensorDict({"loss": loss.mean()}, []) - return loss_td + loss = _reduce(loss, reduction=self.reduction) + td_out = TensorDict({"loss": loss}, []) + return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 61aaf5990e4..1c4e8785240 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import math from dataclasses import dataclass from numbers import Number @@ -21,6 +22,7 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, + _reduce, _vmap_func, default_value_kwargs, distance_loss, @@ -76,6 +78,10 @@ class REDQLoss(LossModule): policy and critic will only be trained on the policy loss. Defaults to ``False``, ie. gradients are propagated to shared parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Examples: >>> import torch @@ -252,7 +258,10 @@ def __init__( gamma: float = None, priority_key: str = None, separate_losses: bool = False, + reduction: str = None, ): + if reduction is None: + reduction = "mean" super().__init__() self._in_keys = None self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -309,7 +318,7 @@ def __init__( self._target_entropy = target_entropy self._action_spec = action_spec self.target_entropy_buffer = None - + self.reduction = reduction self.gSDE = gSDE if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) @@ -520,9 +529,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: next_action_log_prob_qvalue, ) = sample_log_prob.unbind(0) - loss_actor = -( - state_action_value_actor - self.alpha * action_log_prob_actor - ).mean(0) + loss_actor = -(state_action_value_actor - self.alpha * action_log_prob_actor) next_state_value = ( next_state_action_value_qvalue - self.alpha * next_action_log_prob_qvalue @@ -542,7 +549,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: pred_val, target_value.expand_as(pred_val), loss_function=self.loss_function, - ).mean(0) + ) tensordict.set(self.tensor_keys.priority, td_error.detach().max(0)[0]) @@ -553,19 +560,21 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) td_out = TensorDict( { - "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_qval.mean(), - "loss_alpha": loss_alpha.mean(), + "loss_actor": loss_actor, + "loss_qvalue": loss_qval, + "loss_alpha": loss_alpha, "alpha": self.alpha.detach(), - "entropy": -sample_log_prob.mean().detach(), - "state_action_value_actor": state_action_value_actor.mean().detach(), - "action_log_prob_actor": action_log_prob_actor.mean().detach(), - "next.state_value": next_state_value.mean().detach(), - "target_value": target_value.mean().detach(), + "entropy": -sample_log_prob.detach(), + "state_action_value_actor": state_action_value_actor.detach(), + "action_log_prob_actor": action_log_prob_actor.detach(), + "next.state_value": next_state_value.detach(), + "target_value": target_value.detach(), }, [], ) - + td_out = td_out.apply( + functools.partial(_reduce, reduction=self.reduction), batch_size=[] + ) return td_out def _loss_alpha(self, log_pi: Tensor) -> Tensor: diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 5b722fd05f3..c80b56ae77b 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import math import warnings from dataclasses import dataclass @@ -26,6 +27,7 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, + _reduce, _vmap_func, default_value_kwargs, distance_loss, @@ -97,6 +99,10 @@ class SACLoss(LossModule): policy and critic will only be trained on the policy loss. Defaults to ``False``, ie. gradients are propagated to shared parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Examples: >>> import torch @@ -280,9 +286,12 @@ def __init__( gamma: float = None, priority_key: str = None, separate_losses: bool = False, + reduction: str = None, ) -> None: self._in_keys = None self._out_keys = None + if reduction is None: + reduction = "mean" super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -381,6 +390,7 @@ def __init__( self._vmap_qnetwork00 = _vmap_func( qvalue_network, randomness=self.vmap_randomness ) + self.reduction = reduction @property def target_entropy_buffer(self): @@ -557,17 +567,21 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) if shape: tensordict.update(tensordict_reshape.view(shape)) - entropy = -metadata_actor["log_prob"].mean() + entropy = -metadata_actor["log_prob"] out = { - "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_qvalue.mean(), - "loss_alpha": loss_alpha.mean(), + "loss_actor": loss_actor, + "loss_qvalue": loss_qvalue, + "loss_alpha": loss_alpha, "alpha": self._alpha, "entropy": entropy, } if self._version == 1: - out["loss_value"] = loss_value.mean() - return TensorDict(out, []) + out["loss_value"] = loss_value + td_out = TensorDict(out, []) + td_out = td_out.apply( + functools.partial(_reduce, reduction=self.reduction), batch_size=[] + ) + return td_out @property @_cache_values @@ -805,6 +819,10 @@ class DiscreteSACLoss(LossModule): policy and critic will only be trained on the policy loss. Defaults to ``False``, ie. gradients are propagated to shared parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Examples: >>> import torch @@ -970,7 +988,10 @@ def __init__( delay_qvalue: bool = True, priority_key: str = None, separate_losses: bool = False, + reduction: str = None, ): + if reduction is None: + reduction = "mean" self._in_keys = None super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -1051,6 +1072,7 @@ def __init__( self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) + self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -1106,15 +1128,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) if shape: tensordict.update(tensordict_reshape.view(shape)) - entropy = -metadata_actor["log_prob"].mean() + entropy = -metadata_actor["log_prob"] out = { - "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_value.mean(), - "loss_alpha": loss_alpha.mean(), + "loss_actor": loss_actor, + "loss_qvalue": loss_value, + "loss_alpha": loss_alpha, "alpha": self._alpha, "entropy": entropy, } - return TensorDict(out, []) + td_out = TensorDict(out, []) + td_out = td_out.apply( + functools.partial(_reduce, reduction=self.reduction), batch_size=[] + ) + return td_out def _compute_target(self, tensordict) -> Tensor: r"""Value network for SAC v2. @@ -1189,7 +1215,7 @@ def _value_loss( chosen_action_value, target_value.expand_as(chosen_action_value), loss_function=self.loss_function, - ).mean(0) + ).sum(0) metadata = { "td_error": td_error.detach().max(0)[0], diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 877a8f0c819..d6f4d2c10c8 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools from dataclasses import dataclass from typing import Optional, Tuple @@ -18,6 +19,7 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, + _reduce, _vmap_func, default_value_kwargs, distance_loss, @@ -65,6 +67,10 @@ class TD3Loss(LossModule): policy and critic will only be trained on the policy loss. Defaults to ``False``, ie. gradients are propagated to shared parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Examples: >>> import torch @@ -217,7 +223,10 @@ def __init__( gamma: float = None, priority_key: str = None, separate_losses: bool = False, + reduction: str = None, ) -> None: + if reduction is None: + reduction = "mean" super().__init__() self._in_keys = None self._set_deprecated_ctor_keys(priority=priority_key) @@ -299,6 +308,7 @@ def __init__( self._vmap_actor_network00 = _vmap_func( self.actor_network, randomness=self.vmap_randomness ) + self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -361,9 +371,9 @@ def actor_loss(self, tensordict): .get(self.tensor_keys.state_action_value) .squeeze(-1) ) - loss_actor = -(state_action_value_actor[0]).mean() + loss_actor = -(state_action_value_actor[0]) metadata = { - "state_action_value_actor": state_action_value_actor.mean().detach(), + "state_action_value_actor": state_action_value_actor.detach(), } return loss_actor, metadata @@ -428,22 +438,17 @@ def value_loss(self, tensordict): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) td_error = (current_qvalue - target_value).pow(2) - loss_qval = ( - distance_loss( - current_qvalue, - target_value.expand_as(current_qvalue), - loss_function=self.loss_function, - ) - .mean(-1) - .sum() - ) + loss_qval = distance_loss( + current_qvalue, + target_value.expand_as(current_qvalue), + loss_function=self.loss_function, + ).sum(0) metadata = { "td_error": td_error, - "next_state_value": next_target_qvalue.mean().detach(), - "pred_value": current_qvalue.mean().detach(), - "target_value": target_value.mean().detach(), + "next_state_value": next_target_qvalue.detach(), + "pred_value": current_qvalue.detach(), + "target_value": target_value.detach(), } - return loss_qval, metadata @dispatch @@ -467,7 +472,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: }, batch_size=[], ) - + td_out = td_out.apply( + functools.partial(_reduce, reduction=self.reduction), batch_size=[] + ) return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):