Skip to content

Commit

Permalink
[BugFix] Fix param tying in loss modules (#1037)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 12, 2023
1 parent 9df4529 commit 993c738
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 32 deletions.
47 changes: 47 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,7 @@ def test_ppo_shared(self, loss_class, device, advantage):
actor,
value,
loss_critic_type="l2",
separate_losses=True,
)

if advantage is not None:
Expand Down Expand Up @@ -4883,6 +4884,52 @@ def test_non_differentiable(self, adv, kwargs):
assert td["advantage"].is_leaf


class TestBase:
@pytest.mark.parametrize("expand_dim", [None, 2])
@pytest.mark.parametrize("compare_against", [True, False])
def test_convert_to_func(self, compare_against, expand_dim):
class MyLoss(LossModule):
def __init__(self, compare_against, expand_dim):
super().__init__()
module1 = nn.Linear(3, 4)
module2 = nn.Linear(3, 4)
module3 = nn.Linear(3, 4)
module_a = TensorDictModule(
nn.Sequential(module1, module2), in_keys=["a"], out_keys=["c"]
)
module_b = TensorDictModule(
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
)
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
module_b,
"module_b",
compare_against=module_a.parameters() if compare_against else [],
expand_dim=expand_dim,
)

loss_module = MyLoss(compare_against=compare_against, expand_dim=expand_dim)

for key in ["module.0.bias", "module.0.weight"]:
if compare_against:
assert not loss_module.module_b_params.flatten_keys()[key].requires_grad
else:
assert loss_module.module_b_params.flatten_keys()[key].requires_grad
if expand_dim:
assert (
loss_module.module_b_params.flatten_keys()[key].shape[0]
== expand_dim
)
else:
assert (
loss_module.module_b_params.flatten_keys()[key].shape[0]
!= expand_dim
)

for key in ["module.1.bias", "module.1.weight"]:
loss_module.module_b_params.flatten_keys()[key].requires_grad


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
13 changes: 12 additions & 1 deletion torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class A2CLoss(LossModule):
critic_coef (float): the weight of the critic loss.
loss_critic_type (str): loss function for the value discrepancy.
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -78,12 +82,19 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
gamma: float = None,
separate_losses: bool = False,
):
super().__init__()
self.convert_to_functional(
actor, "actor", funs_to_decorate=["forward", "get_dist"]
)
self.convert_to_functional(critic, "critic", compare_against=self.actor_params)
if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor.parameters())
else:
policy_params = None
self.convert_to_functional(critic, "critic", compare_against=policy_params)
self.advantage_key = advantage_key
self.value_target_key = value_target_key
self.samples_mc_entropy = samples_mc_entropy
Expand Down
74 changes: 63 additions & 11 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,54 @@ def convert_to_functional(
compare_against: Optional[List[Parameter]] = None,
funs_to_decorate=None,
) -> None:
"""Converts a module to functional to be used in the loss."""
"""Converts a module to functional to be used in the loss.
Args:
module (TensorDictModule or compatible): a stateful tensordict module.
This module will be made functional, yet still stateful, meaning
that it will be callable with the following alternative signatures:
>>> module(tensordict)
>>> module(tensordict, params=params)
``params`` is a :class:`tensordict.TensorDict` instance with parameters
stuctured as the output of :func:`tensordict.nn.make_functional`
is.
module_name (str): name where the module will be found.
The parameters of the module will be found under ``loss_module.<module_name>_params``
whereas the module will be found under ``loss_module.<module_name>``.
expand_dim (int, optional): if provided, the parameters of the module
will be expanded ``N`` times, where ``N = expand_dim`` along the
first dimension. This option is to be used whenever a target
network with more than one configuration is to be used.
.. note::
If a ``compare_against`` list of values is provided, the
resulting parameters will simply be a detached expansion
of the original parameters. If ``compare_against`` is not
provided, the value of the parameters will be resampled uniformly
between the minimum and maximum value of the parameter content.
create_target_params (bool, optional): if ``True``, a detached
copy of the parameter will be available to feed a target network
under the name ``loss_module.<module_name>_target_params``.
If ``False`` (default), this attribute will still be available
but it will be a detached instance of the parameters, not a copy.
In other words, any modification of the parameter value
will directly be reflected in the target parameters.
compare_against (iterable of parameters, optional): if provided,
this list of parameters will be used as a comparison set for
the parameters of the module. If the parameters are expanded
(``expand_dim > 0``), the resulting parameters for the module
will be a simple expansion of the original parameter. Otherwise,
the resulting parameters will be a detached version of the
original parameters. If ``None``, the resulting parameters
will carry gradients as expected.
funs_to_decorate (list of str, optional): if provided, the list of
methods of ``module`` to make functional, ie the list of
methods that will accept the ``params`` keyword argument.
"""
if funs_to_decorate is None:
funs_to_decorate = ["forward"]
# To make it robust to device casting, we must register list of
Expand All @@ -102,10 +149,6 @@ def convert_to_functional(
params = make_functional(module, funs_to_decorate=funs_to_decorate)
functional_module = deepcopy(module)
repopulate_module(module, params)
# params = make_functional(
# module, funs_to_decorate=funs_to_decorate, keep_params=True
# )
# functional_module = module

params_and_buffers = params
# we transform the buffers in params to make sure they follow the device
Expand Down Expand Up @@ -135,15 +178,15 @@ def create_buffers(tensor):
"expanding params is only possible when functorch is installed,"
"as this feature requires calls to the vmap operator."
)
if compare_against is not None:
compare_against = set(compare_against)
else:
compare_against = set()
if expand_dim:
# Expands the dims of params and buffers.
# If the param already exist in the module, we return a simple expansion of the
# original one. Otherwise, we expand and resample it.
# For buffers, a cloned expansion (or equivalently a repeat) is returned.
if compare_against is not None:
compare_against = set(compare_against)
else:
compare_against = set()

