Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 11, 2024
1 parent d0e4c04 commit 8aac458
Showing 1 changed file with 52 additions and 24 deletions.
76 changes: 52 additions & 24 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,23 +203,37 @@ def __init__(

@property
def _targets(self):
return TensorDict(
{name: getattr(self.loss_module, name) for name in self._target_names},
[],
)
targets = self.__dict__.get("_targets_val", None)
if targets is None:
targets = self.__dict__["_targets_val"] = TensorDict(
{name: getattr(self.loss_module, name) for name in self._target_names},
[],
)
return targets

@_targets.setter
def _targets(self, targets):
self.__dict__["_targets_val"] = targets

@property
def _sources(self):
return TensorDict(
{name: getattr(self.loss_module, name) for name in self._source_names},
[],
)
sources = self.__dict__.get("_sources_val", None)
if sources is None:
sources = self.__dict__["_sources_val"] = TensorDict(
{name: getattr(self.loss_module, name) for name in self._source_names},
[],
)
return sources

@_sources.setter
def _sources(self, sources):
self.__dict__["_sources_val"] = sources

def init_(self) -> None:
if self.initialized:
warnings.warn("Updated already initialized.")
found_distinct = False
self._distinct = {}
self._distinct_and_params = {}
for key, source in self._sources.items(True, True):
if not isinstance(key, tuple):
key = (key,)
Expand All @@ -228,8 +242,12 @@ def init_(self) -> None:
# for p_source, p_target in zip(source, target):
if target.requires_grad:
raise RuntimeError("the target parameter is part of a graph.")
self._distinct[key] = target.data_ptr() != source.data.data_ptr()
found_distinct = found_distinct or self._distinct[key]
self._distinct_and_params[key] = (
target.is_leaf
and source.requires_grad
and target.data_ptr() != source.data.data_ptr()
)
found_distinct = found_distinct or self._distinct_and_params[key]
target.data.copy_(source.data)
if not found_distinct:
raise RuntimeError(
Expand All @@ -239,6 +257,22 @@ def init_(self) -> None:
"to True if it is not done by default. "
f"If no target parameter is needed, do not use a target updater such as {type(self)}."
)
# filter the target_ out
def filter_target(key):
if isinstance(key, tuple):
return (filter_target(key[0]), *key[1:])
return key[7:]

self._sources = self._sources.select(
*[
filter_target(key)
for (key, val) in self._distinct_and_params.items()
if val
]
).lock_()
self._targets = self._targets.select(
*(key for (key, val) in self._distinct_and_params.items() if val)
).lock_()

self.initialized = True

Expand All @@ -248,19 +282,11 @@ def step(self) -> None:
f"{self.__class__.__name__} must be "
f"initialized (`{self.__class__.__name__}.init_()`) before calling step()"
)
for key, source in self._sources.items(True, True):
if not isinstance(key, tuple):
key = (key,)
key = ("target_" + key[0], *key[1:])
if not self._distinct[key]:
continue
target = self._targets[key]
for key, param in self._sources.items():
target = self._targets.get("target_{}".format(key))
if target.requires_grad:
raise RuntimeError("the target parameter is part of a graph.")
if target.is_leaf:
self._step(source, target)
else:
target.copy_(source)
self._step(param, target)

def _step(self, p_source: Tensor, p_target: Tensor) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -326,8 +352,10 @@ def __init__(
super(SoftUpdate, self).__init__(loss_module)
self.eps = eps

def _step(self, p_source: Tensor, p_target: Tensor) -> None:
p_target.data.copy_(p_target.data * self.eps + p_source.data * (1 - self.eps))
def _step(
self, p_source: Tensor | TensorDictBase, p_target: Tensor | TensorDictBase
) -> None:
p_target.data.mul_(self.eps).add_(p_source.data, (1 - self.eps))


class HardUpdate(TargetNetUpdater):
Expand Down

0 comments on commit 8aac458

Please sign in to comment.