diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index de83b023662..d67db859713 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -11,10 +11,10 @@ from typing import Tuple import torch - from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase -from tensordict.utils import NestedKey + +from tensordict.utils import NestedKey, unravel_key from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -216,6 +216,9 @@ def __init__( self.actor_critic.module[1] = self.value_network self.actor_in_keys = actor_network.in_keys + self.value_exclusive_keys = set(self.value_network.in_keys) - ( + set(self.actor_in_keys) | set(self.actor_network.out_keys) + ) self.loss_function = loss_function @@ -233,14 +236,15 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: self._set_in_keys() def _set_in_keys(self): - keys = [ - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), + in_keys = { + unravel_key(("next", self.tensor_keys.reward)), + unravel_key(("next", self.tensor_keys.done)), *self.actor_in_keys, - *[("next", key) for key in self.actor_in_keys], + *[unravel_key(("next", key)) for key in self.actor_in_keys], *self.value_network.in_keys, - ] - self._in_keys = list(set(keys)) + *[unravel_key(("next", key)) for key in self.value_network.in_keys], + } + self._in_keys = sorted(in_keys, key=str) @property def in_keys(self): @@ -293,7 +297,9 @@ def _loss_actor( self, tensordict: TensorDictBase, ) -> torch.Tensor: - td_copy = tensordict.select(*self.actor_in_keys).detach() + td_copy = tensordict.select( + *self.actor_in_keys, *self.value_exclusive_keys + ).detach() td_copy = self.actor_network( td_copy, params=self.actor_network_params,