diff --git a/test/test_cost.py b/test/test_cost.py index 27869659bd5..f841fae4e0d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -997,6 +997,11 @@ def test_ddpg_notensordict(self): loss_val = loss(**kwargs) for i, key in enumerate(loss_val_td.keys()): torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + # test select + loss.select_out_keys("loss_actor", "target_value") + loss_actor, target_value = loss(**kwargs) + assert loss_actor == loss_val_td["loss_actor"] + assert (target_value == loss_val_td["target_value"]).all() @pytest.mark.skipif( @@ -1907,6 +1912,12 @@ def test_sac_notensordict( torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4]) if version == 1: torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[5]) + # test select + torch.manual_seed(self.seed) + loss.select_out_keys("loss_actor", "loss_alpha") + loss_actor, loss_alpha = loss(**kwargs) + assert loss_actor == loss_val_td["loss_actor"] + assert loss_alpha == loss_val_td["loss_alpha"] @pytest.mark.skipif( @@ -3746,6 +3757,12 @@ def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_ke # don't test entropy and loss_entropy, since they depend on a random sample # from distribution assert len(loss_val) == 4 + # test select + torch.manual_seed(self.seed) + loss.select_out_keys("loss_objective", "loss_critic") + loss_objective, loss_critic = loss(**kwargs) + assert loss_objective == loss_val_td["loss_objective"] + assert loss_critic == loss_val_td["loss_critic"] class TestReinforce(LossModuleTestBase): @@ -4805,6 +4822,12 @@ def test_iql_notensordict(self, action_key, observation_key, reward_key, done_ke torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1]) torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[2]) torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[3]) + # test select + torch.manual_seed(self.seed) + loss.select_out_keys("loss_actor", "loss_value") + loss_actor, loss_value = loss(**kwargs) + assert loss_actor == loss_val_td["loss_actor"] + assert loss_value == loss_val_td["loss_value"] def test_hold_out(): diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index fe7ad37b99f..2298921464e 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -150,6 +150,19 @@ class A2CLoss(LossModule): ... next_reward = torch.randn(*batch, 1), ... next_observation = torch.randn(*batch, n_obs)) >>> loss_obj.backward() + + The output keys can also be filtered using the :meth:`SACLoss.select_out_keys` + method. + + Examples: + >>> loss.select_out_keys('loss_objective', 'loss_critic') + >>> loss_obj, loss_critic = loss( + ... observation = torch.randn(*batch, n_obs), + ... action = spec.rand(batch), + ... next_done = torch.zeros(*batch, 1, dtype=torch.bool), + ... next_reward = torch.randn(*batch, 1), + ... next_observation = torch.randn(*batch, n_obs)) + >>> loss_obj.backward() """ @dataclass @@ -200,6 +213,7 @@ def __init__( advantage_key: str = None, value_target_key: str = None, ): + self._out_keys = None super().__init__() self._set_deprecated_ctor_keys( advantage=advantage_key, value_target=value_target_key @@ -243,13 +257,19 @@ def in_keys(self): @property def out_keys(self): - outs = ["loss_objective"] - if self.critic_coef: - outs.append("loss_critic") - if self.entropy_bonus: - outs.append("entropy") - outs.append("loss_entropy") - return outs + if self._out_keys is None: + outs = ["loss_objective"] + if self.critic_coef: + outs.append("loss_critic") + if self.entropy_bonus: + outs.append("entropy") + outs.append("loss_entropy") + self._out_keys = outs + return self._out_keys + + @out_keys.setter + def out_keys(self, value): + self._out_keys = value def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 706c2243b1b..f6fee87bd46 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -12,7 +12,12 @@ import torch -from tensordict.nn import make_functional, repopulate_module, TensorDictModule +from tensordict.nn import ( + make_functional, + repopulate_module, + TensorDictModule, + TensorDictModuleBase, +) from tensordict.tensordict import TensorDictBase from torch import nn, Tensor @@ -39,7 +44,7 @@ FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." -class LossModule(nn.Module): +class LossModule(TensorDictModuleBase): """A parent class for RL losses. LossModule inherits from nn.Module. It is designed to read an input diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 1809edc12b9..4c613ec040a 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -117,6 +117,19 @@ class DDPGLoss(LossModule): ... next_reward=torch.randn(1)) >>> loss_actor.backward() + The output keys can also be filtered using the :meth:`DDPGLoss.select_out_keys` + method. + + Examples: + >>> loss.select_out_keys('loss_actor', 'loss_value') + >>> loss_actor, loss_value = loss( + ... observation=torch.randn(n_obs), + ... action=spec.rand(), + ... next_done=torch.zeros(1, dtype=torch.bool), + ... next_observation=torch.randn(n_obs), + ... next_reward=torch.randn(1)) + >>> loss_actor.backward() + """ @dataclass @@ -147,6 +160,14 @@ class _AcceptedKeys: default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.TD0 + out_keys = [ + "loss_actor", + "loss_value", + "pred_value", + "target_value", + "pred_value_max", + "target_value_max", + ] def __init__( self, @@ -210,16 +231,7 @@ def in_keys(self): keys = list(set(keys)) return keys - @dispatch( - dest=[ - "loss_actor", - "loss_value", - "pred_value", - "target_value", - "pred_value_max", - "target_value_max", - ] - ) + @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDict: """Computes the DDPG losses given a tensordict sampled from the replay buffer. diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 5a47354edc2..37af156867c 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -153,7 +153,7 @@ class IQLLoss(LossModule): >>> loss = IQLLoss(actor, qvalue, value) >>> batch = [2, ] >>> action = spec.rand(batch) - >>> loss_actor, loss_qvlaue, loss_value, entropy = loss( + >>> loss_actor, loss_qvalue, loss_value, entropy = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), @@ -161,6 +161,19 @@ class IQLLoss(LossModule): ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() + + The output keys can also be filtered using the :meth:`IQLLoss.select_out_keys` + method. + + Examples: + >>> loss.select_out_keys('loss_actor', 'loss_qvalue') + >>> loss_actor, loss_qvalue = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_observation=torch.zeros(*batch, n_obs), + ... next_reward=torch.randn(*batch, 1)) + >>> loss_actor.backward() """ @dataclass @@ -199,6 +212,12 @@ class _AcceptedKeys: default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 + out_keys = [ + "loss_actor", + "loss_qvalue", + "loss_value", + "entropy", + ] def __init__( self, @@ -292,14 +311,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: done=self.tensor_keys.done, ) - @dispatch( - dest=[ - "loss_actor", - "loss_qvalue", - "loss_value", - "entropy", - ] - ) + @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: shape = None if tensordict.ndimension() > 1: diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index f974f432673..250223ecc25 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -186,7 +186,7 @@ class SACLoss(LossModule): >>> loss = SACLoss(actor, qvalue, value) >>> batch = [2, ] >>> action = spec.rand(batch) - >>> loss_actor, loss_qvlaue, _, _, _, _ = loss( + >>> loss_actor, loss_qvalue, _, _, _, _ = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), @@ -194,6 +194,18 @@ class SACLoss(LossModule): ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() + The output keys can also be filtered using the :meth:`SACLoss.select_out_keys` + method. + + Examples: + >>> loss.select_out_keys('loss_actor', 'loss_qvalue') + >>> loss_actor, loss_qvalue = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_observation=torch.zeros(*batch, n_obs), + ... next_reward=torch.randn(*batch, 1)) + >>> loss_actor.backward() """ @dataclass @@ -251,6 +263,7 @@ def __init__( gamma: float = None, priority_key: str = None, ) -> None: + self._out_keys = None if not _has_functorch: raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR super().__init__() @@ -425,12 +438,18 @@ def in_keys(self): @property def out_keys(self): - keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"] - if self._version == 1: - keys.append("loss_value") - return keys - - @dispatch() + if self._out_keys is None: + keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"] + if self._version == 1: + keys.append("loss_value") + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: shape = None if tensordict.ndimension() > 1: