Skip to content

Commit

Permalink
[Refactor] the usage of tensordict keys in loss modules (#1175)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
Blonck and vmoens authored May 31, 2023
1 parent 35081b3 commit 86c69df
Show file tree
Hide file tree
Showing 14 changed files with 1,972 additions and 456 deletions.
954 changes: 909 additions & 45 deletions test/test_cost.py

Large diffs are not rendered by default.

92 changes: 64 additions & 28 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from dataclasses import dataclass
from typing import Tuple

import torch
from tensordict.nn import ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torch import distributions as d

from torchrl.objectives.common import LossModule
Expand All @@ -33,10 +35,6 @@ class A2CLoss(LossModule):
Args:
actor (ProbabilisticTensorDictSequential): policy operator.
critic (ValueOperator): value operator.
advantage_key (str): the input tensordict key where the advantage is expected to be written.
default: "advantage"
value_target_key (str): the input tensordict key where the target state
value is expected to be written. Defaults to ``"value_target"``.
entropy_bonus (bool): if ``True``, an entropy bonus will be added to the
loss to favour exploratory policies.
samples_mc_entropy (int): if the distribution retrieved from the policy
Expand All @@ -53,6 +51,10 @@ class A2CLoss(LossModule):
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
advantage_key (str): [Deprecated, use set_keys(advantage_key=advantage_key) instead]
The input tensordict key where the advantage is expected to be written. default: "advantage"
value_target_key (str): [Deprecated, use set_keys() instead] the input
tensordict key where the target state value is expected to be written. Defaults to ``"value_target"``.
.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand All @@ -67,24 +69,52 @@ class A2CLoss(LossModule):
"""

@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values.
Attributes:
advantage (NestedKey): The input tensordict key where the advantage is expected.
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
value_target (NestedKey): The input tensordict key where the target state value is expected.
Will be used for the underlying value estimator Defaults to ``"value_target"``.
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
"""

advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "state_value"
action: NestedKey = "action"

default_keys = _AcceptedKeys()
default_value_estimator: ValueEstimators = ValueEstimators.GAE

def __init__(
self,
actor: ProbabilisticTensorDictSequential,
critic: TensorDictModule,
*,
advantage_key: str = "advantage",
value_target_key: str = "value_target",
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01,
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
gamma: float = None,
separate_losses: bool = False,
advantage_key: str = None,
value_target_key: str = None,
):
super().__init__()
self._set_deprecated_ctor_keys(
advantage=advantage_key, value_target=value_target_key
)

self.convert_to_functional(
actor, "actor", funs_to_decorate=["forward", "get_dist"]
)
Expand All @@ -95,8 +125,6 @@ def __init__(
else:
policy_params = None
self.convert_to_functional(critic, "critic", compare_against=policy_params)
self.advantage_key = advantage_key
self.value_target_key = value_target_key
self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus and entropy_coef
self.register_buffer(
Expand All @@ -110,6 +138,14 @@ def __init__(
self.gamma = gamma
self.loss_critic_type = loss_critic_type

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
advantage=self._tensor_keys.advantage,
value_target=self._tensor_keys.value_target,
value=self._tensor_keys.value,
)

def reset(self) -> None:
pass

Expand All @@ -125,9 +161,11 @@ def _log_probs(
self, tensordict: TensorDictBase
) -> Tuple[torch.Tensor, d.Distribution]:
# current log_prob of actions
action = tensordict.get("action")
action = tensordict.get(self.tensor_keys.action)
if action.requires_grad:
raise RuntimeError("tensordict stored action require grad.")
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} require grad."
)
tensordict_clone = tensordict.select(*self.actor.in_keys).clone()

dist = self.actor.get_dist(tensordict_clone, params=self.actor_params)
Expand All @@ -139,20 +177,20 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
try:
# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
target_return = tensordict.get(self.value_target_key)
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic.in_keys)
state_value = self.critic(
tensordict_select,
params=self.critic_params,
).get("state_value")
).get(self.tensor_keys.value)
loss_value = distance_loss(
target_return,
state_value,
loss_function=self.loss_critic_type,
)
except KeyError:
raise KeyError(
f"the key {self.value_target_key} was not found in the input tensordict. "
f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
f"Make sure you provided the right key and the value_target (i.e. the target "
f"return) has been retrieved accordingly. Advantage classes such as GAE, "
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
Expand All @@ -162,14 +200,14 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.advantage_key, None)
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
self.value_estimator(
tensordict,
params=self.critic_params.detach(),
target_params=self.target_critic_params,
)
advantage = tensordict.get(self.advantage_key)
advantage = tensordict.get(self.tensor_keys.advantage)
log_probs, dist = self._log_probs(tensordict)
loss = -(log_probs * advantage)
td_out = TensorDict({"loss_objective": loss.mean()}, [])
Expand All @@ -190,22 +228,20 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
hp.update(hyperparams)
if hasattr(self, "gamma"):
hp["gamma"] = self.gamma
value_key = "state_value"
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(
value_network=self.critic, value_key=value_key, **hp
)
self._value_estimator = TD1Estimator(value_network=self.critic, **hp)
elif value_type == ValueEstimators.TD0:
self._value_estimator = TD0Estimator(
value_network=self.critic, value_key=value_key, **hp
)
self._value_estimator = TD0Estimator(value_network=self.critic, **hp)
elif value_type == ValueEstimators.GAE:
self._value_estimator = GAE(
value_network=self.critic, value_key=value_key, **hp
)
self._value_estimator = GAE(value_network=self.critic, **hp)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(
value_network=self.critic, value_key=value_key, **hp
)
self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp)
else:
raise NotImplementedError(f"Unknown value type {value_type}")

