diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 017394de04b..31954005195 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -203,23 +203,37 @@ def __init__( @property def _targets(self): - return TensorDict( - {name: getattr(self.loss_module, name) for name in self._target_names}, - [], - ) + targets = self.__dict__.get("_targets_val", None) + if targets is None: + targets = self.__dict__["_targets_val"] = TensorDict( + {name: getattr(self.loss_module, name) for name in self._target_names}, + [], + ) + return targets + + @_targets.setter + def _targets(self, targets): + self.__dict__["_targets_val"] = targets @property def _sources(self): - return TensorDict( - {name: getattr(self.loss_module, name) for name in self._source_names}, - [], - ) + sources = self.__dict__.get("_sources_val", None) + if sources is None: + sources = self.__dict__["_sources_val"] = TensorDict( + {name: getattr(self.loss_module, name) for name in self._source_names}, + [], + ) + return sources + + @_sources.setter + def _sources(self, sources): + self.__dict__["_sources_val"] = sources def init_(self) -> None: if self.initialized: warnings.warn("Updated already initialized.") found_distinct = False - self._distinct = {} + self._distinct_and_params = {} for key, source in self._sources.items(True, True): if not isinstance(key, tuple): key = (key,) @@ -228,8 +242,12 @@ def init_(self) -> None: # for p_source, p_target in zip(source, target): if target.requires_grad: raise RuntimeError("the target parameter is part of a graph.") - self._distinct[key] = target.data_ptr() != source.data.data_ptr() - found_distinct = found_distinct or self._distinct[key] + self._distinct_and_params[key] = ( + target.is_leaf + and source.requires_grad + and target.data_ptr() != source.data.data_ptr() + ) + found_distinct = found_distinct or self._distinct_and_params[key] target.data.copy_(source.data) if not found_distinct: raise RuntimeError( @@ -240,6 +258,23 @@ def init_(self) -> None: f"If no target parameter is needed, do not use a target updater such as {type(self)}." ) + # filter the target_ out + def filter_target(key): + if isinstance(key, tuple): + return (filter_target(key[0]), *key[1:]) + return key[7:] + + self._sources = self._sources.select( + *[ + filter_target(key) + for (key, val) in self._distinct_and_params.items() + if val + ] + ).lock_() + self._targets = self._targets.select( + *(key for (key, val) in self._distinct_and_params.items() if val) + ).lock_() + self.initialized = True def step(self) -> None: @@ -248,19 +283,11 @@ def step(self) -> None: f"{self.__class__.__name__} must be " f"initialized (`{self.__class__.__name__}.init_()`) before calling step()" ) - for key, source in self._sources.items(True, True): - if not isinstance(key, tuple): - key = (key,) - key = ("target_" + key[0], *key[1:]) - if not self._distinct[key]: - continue - target = self._targets[key] + for key, param in self._sources.items(): + target = self._targets.get("target_{}".format(key)) if target.requires_grad: raise RuntimeError("the target parameter is part of a graph.") - if target.is_leaf: - self._step(source, target) - else: - target.copy_(source) + self._step(param, target) def _step(self, p_source: Tensor, p_target: Tensor) -> None: raise NotImplementedError @@ -326,8 +353,10 @@ def __init__( super(SoftUpdate, self).__init__(loss_module) self.eps = eps - def _step(self, p_source: Tensor, p_target: Tensor) -> None: - p_target.data.copy_(p_target.data * self.eps + p_source.data * (1 - self.eps)) + def _step( + self, p_source: Tensor | TensorDictBase, p_target: Tensor | TensorDictBase + ) -> None: + p_target.data.lerp_(p_source.data, 1 - self.eps) class HardUpdate(TargetNetUpdater):