Skip to content

Commit

Permalink
[Feature] Make losses inherit from TDMBase (#1246)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 8, 2023
1 parent fdc78be commit cd344a3
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 35 deletions.
23 changes: 23 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,11 @@ def test_ddpg_notensordict(self):
loss_val = loss(**kwargs)
for i, key in enumerate(loss_val_td.keys()):
torch.testing.assert_close(loss_val_td.get(key), loss_val[i])
# test select
loss.select_out_keys("loss_actor", "target_value")
loss_actor, target_value = loss(**kwargs)
assert loss_actor == loss_val_td["loss_actor"]
assert (target_value == loss_val_td["target_value"]).all()


@pytest.mark.skipif(
Expand Down Expand Up @@ -1907,6 +1912,12 @@ def test_sac_notensordict(
torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4])
if version == 1:
torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[5])
# test select
torch.manual_seed(self.seed)
loss.select_out_keys("loss_actor", "loss_alpha")
loss_actor, loss_alpha = loss(**kwargs)
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]


@pytest.mark.skipif(
Expand Down Expand Up @@ -3746,6 +3757,12 @@ def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_ke
# don't test entropy and loss_entropy, since they depend on a random sample
# from distribution
assert len(loss_val) == 4
# test select
torch.manual_seed(self.seed)
loss.select_out_keys("loss_objective", "loss_critic")
loss_objective, loss_critic = loss(**kwargs)
assert loss_objective == loss_val_td["loss_objective"]
assert loss_critic == loss_val_td["loss_critic"]


class TestReinforce(LossModuleTestBase):
Expand Down Expand Up @@ -4805,6 +4822,12 @@ def test_iql_notensordict(self, action_key, observation_key, reward_key, done_ke
torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1])
torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[2])
torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[3])
# test select
torch.manual_seed(self.seed)
loss.select_out_keys("loss_actor", "loss_value")
loss_actor, loss_value = loss(**kwargs)
assert loss_actor == loss_val_td["loss_actor"]
assert loss_value == loss_val_td["loss_value"]


