diff --git a/test/test_cost.py b/test/test_cost.py index 70758ff7d5d..9ac2bb6b950 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6854,6 +6854,71 @@ def test_cql( p.grad is None or p.grad.norm() == 0.0 ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + @pytest.mark.parametrize("delay_actor", (True,)) + @pytest.mark.parametrize("delay_qvalue", (True,)) + @pytest.mark.parametrize( + "max_q_backup", + [ + True, + ], + ) + @pytest.mark.parametrize( + "deterministic_backup", + [ + True, + ], + ) + @pytest.mark.parametrize( + "with_lagrange", + [ + True, + ], + ) + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("td_est", [None]) + def test_cql_qvalfromlist( + self, + delay_actor, + delay_qvalue, + max_q_backup, + deterministic_backup, + with_lagrange, + device, + td_est, + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_cql(device=device) + + actor = self._create_mock_actor(device=device) + qvalue0 = self._create_mock_qvalue(device=device) + qvalue1 = self._create_mock_qvalue(device=device) + + loss_fn_single = CQLLoss( + actor_network=actor, + qvalue_network=qvalue0, + loss_function="l2", + max_q_backup=max_q_backup, + deterministic_backup=deterministic_backup, + with_lagrange=with_lagrange, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + loss_fn_mult = CQLLoss( + actor_network=actor, + qvalue_network=[qvalue0, qvalue1], + loss_function="l2", + max_q_backup=max_q_backup, + deterministic_backup=deterministic_backup, + with_lagrange=with_lagrange, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + # Check that all params have the same shape + p2 = dict(loss_fn_mult.named_parameters()) + for key, val in loss_fn_single.named_parameters(): + assert val.shape == p2[key].shape + assert len(dict(loss_fn_single.named_parameters())) == len(p2) + @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("max_q_backup", [True]) @@ -14605,6 +14670,62 @@ def init(mod): loss.from_stateful_net("module_a", module_a) assert (loss.module_a_params == 1).all() + def test_from_module_list(self): + class MyLoss(LossModule): + module_a: TensorDictModule + module_b: TensorDictModule + + module_a_params: TensorDict + module_b_params: TensorDict + + target_module_a_params: TensorDict + target_module_b_params: TensorDict + + def __init__(self, module_a, module_b0, module_b1, expand_dim=2): + super().__init__() + self.convert_to_functional(module_a, "module_a") + self.convert_to_functional( + [module_b0, module_b1], + "module_b", + # This will be ignored + compare_against=module_a.parameters(), + expand_dim=expand_dim, + ) + + module1 = nn.Linear(3, 4) + module2 = nn.Linear(3, 4) + module3a = nn.Linear(3, 4) + module3b = nn.Linear(3, 4) + + module_a = TensorDictModule( + nn.Sequential(module1, module2), in_keys=["a"], out_keys=["c"] + ) + + module_b0 = TensorDictModule( + nn.Sequential(module1, module3a), in_keys=["b"], out_keys=["c"] + ) + module_b1 = TensorDictModule( + nn.Sequential(module1, module3b), in_keys=["b"], out_keys=["c"] + ) + + loss = MyLoss(module_a, module_b0, module_b1) + + # This should be extended + assert not isinstance( + loss.module_b_params["module", "0", "weight"], nn.Parameter + ) + assert loss.module_b_params["module", "0", "weight"].shape[0] == 2 + assert ( + loss.module_b_params["module", "0", "weight"].data.data_ptr() + == loss.module_a_params["module", "0", "weight"].data.data_ptr() + ) + assert isinstance(loss.module_b_params["module", "1", "weight"], nn.Parameter) + assert loss.module_b_params["module", "1", "weight"].shape[0] == 2 + assert ( + loss.module_b_params["module", "1", "weight"].data.data_ptr() + != loss.module_a_params["module", "1", "weight"].data.data_ptr() + ) + def test_tensordict_keys(self): """Test configurable tensordict key behavior with derived classes.""" @@ -14962,10 +15083,10 @@ def __init__(self): assert v_p1 == v_p2 assert v_params1 == v_params2 assert v_buffers1 == v_buffers2 - for p in mod.parameters(): - assert isinstance(p, nn.Parameter) - for p in mod.buffers(): - assert isinstance(p, Buffer) + for k, p in mod.named_parameters(): + assert isinstance(p, nn.Parameter), k + for k, p in mod.named_buffers(): + assert isinstance(p, Buffer), k for p in mod.actor_params.values(True, True): assert isinstance(p, (nn.Parameter, Buffer)) for p in mod.value_params.values(True, True): diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 1471cde5141..bedd91e2e56 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -62,7 +62,7 @@ class A2CLoss(LossModule): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. advantage_key (str): [Deprecated, use set_keys(advantage_key=advantage_key) instead] The input tensordict key where the advantage is expected to be written. default: "advantage" diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index a10e6ccf25e..f2b02825005 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -317,57 +317,67 @@ def convert_to_functional( # Otherwise, casting the module to a device will keep old references # to uncast tensors sep = self.SEP - params = TensorDict.from_module(module, as_module=True) - - for key in params.keys(True): - if sep in key: - raise KeyError( - f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer." + if isinstance(module, (list, tuple)): + if len(module) != expand_dim: + raise RuntimeError( + "The ``expand_dim`` value must match the length of the module list/tuple " + "if a single module isn't provided." ) - if compare_against is not None: - compare_against = set(compare_against) + params = TensorDict.from_modules( + *module, as_module=True, expand_identical=True + ) else: - compare_against = set() - if expand_dim: - # Expands the dims of params and buffers. - # If the param already exist in the module, we return a simple expansion of the - # original one. Otherwise, we expand and resample it. - # For buffers, a cloned expansion (or equivalently a repeat) is returned. - - def _compare_and_expand(param): - if is_tensor_collection(param): - return param._apply_nest( + params = TensorDict.from_module(module, as_module=True) + + for key in params.keys(True): + if sep in key: + raise KeyError( + f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer." + ) + if compare_against is not None: + compare_against = set(compare_against) + else: + compare_against = set() + if expand_dim: + # Expands the dims of params and buffers. + # If the param already exist in the module, we return a simple expansion of the + # original one. Otherwise, we expand and resample it. + # For buffers, a cloned expansion (or equivalently a repeat) is returned. + + def _compare_and_expand(param): + if is_tensor_collection(param): + return param._apply_nest( + _compare_and_expand, + batch_size=[expand_dim, *param.shape], + filter_empty=False, + call_on_nested=True, + ) + if not isinstance(param, nn.Parameter): + buffer = param.expand(expand_dim, *param.shape).clone() + return buffer + if param in compare_against: + expanded_param = param.data.expand(expand_dim, *param.shape) + # the expanded parameter must be sent to device when to() + # is called: + return expanded_param + else: + p_out = param.expand(expand_dim, *param.shape).clone() + p_out = nn.Parameter( + p_out.uniform_( + p_out.min().item(), p_out.max().item() + ).requires_grad_() + ) + return p_out + + params = TensorDictParams( + params.apply( _compare_and_expand, - batch_size=[expand_dim, *param.shape], + batch_size=[expand_dim, *params.shape], filter_empty=False, call_on_nested=True, - ) - if not isinstance(param, nn.Parameter): - buffer = param.expand(expand_dim, *param.shape).clone() - return buffer - if param in compare_against: - expanded_param = param.data.expand(expand_dim, *param.shape) - # the expanded parameter must be sent to device when to() - # is called: - return expanded_param - else: - p_out = param.expand(expand_dim, *param.shape).clone() - p_out = nn.Parameter( - p_out.uniform_( - p_out.min().item(), p_out.max().item() - ).requires_grad_() - ) - return p_out - - params = TensorDictParams( - params.apply( - _compare_and_expand, - batch_size=[expand_dim, *params.shape], - filter_empty=False, - call_on_nested=True, - ), - no_convert=True, - ) + ), + no_convert=True, + ) param_name = module_name + "_params" diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 98283b24ff7..0d2d869d1e1 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -9,7 +9,7 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -46,8 +46,15 @@ class CQLLoss(LossModule): Args: actor_network (ProbabilisticActor): stochastic actor - qvalue_network (TensorDictModule): Q(s, a) parametric model. + qvalue_network (TensorDictModule or list of TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. + If a single instance of `qvalue_network` is provided, it will be duplicated ``N`` + times (where ``N=2`` for this loss). If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword args: loss_function (str, optional): loss function to be used with @@ -266,7 +273,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, loss_function: str = "smooth_l1", alpha_init: float = 1.0, diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 22d35bd5799..355a33a4682 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -7,7 +7,7 @@ import math from dataclasses import dataclass from functools import wraps -from typing import Dict, Tuple, Union +from typing import Dict, List, Tuple, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams @@ -54,6 +54,13 @@ class CrossQLoss(LossModule): actor_network (ProbabilisticActor): stochastic actor qvalue_network (TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword Args: num_qvalue_nets (integer, optional): number of Q-Value networks used. @@ -81,7 +88,7 @@ class CrossQLoss(LossModule): priority (for prioritized replay buffer usage). Defaults to ``"td_error"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -248,7 +255,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 5ffbeaf029b..6e1cf0f5eb3 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -40,7 +40,7 @@ class DDPGLoss(LossModule): data collection. Default is ``True``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index dd2ac615b58..9e7115ac601 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -7,7 +7,7 @@ import math from dataclasses import dataclass from numbers import Number -from typing import Tuple, Union +from typing import List, Tuple, Union import numpy as np import torch @@ -41,6 +41,13 @@ class REDQLoss_deprecated(LossModule): actor_network (TensorDictModule): the actor to be trained qvalue_network (TensorDictModule): a single Q-value network that will be multiplied as many times as needed. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword Args: num_qvalue_nets (int, optional): Number of Q-value networks to be trained. @@ -75,7 +82,7 @@ class REDQLoss_deprecated(LossModule): ``"td_error"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -134,7 +141,7 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, num_qvalue_nets: int = 10, sub_sample_len: int = 2, diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index a60d010d480..7fab95a95ed 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -6,7 +6,7 @@ import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams @@ -37,6 +37,14 @@ class IQLLoss(LossModule): Args: actor_network (ProbabilisticActor): stochastic actor qvalue_network (TensorDictModule): Q(s, a) parametric model + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. + value_network (TensorDictModule, optional): V(s) parametric model. Keyword Args: @@ -55,7 +63,7 @@ class IQLLoss(LossModule): buffer usage). Default is `"td_error"`. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -247,7 +255,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], value_network: Optional[TensorDictModule], *, num_qvalue_nets: int = 2, @@ -548,7 +556,7 @@ class DiscreteIQLLoss(IQLLoss): buffer usage). Default is `"td_error"`. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 08afc2a13f4..16e2776805b 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -82,7 +82,7 @@ class PPOLoss(LossModule): before being used. Defaults to ``False``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. advantage_key (str, optional): [Deprecated, use set_keys(advantage_key=advantage_key) instead] The input tensordict key where the advantage is @@ -657,7 +657,7 @@ class ClipPPOLoss(PPOLoss): before being used. Defaults to ``False``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. advantage_key (str, optional): [Deprecated, use set_keys(advantage_key=advantage_key) instead] The input tensordict key where the advantage is @@ -896,7 +896,7 @@ class KLPENPPOLoss(PPOLoss): before being used. Defaults to ``False``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. advantage_key (str, optional): [Deprecated, use set_keys(advantage_key=advantage_key) instead] The input tensordict key where the advantage is diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 00e5c24f08c..a0aaa96f7c5 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -7,7 +7,7 @@ import math from dataclasses import dataclass from numbers import Number -from typing import Union +from typing import List, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams @@ -41,8 +41,14 @@ class REDQLoss(LossModule): Args: actor_network (TensorDictModule): the actor to be trained - qvalue_network (TensorDictModule): a single Q-value network that will - be multiplicated as many times as needed. + qvalue_network (TensorDictModule): a single Q-value network or a list of Q-value networks. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword Args: num_qvalue_nets (int, optional): Number of Q-value networks to be trained. @@ -77,7 +83,7 @@ class REDQLoss(LossModule): ``"td_error"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -250,7 +256,7 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, num_qvalue_nets: int = 10, sub_sample_len: int = 2, diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index f32bea50d7e..d2d387e9a99 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -56,7 +56,7 @@ class ReinforceLoss(LossModule): value is expected to be written. Defaults to ``"value_target"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. functional (bool, optional): whether modules should be functionalized. Functionalizing permits features like meta-RL, but makes it diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 51017384dbe..67ab7d7d8ce 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from functools import wraps from numbers import Number -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -57,6 +57,14 @@ class SACLoss(LossModule): actor_network (ProbabilisticActor): stochastic actor qvalue_network (TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. + value_network (TensorDictModule, optional): V(s) parametric model. This module typically outputs a ``"state_value"`` entry. @@ -64,6 +72,7 @@ class SACLoss(LossModule): If not provided, the second version of SAC is assumed, where only the Q-Value network is needed. + Keyword Args: num_qvalue_nets (integer, optional): number of Q-Value networks used. Defaults to ``2``. loss_function (str, optional): loss function to be used with @@ -98,7 +107,7 @@ class SACLoss(LossModule): priority (for prioritized replay buffer usage). Defaults to ``"td_error"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -280,7 +289,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], value_network: Optional[TensorDictModule] = None, *, num_qvalue_nets: int = 2, @@ -830,7 +839,7 @@ class DiscreteSACLoss(LossModule): Default is `"td_error"`. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index b569eb01345..db99237d39e 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -34,8 +34,15 @@ class TD3Loss(LossModule): Args: actor_network (TensorDictModule): the actor to be trained - qvalue_network (TensorDictModule): a single Q-value network that will - be multiplicated as many times as needed. + qvalue_network (TensorDictModule): a single Q-value network or a list of + Q-value networks. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword Args: bounds (tuple of float, optional): the bounds of the action space. @@ -66,7 +73,7 @@ class TD3Loss(LossModule): the actor. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -218,7 +225,7 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, action_spec: TensorSpec = None, bounds: Optional[Tuple[float]] = None, diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index 93845bb00bd..d5529e0b859 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -43,8 +43,15 @@ class TD3BCLoss(LossModule): Args: actor_network (TensorDictModule): the actor to be trained - qvalue_network (TensorDictModule): a single Q-value network that will - be multiplicated as many times as needed. + qvalue_network (TensorDictModule): a single Q-value network or a list of + Q-value networks. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword Args: bounds (tuple of float, optional): the bounds of the action space. @@ -77,7 +84,7 @@ class TD3BCLoss(LossModule): the actor. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -233,7 +240,7 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, action_spec: TensorSpec = None, bounds: Optional[Tuple[float]] = None, diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index b977a3440dd..b7db2e8242e 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -502,7 +502,7 @@ class TD0Estimator(ValueEstimatorBase): skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` is not affected. advantage_key (str or tuple of str, optional): [Deprecated] the key of the advantage entry. Defaults to ``"advantage"``. @@ -701,7 +701,7 @@ class TD1Estimator(ValueEstimatorBase): skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` is not affected. advantage_key (str or tuple of str, optional): [Deprecated] the key of the advantage entry. Defaults to ``"advantage"``. @@ -922,7 +922,7 @@ class TDLambdaEstimator(ValueEstimatorBase): lambda return. Default is `True`. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` is not affected. advantage_key (str or tuple of str, optional): [Deprecated] the key of the advantage entry. Defaults to ``"advantage"``. @@ -1164,7 +1164,7 @@ class GAE(ValueEstimatorBase): lambda return. Default is `True`. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` is not affected. Defaults to "state_value". advantage_key (str or tuple of str, optional): [Deprecated] the key of @@ -1476,7 +1476,7 @@ class VTrace(ValueEstimatorBase): pass detached parameters for functional modules. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` is not affected. Defaults to "state_value". advantage_key (str or tuple of str, optional): [Deprecated] the key of