From 993c7381b33a794aa6a3cd7720b9c26da5cdfda4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 12 Apr 2023 13:10:26 +0100 Subject: [PATCH] [BugFix] Fix param tying in loss modules (#1037) --- test/test_cost.py | 47 +++++++++++++++++++++++ torchrl/objectives/a2c.py | 13 ++++++- torchrl/objectives/common.py | 74 ++++++++++++++++++++++++++++++------ torchrl/objectives/ppo.py | 61 +++++++++++++++++++---------- 4 files changed, 163 insertions(+), 32 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 0b049086486..e0d702c7b1b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -2407,6 +2407,7 @@ def test_ppo_shared(self, loss_class, device, advantage): actor, value, loss_critic_type="l2", + separate_losses=True, ) if advantage is not None: @@ -4883,6 +4884,52 @@ def test_non_differentiable(self, adv, kwargs): assert td["advantage"].is_leaf +class TestBase: + @pytest.mark.parametrize("expand_dim", [None, 2]) + @pytest.mark.parametrize("compare_against", [True, False]) + def test_convert_to_func(self, compare_against, expand_dim): + class MyLoss(LossModule): + def __init__(self, compare_against, expand_dim): + super().__init__() + module1 = nn.Linear(3, 4) + module2 = nn.Linear(3, 4) + module3 = nn.Linear(3, 4) + module_a = TensorDictModule( + nn.Sequential(module1, module2), in_keys=["a"], out_keys=["c"] + ) + module_b = TensorDictModule( + nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"] + ) + self.convert_to_functional(module_a, "module_a") + self.convert_to_functional( + module_b, + "module_b", + compare_against=module_a.parameters() if compare_against else [], + expand_dim=expand_dim, + ) + + loss_module = MyLoss(compare_against=compare_against, expand_dim=expand_dim) + + for key in ["module.0.bias", "module.0.weight"]: + if compare_against: + assert not loss_module.module_b_params.flatten_keys()[key].requires_grad + else: + assert loss_module.module_b_params.flatten_keys()[key].requires_grad + if expand_dim: + assert ( + loss_module.module_b_params.flatten_keys()[key].shape[0] + == expand_dim + ) + else: + assert ( + loss_module.module_b_params.flatten_keys()[key].shape[0] + != expand_dim + ) + + for key in ["module.1.bias", "module.1.weight"]: + loss_module.module_b_params.flatten_keys()[key].requires_grad + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 8c4d0ef70ef..bb23793d0df 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -49,6 +49,10 @@ class A2CLoss(LossModule): critic_coef (float): the weight of the critic loss. loss_critic_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. .. note: The advantage (typically GAE) can be computed by the loss function or @@ -78,12 +82,19 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", gamma: float = None, + separate_losses: bool = False, ): super().__init__() self.convert_to_functional( actor, "actor", funs_to_decorate=["forward", "get_dist"] ) - self.convert_to_functional(critic, "critic", compare_against=self.actor_params) + if separate_losses: + # we want to make sure there are no duplicates in the params: the + # params of critic must be refs to actor if they're shared + policy_params = list(actor.parameters()) + else: + policy_params = None + self.convert_to_functional(critic, "critic", compare_against=policy_params) self.advantage_key = advantage_key self.value_target_key = value_target_key self.samples_mc_entropy = samples_mc_entropy diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 380123be00a..1c0a80e2876 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -88,7 +88,54 @@ def convert_to_functional( compare_against: Optional[List[Parameter]] = None, funs_to_decorate=None, ) -> None: - """Converts a module to functional to be used in the loss.""" + """Converts a module to functional to be used in the loss. + + Args: + module (TensorDictModule or compatible): a stateful tensordict module. + This module will be made functional, yet still stateful, meaning + that it will be callable with the following alternative signatures: + + >>> module(tensordict) + >>> module(tensordict, params=params) + + ``params`` is a :class:`tensordict.TensorDict` instance with parameters + stuctured as the output of :func:`tensordict.nn.make_functional` + is. + module_name (str): name where the module will be found. + The parameters of the module will be found under ``loss_module._params`` + whereas the module will be found under ``loss_module.``. + expand_dim (int, optional): if provided, the parameters of the module + will be expanded ``N`` times, where ``N = expand_dim`` along the + first dimension. This option is to be used whenever a target + network with more than one configuration is to be used. + + .. note:: + If a ``compare_against`` list of values is provided, the + resulting parameters will simply be a detached expansion + of the original parameters. If ``compare_against`` is not + provided, the value of the parameters will be resampled uniformly + between the minimum and maximum value of the parameter content. + + create_target_params (bool, optional): if ``True``, a detached + copy of the parameter will be available to feed a target network + under the name ``loss_module._target_params``. + If ``False`` (default), this attribute will still be available + but it will be a detached instance of the parameters, not a copy. + In other words, any modification of the parameter value + will directly be reflected in the target parameters. + compare_against (iterable of parameters, optional): if provided, + this list of parameters will be used as a comparison set for + the parameters of the module. If the parameters are expanded + (``expand_dim > 0``), the resulting parameters for the module + will be a simple expansion of the original parameter. Otherwise, + the resulting parameters will be a detached version of the + original parameters. If ``None``, the resulting parameters + will carry gradients as expected. + funs_to_decorate (list of str, optional): if provided, the list of + methods of ``module`` to make functional, ie the list of + methods that will accept the ``params`` keyword argument. + + """ if funs_to_decorate is None: funs_to_decorate = ["forward"] # To make it robust to device casting, we must register list of @@ -102,10 +149,6 @@ def convert_to_functional( params = make_functional(module, funs_to_decorate=funs_to_decorate) functional_module = deepcopy(module) repopulate_module(module, params) - # params = make_functional( - # module, funs_to_decorate=funs_to_decorate, keep_params=True - # ) - # functional_module = module params_and_buffers = params # we transform the buffers in params to make sure they follow the device @@ -135,15 +178,15 @@ def create_buffers(tensor): "expanding params is only possible when functorch is installed," "as this feature requires calls to the vmap operator." ) + if compare_against is not None: + compare_against = set(compare_against) + else: + compare_against = set() if expand_dim: # Expands the dims of params and buffers. # If the param already exist in the module, we return a simple expansion of the # original one. Otherwise, we expand and resample it. # For buffers, a cloned expansion (or equivalently a repeat) is returned. - if compare_against is not None: - compare_against = set(compare_against) - else: - compare_against = set() def _compare_and_expand(param): @@ -186,11 +229,16 @@ def _compare_and_expand(param): if parameter not in prev_set_params: setattr(self, "_sep_".join([module_name, key]), parameter) else: + # if the parameter is already present, we register a string pointing + # to is instead. If the string ends with a '_detached' suffix, the + # value will be detached for _param_name, p in self.named_parameters(): if parameter is p: break else: raise RuntimeError("parameter not found") + if compare_against is not None and p in compare_against: + _param_name = _param_name + "_detached" setattr(self, "_sep_".join([module_name, key]), _param_name) prev_set_buffers = set(self.buffers()) for key, buffer in buffers.items(): @@ -259,8 +307,12 @@ def _param_getter(self, network_name): key = (key,) value_to_set = getattr(self, "_sep_".join([network_name, *key])) if isinstance(value_to_set, str): - value_to_set = getattr(self, value_to_set).detach() - params.set(key, value_to_set) + if value_to_set.endswith("_detached"): + value_to_set = value_to_set[:-9] + value_to_set = getattr(self, value_to_set).detach() + else: + value_to_set = getattr(self, value_to_set) + params._set(key, value_to_set) return params else: params = getattr(self, param_name) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 8dcb681678d..7f2392480d1 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -42,27 +42,31 @@ class PPOLoss(LossModule): Args: actor (ProbabilisticTensorDictSequential): policy operator. critic (ValueOperator): value operator. - advantage_key (str): the input tensordict key where the advantage is + advantage_key (str, optional): the input tensordict key where the advantage is expected to be written. Defaults to ``"advantage"``. - value_target_key (str): the input tensordict key where the target state + value_target_key (str, optional): the input tensordict key where the target state value is expected to be written. Defaults to ``"value_target"``. - entropy_bonus (bool): if ``True``, an entropy bonus will be added to the + entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the loss to favour exploratory policies. - samples_mc_entropy (int): if the distribution retrieved from the policy + samples_mc_entropy (int, optional): if the distribution retrieved from the policy operator does not have a closed form formula for the entropy, a Monte-Carlo estimate will be used. ``samples_mc_entropy`` will control how many samples will be used to compute this estimate. Defaults to ``1``. - entropy_coef (scalar): entropy multiplier when computing the total loss. + entropy_coef (scalar, optional): entropy multiplier when computing the total loss. Defaults to ``0.01``. - critic_coef (scalar): critic loss multiplier when computing the total + critic_coef (scalar, optional): critic loss multiplier when computing the total loss. Defaults to ``1.0``. - loss_critic_type (str): loss function for the value discrepancy. + loss_critic_type (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. - normalize_advantage (bool): if ``True``, the advantage will be normalized + normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. .. note: The advantage (typically GAE) can be computed by the loss function or @@ -107,14 +111,19 @@ def __init__( loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, gamma: float = None, + separate_losses: bool = False, ): super().__init__() self.convert_to_functional( actor, "actor", funs_to_decorate=["forward", "get_dist"] ) - # we want to make sure there are no duplicates in the params: the - # params of critic must be refs to actor if they're shared - self.convert_to_functional(critic, "critic", compare_against=self.actor_params) + if separate_losses: + # we want to make sure there are no duplicates in the params: the + # params of critic must be refs to actor if they're shared + policy_params = list(actor.parameters()) + else: + policy_params = None + self.convert_to_functional(critic, "critic", compare_against=policy_params) self.advantage_key = advantage_key self.value_target_key = value_target_key self.samples_mc_entropy = samples_mc_entropy @@ -248,28 +257,32 @@ class ClipPPOLoss(PPOLoss): Args: actor (ProbabilisticTensorDictSequential): policy operator. critic (ValueOperator): value operator. - advantage_key (str): the input tensordict key where the advantage is expected to be written. + advantage_key (str, optional): the input tensordict key where the advantage is expected to be written. Defaults to ``"advantage"``. - value_target_key (str): the input tensordict key where the target state + value_target_key (str, optional): the input tensordict key where the target state value is expected to be written. Defaults to ``"value_target"``. - clip_epsilon (scalar): weight clipping threshold in the clipped PPO loss equation. + clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation. default: 0.2 - entropy_bonus (bool): if ``True``, an entropy bonus will be added to the + entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the loss to favour exploratory policies. - samples_mc_entropy (int): if the distribution retrieved from the policy + samples_mc_entropy (int, optional): if the distribution retrieved from the policy operator does not have a closed form formula for the entropy, a Monte-Carlo estimate will be used. ``samples_mc_entropy`` will control how many samples will be used to compute this estimate. Defaults to ``1``. - entropy_coef (scalar): entropy multiplier when computing the total loss. + entropy_coef (scalar, optional): entropy multiplier when computing the total loss. Defaults to ``0.01``. - critic_coef (scalar): critic loss multiplier when computing the total + critic_coef (scalar, optional): critic loss multiplier when computing the total loss. Defaults to ``1.0``. - loss_critic_type (str): loss function for the value discrepancy. + loss_critic_type (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. - normalize_advantage (bool): if ``True``, the advantage will be normalized + normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. .. note: The advantage (typically GAE) can be computed by the loss function or @@ -312,6 +325,7 @@ def __init__( loss_critic_type: str = "smooth_l1", normalize_advantage: bool = True, gamma: float = None, + separate_losses: bool = False, **kwargs, ): super(ClipPPOLoss, self).__init__( @@ -325,6 +339,7 @@ def __init__( loss_critic_type=loss_critic_type, normalize_advantage=normalize_advantage, gamma=gamma, + separate_losses=separate_losses, **kwargs, ) self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon)) @@ -425,6 +440,10 @@ class KLPENPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. .. note: @@ -472,6 +491,7 @@ def __init__( loss_critic_type: str = "smooth_l1", normalize_advantage: bool = True, gamma: float = None, + separate_losses: bool = False, **kwargs, ): super(KLPENPPOLoss, self).__init__( @@ -485,6 +505,7 @@ def __init__( loss_critic_type=loss_critic_type, normalize_advantage=normalize_advantage, gamma=gamma, + separate_losses=separate_losses, **kwargs, )