def test_hold_out():
Expand Down
34 changes: 27 additions & 7 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ class A2CLoss(LossModule):
... next_reward = torch.randn(*batch, 1),
... next_observation = torch.randn(*batch, n_obs))
>>> loss_obj.backward()
The output keys can also be filtered using the :meth:`SACLoss.select_out_keys`
method.
Examples:
>>> loss.select_out_keys('loss_objective', 'loss_critic')
>>> loss_obj, loss_critic = loss(
... observation = torch.randn(*batch, n_obs),
... action = spec.rand(batch),
... next_done = torch.zeros(*batch, 1, dtype=torch.bool),
... next_reward = torch.randn(*batch, 1),
... next_observation = torch.randn(*batch, n_obs))
>>> loss_obj.backward()
"""

@dataclass
Expand Down Expand Up @@ -200,6 +213,7 @@ def __init__(
advantage_key: str = None,
value_target_key: str = None,
):
self._out_keys = None
super().__init__()
self._set_deprecated_ctor_keys(
advantage=advantage_key, value_target=value_target_key
Expand Down Expand Up @@ -243,13 +257,19 @@ def in_keys(self):

@property
def out_keys(self):
outs = ["loss_objective"]
if self.critic_coef:
outs.append("loss_critic")
if self.entropy_bonus:
outs.append("entropy")
outs.append("loss_entropy")
return outs
if self._out_keys is None:
outs = ["loss_objective"]
if self.critic_coef:
outs.append("loss_critic")
if self.entropy_bonus:
outs.append("entropy")
outs.append("loss_entropy")
self._out_keys = outs
return self._out_keys

@out_keys.setter
def out_keys(self, value):
self._out_keys = value

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
9 changes: 7 additions & 2 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

import torch

from tensordict.nn import make_functional, repopulate_module, TensorDictModule
from tensordict.nn import (
make_functional,
repopulate_module,
TensorDictModule,
TensorDictModuleBase,
)

from tensordict.tensordict import TensorDictBase
from torch import nn, Tensor
Expand All @@ -39,7 +44,7 @@
FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality."


class LossModule(nn.Module):
class LossModule(TensorDictModuleBase):
"""A parent class for RL losses.
LossModule inherits from nn.Module. It is designed to read an input
Expand Down
32 changes: 22 additions & 10 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ class DDPGLoss(LossModule):
... next_reward=torch.randn(1))
>>> loss_actor.backward()
The output keys can also be filtered using the :meth:`DDPGLoss.select_out_keys`
method.
Examples:
>>> loss.select_out_keys('loss_actor', 'loss_value')
>>> loss_actor, loss_value = loss(
... observation=torch.randn(n_obs),
... action=spec.rand(),
... next_done=torch.zeros(1, dtype=torch.bool),
... next_observation=torch.randn(n_obs),
... next_reward=torch.randn(1))
>>> loss_actor.backward()
"""

@dataclass
Expand Down Expand Up @@ -147,6 +160,14 @@ class _AcceptedKeys:

default_keys = _AcceptedKeys()
default_value_estimator: ValueEstimators = ValueEstimators.TD0
out_keys = [
"loss_actor",
"loss_value",
"pred_value",
"target_value",
"pred_value_max",
"target_value_max",
]

def __init__(
self,
Expand Down Expand Up @@ -210,16 +231,7 @@ def in_keys(self):
keys = list(set(keys))
return keys

@dispatch(
dest=[
"loss_actor",
"loss_value",
"pred_value",
"target_value",
"pred_value_max",
"target_value_max",
]
)
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
"""Computes the DDPG losses given a tensordict sampled from the replay buffer.
Expand Down
30 changes: 21 additions & 9 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,27 @@ class IQLLoss(LossModule):
>>> loss = IQLLoss(actor, qvalue, value)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_actor, loss_qvlaue, loss_value, entropy = loss(
>>> loss_actor, loss_qvalue, loss_value, entropy = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_observation=torch.zeros(*batch, n_obs),
... next_reward=torch.randn(*batch, 1))
>>> loss_actor.backward()
The output keys can also be filtered using the :meth:`IQLLoss.select_out_keys`
method.
Examples:
>>> loss.select_out_keys('loss_actor', 'loss_qvalue')
>>> loss_actor, loss_qvalue = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_observation=torch.zeros(*batch, n_obs),
... next_reward=torch.randn(*batch, 1))
>>> loss_actor.backward()
"""

@dataclass
Expand Down Expand Up @@ -199,6 +212,12 @@ class _AcceptedKeys:

default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.TD0
out_keys = [
"loss_actor",
"loss_qvalue",
"loss_value",
"entropy",
]

def __init__(
self,
Expand Down Expand Up @@ -292,14 +311,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
done=self.tensor_keys.done,
)

@dispatch(
dest=[
"loss_actor",
"loss_qvalue",
"loss_value",
"entropy",
]
)
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
Expand Down
33 changes: 26 additions & 7 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,26 @@ class SACLoss(LossModule):
>>> loss = SACLoss(actor, qvalue, value)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_actor, loss_qvlaue, _, _, _, _ = loss(
>>> loss_actor, loss_qvalue, _, _, _, _ = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_observation=torch.zeros(*batch, n_obs),
... next_reward=torch.randn(*batch, 1))
>>> loss_actor.backward()
The output keys can also be filtered using the :meth:`SACLoss.select_out_keys`
method.
Examples:
>>> loss.select_out_keys('loss_actor', 'loss_qvalue')
>>> loss_actor, loss_qvalue = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_observation=torch.zeros(*batch, n_obs),
... next_reward=torch.randn(*batch, 1))
>>> loss_actor.backward()
"""

@dataclass
Expand Down Expand Up @@ -251,6 +263,7 @@ def __init__(
gamma: float = None,
priority_key: str = None,
) -> None:
self._out_keys = None
if not _has_functorch:
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR
super().__init__()
Expand Down Expand Up @@ -425,12 +438,18 @@ def in_keys(self):

@property
def out_keys(self):
keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]
if self._version == 1:
keys.append("loss_value")
return keys

@dispatch()
if self._out_keys is None:
keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]
if self._version == 1:
keys.append("loss_value")
self._out_keys = keys
return self._out_keys

@out_keys.setter
def out_keys(self, values):
self._out_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
Expand Down

0 comments on commit cd344a3

Please sign in to comment.