tensor_keys = {
"advantage": self.tensor_keys.advantage,
"value": self.tensor_keys.value,
"value_target": self.tensor_keys.value_target,
}
self._value_estimator.set_keys(**tensor_keys)
73 changes: 73 additions & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -57,13 +58,47 @@ class LossModule(nn.Module):
By default, the forward method is always decorated with a
gh :class:`torchrl.envs.ExplorationType.MODE`
To utilize the ability configuring the tensordict keys via
:meth:`~.set_keys()` a subclass must define an _AcceptedKeys dataclass.
This dataclass should include all keys that are intended to be configurable.
In addition, the subclass must implement the
:meth:._forward_value_estimator_keys() method. This function is crucial for
forwarding any altered tensordict keys to the underlying value_estimator.
Examples:
>>> class MyLoss(LossModule):
>>> @dataclass
>>> class _AcceptedKeys:
>>> action = "action"
>>>
>>> def _forward_value_estimator_keys(self, **kwargs) -> None:
>>> pass
>>>
>>> loss = MyLoss()
>>> loss.set_keys(action="action2")
"""

@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values.
"""

pass

default_value_estimator: ValueEstimators = None
SEP = "_sep_"

@property
def tensor_keys(self) -> _AcceptedKeys:
return self._tensor_keys

def __new__(cls, *args, **kwargs):
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)
cls._tensor_keys = cls._AcceptedKeys()
return super().__new__(cls)

def __init__(self):
Expand All @@ -74,6 +109,44 @@ def __init__(self):
self.value_type = self.default_value_estimator
# self.register_forward_pre_hook(_parameters_to_tensordict)

def _set_deprecated_ctor_keys(self, **kwargs) -> None:
"""Helper function to set a tensordict key from a constructor and raise a warning simultaneously."""
for key, value in kwargs.items():
if value is not None:
warnings.warn(
f"Setting '{key}' via the constructor is deprecated, use .set_keys(<key>='some_key') instead.",
category=DeprecationWarning,
)
self.set_keys(**{key: value})

def set_keys(self, **kwargs) -> None:
"""Set tensordict key names.
Examples:
>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
"""
for key, value in kwargs.items():
if key not in self._AcceptedKeys.__dict__:
raise ValueError(f"{key} it not an accepted tensordict key")
if value is not None:
setattr(self.tensor_keys, key, value)
else:
setattr(self.tensor_keys, key, self.default_keys.key)

try:
self._forward_value_estimator_keys(**kwargs)
except AttributeError:
raise AttributeError(
"To utilize `.set_keys(...)` for tensordict key configuration, the subclassed loss module "
"must define an _AcceptedKeys dataclass containing all keys intended for configuration. "
"Moreover, the subclass needs to implement `._forward_value_estimator_keys()` method to "
"facilitate forwarding of any modified tensordict keys to the underlying value_estimator."
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*".
Expand Down
Loading

0 comments on commit 86c69df

Please sign in to comment.