def _compare_and_expand(param):

Expand Down Expand Up @@ -186,11 +229,16 @@ def _compare_and_expand(param):
if parameter not in prev_set_params:
setattr(self, "_sep_".join([module_name, key]), parameter)
else:
# if the parameter is already present, we register a string pointing
# to is instead. If the string ends with a '_detached' suffix, the
# value will be detached
for _param_name, p in self.named_parameters():
if parameter is p:
break
else:
raise RuntimeError("parameter not found")
if compare_against is not None and p in compare_against:
_param_name = _param_name + "_detached"
setattr(self, "_sep_".join([module_name, key]), _param_name)
prev_set_buffers = set(self.buffers())
for key, buffer in buffers.items():
Expand Down Expand Up @@ -259,8 +307,12 @@ def _param_getter(self, network_name):
key = (key,)
value_to_set = getattr(self, "_sep_".join([network_name, *key]))
if isinstance(value_to_set, str):
value_to_set = getattr(self, value_to_set).detach()
params.set(key, value_to_set)
if value_to_set.endswith("_detached"):
value_to_set = value_to_set[:-9]
value_to_set = getattr(self, value_to_set).detach()
else:
value_to_set = getattr(self, value_to_set)
params._set(key, value_to_set)
return params
else:
params = getattr(self, param_name)
Expand Down
61 changes: 41 additions & 20 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,31 @@ class PPOLoss(LossModule):
Args:
actor (ProbabilisticTensorDictSequential): policy operator.
critic (ValueOperator): value operator.
advantage_key (str): the input tensordict key where the advantage is
advantage_key (str, optional): the input tensordict key where the advantage is
expected to be written.
Defaults to ``"advantage"``.
value_target_key (str): the input tensordict key where the target state
value_target_key (str, optional): the input tensordict key where the target state
value is expected to be written. Defaults to ``"value_target"``.
entropy_bonus (bool): if ``True``, an entropy bonus will be added to the
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
loss to favour exploratory policies.
samples_mc_entropy (int): if the distribution retrieved from the policy
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
operator does not have a closed form
formula for the entropy, a Monte-Carlo estimate will be used.
``samples_mc_entropy`` will control how many
samples will be used to compute this estimate.
Defaults to ``1``.
entropy_coef (scalar): entropy multiplier when computing the total loss.
entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
Defaults to ``0.01``.
critic_coef (scalar): critic loss multiplier when computing the total
critic_coef (scalar, optional): critic loss multiplier when computing the total
loss. Defaults to ``1.0``.
loss_critic_type (str): loss function for the value discrepancy.
loss_critic_type (str, optional): loss function for the value discrepancy.
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool): if ``True``, the advantage will be normalized
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -107,14 +111,19 @@ def __init__(
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
gamma: float = None,
separate_losses: bool = False,
):
super().__init__()
self.convert_to_functional(
actor, "actor", funs_to_decorate=["forward", "get_dist"]
)
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
self.convert_to_functional(critic, "critic", compare_against=self.actor_params)
if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor.parameters())
else:
policy_params = None
self.convert_to_functional(critic, "critic", compare_against=policy_params)
self.advantage_key = advantage_key
self.value_target_key = value_target_key
self.samples_mc_entropy = samples_mc_entropy
Expand Down Expand Up @@ -248,28 +257,32 @@ class ClipPPOLoss(PPOLoss):
Args:
actor (ProbabilisticTensorDictSequential): policy operator.
critic (ValueOperator): value operator.
advantage_key (str): the input tensordict key where the advantage is expected to be written.
advantage_key (str, optional): the input tensordict key where the advantage is expected to be written.
Defaults to ``"advantage"``.
value_target_key (str): the input tensordict key where the target state
value_target_key (str, optional): the input tensordict key where the target state
value is expected to be written. Defaults to ``"value_target"``.
clip_epsilon (scalar): weight clipping threshold in the clipped PPO loss equation.
clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation.
default: 0.2
entropy_bonus (bool): if ``True``, an entropy bonus will be added to the
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
loss to favour exploratory policies.
samples_mc_entropy (int): if the distribution retrieved from the policy
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
operator does not have a closed form
formula for the entropy, a Monte-Carlo estimate will be used.
``samples_mc_entropy`` will control how many
samples will be used to compute this estimate.
Defaults to ``1``.
entropy_coef (scalar): entropy multiplier when computing the total loss.
entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
Defaults to ``0.01``.
critic_coef (scalar): critic loss multiplier when computing the total
critic_coef (scalar, optional): critic loss multiplier when computing the total
loss. Defaults to ``1.0``.
loss_critic_type (str): loss function for the value discrepancy.
loss_critic_type (str, optional): loss function for the value discrepancy.
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool): if ``True``, the advantage will be normalized
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -312,6 +325,7 @@ def __init__(
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = True,
gamma: float = None,
separate_losses: bool = False,
**kwargs,
):
super(ClipPPOLoss, self).__init__(
Expand All @@ -325,6 +339,7 @@ def __init__(
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage,
gamma=gamma,
separate_losses=separate_losses,
**kwargs,
)
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon))
Expand Down Expand Up @@ -425,6 +440,10 @@ class KLPENPPOLoss(PPOLoss):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
.. note:
Expand Down Expand Up @@ -472,6 +491,7 @@ def __init__(
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = True,
gamma: float = None,
separate_losses: bool = False,
**kwargs,
):
super(KLPENPPOLoss, self).__init__(
Expand All @@ -485,6 +505,7 @@ def __init__(
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage,
gamma=gamma,
separate_losses=separate_losses,
**kwargs,
)

Expand Down

0 comments on commit 993c738

Please sign in to comment.