diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 96a887196aa..537d4542910 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -29,6 +29,15 @@ The main characteristics of TorchRL losses are: >>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")) +.. note:: + Initializing parameters in losses can be done via a query to :meth:`~torchrl.objectives.LossModule.get_stateful_net` + which will return a stateful version of the network that can be initialized like any other module. + If the modification is done in-place, it will be downstreamed to any other module that uses the same parameter + set (within and outside of the loss): for instance, modifying the ``actor_network`` parameters from the loss + will also modify the actor in the collector. + If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be + used to reset the parameters in the loss to the new value. + Training value functions ------------------------ diff --git a/test/test_cost.py b/test/test_cost.py index a318f5694cd..70758ff7d5d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -14549,6 +14549,62 @@ def __init__(self, compare_against, expand_dim): for key in ["module.1.bias", "module.1.weight"]: loss_module.module_b_params.flatten_keys()[key].requires_grad + def test_init_params(self): + class MyLoss(LossModule): + module_a: TensorDictModule + module_b: TensorDictModule + module_a_params: TensorDict + module_b_params: TensorDict + target_module_a_params: TensorDict + target_module_b_params: TensorDict + + def __init__(self, expand_dim=2): + super().__init__() + module1 = nn.Linear(3, 4) + module2 = nn.Linear(3, 4) + module3 = nn.Linear(3, 4) + module_a = TensorDictModule( + nn.Sequential(module1, module2), in_keys=["a"], out_keys=["c"] + ) + module_b = TensorDictModule( + nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"] + ) + self.convert_to_functional(module_a, "module_a") + self.convert_to_functional( + module_b, + "module_b", + compare_against=module_a.parameters(), + expand_dim=expand_dim, + ) + + loss = MyLoss() + + module_a = loss.get_stateful_net("module_a", copy=False) + assert module_a is loss.module_a + + module_a = loss.get_stateful_net("module_a") + assert module_a is not loss.module_a + + def init(mod): + if hasattr(mod, "weight"): + mod.weight.data.zero_() + if hasattr(mod, "bias"): + mod.bias.data.zero_() + + module_a.apply(init) + assert (loss.module_a_params == 0).all() + + def init(mod): + if hasattr(mod, "weight"): + mod.weight = torch.nn.Parameter(mod.weight.data + 1) + if hasattr(mod, "bias"): + mod.bias = torch.nn.Parameter(mod.bias.data + 1) + + module_a.apply(init) + assert (loss.module_a_params == 0).all() + loss.from_stateful_net("module_a", module_a) + assert (loss.module_a_params == 1).all() + def test_tensordict_keys(self): """Test configurable tensordict key behavior with derived classes.""" diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index a48ad5b634b..1a92c4b33eb 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -154,8 +154,8 @@ def get_stateful_net(self, copy: bool = True): This can be used to initialize parameters. - Such networks will generally not be callable out-of-the-box and will require some `vmap` - execution. to work + Such networks will often not be callable out-of-the-box and will require a `vmap` call + to be executable. Args: copy (bool, optional): if ``True``, a deepcopy of the network is made. @@ -203,7 +203,6 @@ def get_stateful_net(self, copy: bool = True): self.params.to_module(net) return net - @abc.abstractmethod def from_stateful_net(self, stateful_net: nn.Module): """Populates the parameters given a stateful version of the network. diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index cbfc218327d..a10e6ccf25e 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -8,6 +8,7 @@ import abc import functools import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple @@ -138,6 +139,67 @@ def __init__(self): self._tensor_keys = self._AcceptedKeys() self.register_forward_pre_hook(_updater_check_forward_prehook) + @property + def functional(self): + """Whether the module is functional. + + Unless it has been specifically designed not to be functional, all losses are functional. + """ + return True + + def get_stateful_net(self, network_name: str, copy: bool | None = None): + """Returns a stateful version of the network. + + This can be used to initialize parameters. + + Such networks will often not be callable out-of-the-box and will require a `vmap` call + to be executable. + + Args: + network_name (str): the network name to gather. + copy (bool, optional): if ``True``, a deepcopy of the network is made. + Defaults to ``True``. + + .. note:: if the module is not functional, no copy is made. + """ + net = getattr(self, network_name) + if not self.functional: + if copy is not None and copy: + raise RuntimeError("Cannot copy module in non-functional mode.") + return net + copy = True if copy is None else copy + if copy: + net = deepcopy(net) + params = getattr(self, network_name + "_params") + params.to_module(net) + return net + + def from_stateful_net(self, network_name: str, stateful_net: nn.Module): + """Populates the parameters of a model given a stateful version of the network. + + See :meth:`~.get_stateful_net` for details on how to gather a stateful version of the network. + + Args: + network_name (str): the network name to reset. + stateful_net (nn.Module): the stateful network from which the params should be + gathered. + + """ + if not self.functional: + getattr(self, network_name).load_state_dict(stateful_net.state_dict()) + return + params = TensorDict.from_module(stateful_net, as_module=True) + keyset0 = set(params.keys(True, True)) + self_params = getattr(self, network_name + "_params") + keyset1 = set(self_params.keys(True, True)) + if keyset0 != keyset1: + raise RuntimeError( + f"The keys of params and provided module differ: " + f"{keyset1-keyset0} are in self.params and not in the module, " + f"{keyset0-keyset1} are in the module but not in self.params." + ) + self_params.data.update_(params.data) + def _set_deprecated_ctor_keys(self, **kwargs) -> None: for key, value in kwargs.items(): if value is not None: