Skip to content

Commit

Permalink
[Feature] Separate losses (#1240)
Browse files Browse the repository at this point in the history
  • Loading branch information
MateuszGuzek authored Jun 26, 2023
1 parent 359f9a4 commit 56518a7
Show file tree
Hide file tree
Showing 8 changed files with 1,226 additions and 74 deletions.
1,168 changes: 1,115 additions & 53 deletions test/test_cost.py

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class DDPGLoss(LossModule):
data collection. Default is ``False``.
delay_value (bool, optional): whether to separate the target value networks from the value networks used for
data collection. Default is ``True``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
Examples:
>>> import torch
Expand Down Expand Up @@ -178,6 +182,7 @@ def __init__(
delay_actor: bool = False,
delay_value: bool = True,
gamma: float = None,
separate_losses: bool = False,
) -> None:
self._in_keys = None
super().__init__()
Expand All @@ -195,11 +200,17 @@ def __init__(
"actor_network",
create_target_params=self.delay_actor,
)
if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor_network.parameters())
else:
policy_params = None
self.convert_to_functional(
value_network,
"value_network",
create_target_params=self.delay_value,
compare_against=list(actor_network.parameters()),
compare_against=policy_params,
)
self.actor_critic.module[0] = self.actor_network
self.actor_critic.module[1] = self.value_network
Expand Down
14 changes: 12 additions & 2 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ class REDQLoss_deprecated(LossModule):
priority_key (str, optional): [Deprecated] Key where to write the priority value
for prioritized replay buffers. Default is
``"td_error"``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
"""

@dataclass
Expand Down Expand Up @@ -142,6 +146,7 @@ def __init__(
gSDE: bool = False,
gamma: float = None,
priority_key: str = None,
separate_losses: bool = False,
):
self._in_keys = None
self._out_keys = None
Expand All @@ -155,7 +160,12 @@ def __init__(
"actor_network",
create_target_params=self.delay_actor,
)

if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor_network.parameters())
else:
policy_params = None
# let's make sure that actor_network has `return_log_prob` set to True
self.actor_network.return_log_prob = True

Expand All @@ -165,7 +175,7 @@ def __init__(
"qvalue_network",
expand_dim=num_qvalue_nets,
create_target_params=self.delay_qvalue,
compare_against=actor_network.parameters(),
compare_against=policy_params,
)
self.num_qvalue_nets = num_qvalue_nets
self.sub_sample_len = max(1, min(sub_sample_len, num_qvalue_nets - 1))
Expand Down
24 changes: 19 additions & 5 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class IQLLoss(LossModule):
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
tensordict key where to write the priority (for prioritized replay
buffer usage). Default is `"td_error"`.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
Examples:
>>> import torch
Expand Down Expand Up @@ -233,6 +237,7 @@ def __init__(
expectile: float = 0.5,
gamma: float = None,
priority_key: str = None,
separate_losses: bool = False,
) -> None:
self._in_keys = None
self._out_keys = None
Expand All @@ -252,26 +257,35 @@ def __init__(
create_target_params=False,
funs_to_decorate=["forward", "get_dist"],
)

if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor_network.parameters())
else:
policy_params = None
# Value Function Network
self.convert_to_functional(
value_network,
"value_network",
create_target_params=False,
compare_against=list(actor_network.parameters()),
compare_against=policy_params,
)

# Q Function Network
self.delay_qvalue = True
self.num_qvalue_nets = num_qvalue_nets

if separate_losses and policy_params is not None:
qvalue_policy_params = list(actor_network.parameters()) + list(
value_network.parameters()
)
else:
qvalue_policy_params = None
self.convert_to_functional(
qvalue_network,
"qvalue_network",
num_qvalue_nets,
create_target_params=True,
compare_against=list(actor_network.parameters())
+ list(value_network.parameters()),
compare_against=qvalue_policy_params,
)

self.loss_function = loss_function
Expand Down
14 changes: 12 additions & 2 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class REDQLoss(LossModule):
priority_key (str, optional): [Deprecated, use .set_keys() instead] Key where to write the priority value
for prioritized replay buffers. Default is
``"td_error"``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
Examples:
>>> import torch
Expand Down Expand Up @@ -253,6 +257,7 @@ def __init__(
gSDE: bool = False,
gamma: float = None,
priority_key: str = None,
separate_losses: bool = False,
):
if not _has_functorch:
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR
Expand All @@ -270,14 +275,19 @@ def __init__(

# let's make sure that actor_network has `return_log_prob` set to True
self.actor_network.return_log_prob = True

if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor_network.parameters())
else:
policy_params = None
self.delay_qvalue = delay_qvalue
self.convert_to_functional(
qvalue_network,
"qvalue_network",
num_qvalue_nets,
create_target_params=self.delay_qvalue,
compare_against=list(actor_network.parameters()),
compare_against=policy_params,
)
self.num_qvalue_nets = num_qvalue_nets
self.sub_sample_len = max(1, min(sub_sample_len, num_qvalue_nets - 1))
Expand Down
14 changes: 12 additions & 2 deletions torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class ReinforceLoss(LossModule):
value_target_key (str): [Deprecated, use .set_keys(value_target_key=value_target_key) instead]
The input tensordict key where the target state
value is expected to be written. Defaults to ``"value_target"``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -194,6 +198,7 @@ def __init__(
gamma: float = None,
advantage_key: str = None,
value_target_key: str = None,
separate_losses: bool = False,
) -> None:
super().__init__()
self.in_keys = None
Expand All @@ -210,14 +215,19 @@ def __init__(
"actor_network",
create_target_params=False,
)

if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor.parameters())
else:
policy_params = None
# Value
if critic is not None:
self.convert_to_functional(
critic,
"critic",
create_target_params=self.delay_value,
compare_against=list(actor.parameters()),
compare_against=policy_params,
)
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
Expand Down
39 changes: 32 additions & 7 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class SACLoss(LossModule):
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
Tensordict key where to write the
priority (for prioritized replay buffer usage). Defaults to ``"td_error"``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
Examples:
>>> import torch
Expand Down Expand Up @@ -267,6 +271,7 @@ def __init__(
delay_value: bool = True,
gamma: float = None,
priority_key: str = None,
separate_losses: bool = False,
) -> None:
self._in_keys = None
self._out_keys = None
Expand All @@ -283,7 +288,13 @@ def __init__(
create_target_params=self.delay_actor,
funs_to_decorate=["forward", "get_dist"],
)

if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor_network.parameters())
else:
policy_params = None
q_value_policy_params = None
# Value
if value_network is not None:
self._version = 1
Expand All @@ -292,7 +303,7 @@ def __init__(
value_network,
"value_network",
create_target_params=self.delay_value,
compare_against=list(actor_network.parameters()),
compare_against=policy_params,
)
else:
self._version = 2
Expand All @@ -301,15 +312,19 @@ def __init__(
self.delay_qvalue = delay_qvalue
self.num_qvalue_nets = num_qvalue_nets
if self._version == 1:
value_params = list(value_network.parameters())
if separate_losses:
value_params = list(value_network.parameters())
q_value_policy_params = policy_params + value_params
else:
q_value_policy_params = policy_params
else:
value_params = []
q_value_policy_params = policy_params
self.convert_to_functional(
qvalue_network,
"qvalue_network",
num_qvalue_nets,
create_target_params=self.delay_qvalue,
compare_against=list(actor_network.parameters()) + value_params,
compare_against=q_value_policy_params,
)

self.loss_function = loss_function
Expand Down Expand Up @@ -751,6 +766,10 @@ class DiscreteSACLoss(LossModule):
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
Key where to write the priority value for prioritized replay buffers.
Default is `"td_error"`.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
Examples:
>>> import torch
Expand Down Expand Up @@ -919,6 +938,7 @@ def __init__(
target_entropy: Union[str, Number] = "auto",
delay_qvalue: bool = True,
priority_key: str = None,
separate_losses: bool = False,
):
self._in_keys = None
if not _has_functorch:
Expand All @@ -932,14 +952,19 @@ def __init__(
create_target_params=self.delay_actor,
funs_to_decorate=["forward", "get_dist_params"],
)

if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor_network.parameters())
else:
policy_params = None
self.delay_qvalue = delay_qvalue
self.convert_to_functional(
qvalue_network,
"qvalue_network",
num_qvalue_nets,
create_target_params=self.delay_qvalue,
compare_against=list(actor_network.parameters()),
compare_against=policy_params,
)
self.num_qvalue_nets = num_qvalue_nets
self.loss_function = loss_function
Expand Down
14 changes: 12 additions & 2 deletions torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class TD3Loss(LossModule):
spec (TensorSpec, optional): the action tensor spec. If not provided
and the target entropy is ``"auto"``, it will be retrieved from
the actor.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
Examples:
>>> import torch
Expand Down Expand Up @@ -216,6 +220,7 @@ def __init__(
delay_qvalue: bool = True,
gamma: float = None,
priority_key: str = None,
separate_losses: bool = False,
) -> None:
if not _has_functorch:
raise ImportError(
Expand All @@ -234,13 +239,18 @@ def __init__(
"actor_network",
create_target_params=self.delay_actor,
)

if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor_network.parameters())
else:
policy_params = None
self.convert_to_functional(
qvalue_network,
"qvalue_network",
num_qvalue_nets,
create_target_params=self.delay_qvalue,
compare_against=list(actor_network.parameters()),
compare_against=policy_params,
)

for p in self.parameters():
Expand Down

0 comments on commit 56518a7

Please sign in to comment.