diff --git a/test/test_cost.py b/test/test_cost.py index 649d50da935..97e901487a1 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1011,20 +1011,34 @@ def test_ddpg_notensordict(self): class TestTD3(LossModuleTestBase): seed = 0 - def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + in_keys=None, + out_keys=None, + ): # Actor action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) module = nn.Linear(obs_dim, action_dim) actor = Actor( - spec=action_spec, - module=module, + spec=action_spec, module=module, in_keys=in_keys, out_keys=out_keys ) return actor.to(device) def _create_mock_value( - self, batch=2, obs_dim=3, action_dim=4, device="cpu", out_keys=None + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + action_key="action", + observation_key="observation", ): # Actor class ValueClass(nn.Module): @@ -1038,7 +1052,7 @@ def forward(self, obs, act): module = ValueClass() value = ValueOperator( module=module, - in_keys=["observation", "action"], + in_keys=[observation_key, action_key], out_keys=out_keys, ) return value.to(device) @@ -1049,7 +1063,16 @@ def _create_mock_distributional_actor( raise NotImplementedError def _create_mock_data_td3( - self, batch=8, obs_dim=3, action_dim=4, atoms=None, device="cpu" + self, + batch=8, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + action_key="action", + observation_key="observation", + reward_key="reward", + done_key="done", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -1063,13 +1086,13 @@ def _create_mock_data_td3( td = TensorDict( batch_size=(batch,), source={ - "observation": obs, + observation_key: obs, "next": { - "observation": next_obs, - "done": done, - "reward": reward, + observation_key: next_obs, + done_key: done, + reward_key: reward, }, - "action": action, + action_key: action, }, device=device, ) @@ -1311,6 +1334,8 @@ def test_td3_tensordict_keys(self, td_est): "priority": "td_error", "state_action_value": "state_action_value", "action": "action", + "reward": "reward", + "done": "done", } self.tensordict_keys_test( @@ -1320,12 +1345,16 @@ def test_td3_tensordict_keys(self, td_est): ) value = self._create_mock_value(out_keys=["state_action_value_test"]) - loss_fn = DDPGLoss( + loss_fn = TD3Loss( actor, value, - loss_function="l2", + action_spec=actor.spec, ) - key_mapping = {"state_action_value": ("value", "state_action_value_test")} + key_mapping = { + "state_action_value": ("value", "state_action_value_test"), + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @pytest.mark.parametrize("spec", [True, False]) @@ -1353,6 +1382,43 @@ def test_constructor(self, spec, bounds): bounds=bounds, ) + # TODO: test for action_key, atm the action key of the TD3 loss is not configurable, + # since it is used in it's constructor + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + def test_td3_notensordict(self, observation_key, reward_key, done_key): + + torch.manual_seed(self.seed) + actor = self._create_mock_actor(in_keys=[observation_key]) + qvalue = self._create_mock_value( + observation_key=observation_key, out_keys=["state_action_value"] + ) + td = self._create_mock_data_td3( + observation_key=observation_key, reward_key=reward_key, done_key=done_key + ) + loss = TD3Loss(actor, qvalue, action_spec=actor.spec) + loss.set_keys(reward=reward_key, done=done_key) + + kwargs = { + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + f"next_{observation_key}": td.get(("next", observation_key)), + "action": td.get("action"), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + loss_val_td = loss(td) + loss_val = loss(**kwargs) + for i, key in enumerate(loss_val_td.keys()): + torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + # test select + loss.select_out_keys("loss_actor", "loss_qvalue") + loss_actor, loss_qvalue = loss(**kwargs) + assert loss_actor == loss_val_td["loss_actor"] + assert loss_qvalue == loss_val_td["loss_qvalue"] + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 44db7ec90cd..f268ed5dee0 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -7,7 +7,7 @@ from typing import Optional, Tuple import torch -from tensordict.nn import TensorDictModule +from tensordict.nn import dispatch, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey @@ -68,6 +68,96 @@ class TD3Loss(LossModule): delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used for data collection. Default is ``True``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.td3 import TD3Loss + >>> from tensordict.tensordict import TensorDict + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> module = nn.Linear(n_obs, n_act) + >>> actor = Actor( + ... module=module, + ... spec=spec) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> data = TensorDict({ + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... ("next", "observation"): torch.randn(*batch, n_obs), + ... }, batch) + >>> loss(data) + TensorDict( + fields={ + loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + next_state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + This class is compatible with non-tensordict based modules too and can be + used without recurring to any tensordict-related primitive. In this case, + the expected keyword arguments are: + ``["action", "next_reward", "next_done"]`` + in_keys of the actor and qvalue network + The return value is a tuple of tensors in the following order: + ``["loss_actor", "loss_qvalue", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator + >>> from torchrl.objectives.td3 import TD3Loss + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> module = nn.Linear(n_obs, n_act) + >>> actor = Actor( + ... module=module, + ... spec=spec) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec) + >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue") + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> loss_actor, loss_qvalue = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_reward=torch.randn(*batch, 1), + ... next_observation=torch.randn(*batch, n_obs)) + >>> loss_actor.backward() + """ @dataclass @@ -84,14 +174,29 @@ class _AcceptedKeys: Will be used for the underlying value estimator. Defaults to ``"state_action_value"``. priority (NestedKey): The input tensordict key where the target priority is written to. Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. """ action: NestedKey = "action" state_action_value: NestedKey = "state_action_value" priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 + out_keys = [ + "loss_actor", + "loss_qvalue", + "pred_value", + "state_action_value_actor", + "next_state_value", + "target_value", + ] def __init__( self, @@ -115,6 +220,7 @@ def __init__( ) super().__init__() + self._in_keys = None self._set_deprecated_ctor_keys(priority=priority_key) self.delay_actor = delay_actor @@ -178,8 +284,33 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys( value=self._tensor_keys.state_action_value, + reward=self.tensor_keys.reward, + done=self.tensor_keys.done, ) + self._set_in_keys() + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], + *self.qvalue_network.in_keys, + ] + self._in_keys = list(set(keys)) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = self.actor_network.in_keys tensordict_save = tensordict @@ -333,5 +464,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams else: raise NotImplementedError(f"Unknown value type {value_type}") - tensor_keys = {"value": self.tensor_keys.state_action_value} + tensor_keys = { + "value": self.tensor_keys.state_action_value, + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, + } self._value_estimator.set_keys(**tensor_keys)