Skip to content

Commit

Permalink
[Feature] Auto-init updaters and raise a warning if not present (#1092)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 26, 2023
1 parent 32339da commit cd02ee0
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 23 deletions.
2 changes: 0 additions & 2 deletions examples/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,6 @@ def env_factory(num_workers):
rewards_eval = []

# Main loop
target_net_updater.init_()

collected_frames = 0
pbar = tqdm.tqdm(total=cfg.total_frames)
r0 = None
Expand Down
2 changes: 0 additions & 2 deletions examples/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,6 @@ def env_factory(num_workers):
rewards_eval = []

# Main loop
target_net_updater.init_()

collected_frames = 0

pbar = tqdm.tqdm(total=cfg.total_frames)
Expand Down
2 changes: 0 additions & 2 deletions examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,6 @@ def main(cfg: "DictConfig"): # noqa: F821
rewards_eval = []

# Main loop
target_net_updater.init_()

collected_frames = 0
pbar = tqdm.tqdm(total=cfg.total_frames)
r0 = None
Expand Down
13 changes: 11 additions & 2 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import warnings
from copy import deepcopy

from packaging import version as pack_version
Expand Down Expand Up @@ -1963,7 +1964,6 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device):

if delay_qvalue:
target_updater = SoftUpdate(loss_fn)
target_updater.init_()

with _check_td_steady(td):
loss = loss_fn(td)
Expand Down Expand Up @@ -3728,7 +3728,6 @@ def __init__(self):
)
elif mode == "soft":
upd = SoftUpdate(module, 1 - 1 / value_network_update_interval)
upd.init_()
for _, _v in upd._targets.items(True, True):
if _v.dtype is not torch.int64:
_v.copy_(torch.randn_like(_v))
Expand Down Expand Up @@ -5088,6 +5087,16 @@ def __init__(self, compare_against, expand_dim):
loss_module.module_b_params.flatten_keys()[key].requires_grad


@pytest.mark.parametrize("updater", [HardUpdate, SoftUpdate])
def test_updater_warning(updater):
dqn = DQNLoss(torch.nn.Linear(3, 4), delay_value=True, action_space="one_hot")
with pytest.warns(UserWarning):
dqn.target_value_network_params
updater(dqn)
with warnings.catch_warnings():
dqn.target_value_network_params


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 0 additions & 2 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,8 +1104,6 @@ def __init__(
make_log_softmax: bool = True,
):

action_space, spec = _process_action_space_spec(action_space, spec)

action_space, spec = _process_action_space_spec(action_space, spec)
self.action_space = action_space
self.action_value_key = action_value_key
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def ensure_tensordict_compatible(
out_keys: Optional[Iterable[str]] = None,
safe: bool = False,
wrapper_type: Optional[Type] = TensorDictModule,
**kwargs,
):
"""Checks and ensures an object with forward method is TensorDict compatible."""
if is_tensordict_compatible(module):
Expand Down Expand Up @@ -395,7 +396,6 @@ def ensure_tensordict_compatible(
)

# TODO: Check whether out_keys match (at least in number) if they are provided.
kwargs = {}
if in_keys is not None:
kwargs["in_keys"] = in_keys
if out_keys is not None:
Expand Down
12 changes: 12 additions & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import warnings
from copy import deepcopy
from typing import Iterator, List, Optional, Tuple, Union

Expand All @@ -16,6 +17,7 @@
from torch import nn, Tensor
from torch.nn import Parameter

from torchrl._utils import RL_WARNINGS
from torchrl.modules.utils import Buffer
from torchrl.objectives.utils import ValueEstimators
from torchrl.objectives.value import ValueEstimatorBase
Expand Down Expand Up @@ -60,6 +62,7 @@ def __init__(self):
super().__init__()
self._param_maps = {}
self._value_estimator = None
self._has_update_associated = False
# self.register_forward_pre_hook(_parameters_to_tensordict)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand Down Expand Up @@ -331,6 +334,15 @@ def _target_param_getter(self, network_name):
target_name = "_target_" + network_name + "_params"
param_name = network_name + "_params"
if target_name in self.__dict__:
if not self._has_update_associated and RL_WARNINGS:
warnings.warn(
"No target network updater has been associated "
"with this loss module, but target parameters have been found."
"While this is supported, it is expected that the target network "
"updates will be manually performed. You can deactivate this warning "
"by turning the RL_WARNINGS env variable to False.",
category=UserWarning,
)
target_params = getattr(self, target_name)
if target_params is not None:
# get targets and update
Expand Down
14 changes: 8 additions & 6 deletions torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def __init__(
super().__init__()
self.delay_value = delay_value
value_network = ensure_tensordict_compatible(
module=value_network, wrapper_type=QValueActor
module=value_network,
wrapper_type=QValueActor,
action_space=action_space,
)

self.convert_to_functional(
Expand All @@ -91,11 +93,11 @@ def __init__(
action_space = value_network.spec
except AttributeError:
# let's try with action_space then
pass
try:
action_space = self.value_network.action_space
except AttributeError:
raise ValueError(self.ACTION_SPEC_ERROR)
try:
action_space = value_network.action_space
except AttributeError:
raise ValueError(self.ACTION_SPEC_ERROR)

self.action_space = _find_action_space(action_space)

if gamma is not None:
Expand Down
7 changes: 6 additions & 1 deletion torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import functools
import warnings
from enum import Enum
from typing import Iterable, Optional, Union

Expand Down Expand Up @@ -138,7 +139,7 @@ class TargetNetUpdater:

def __init__(
self,
loss_module: Union["DQNLoss", "DDPGLoss", "SACLoss", "TD3Loss"], # noqa: F821
loss_module: "LossModule", # noqa: F821
):

_target_names = []
Expand Down Expand Up @@ -185,6 +186,7 @@ def __init__(
self._source_names = _source_names
self.loss_module = loss_module
self.initialized = False
self.init_()

@property
def _targets(self):
Expand All @@ -201,6 +203,8 @@ def _sources(self):
)

def init_(self) -> None:
if self.initialized:
warnings.warn("Updated already initialized.")
for key, source in self._sources.items(True, True):
if not isinstance(key, tuple):
key = (key,)
Expand All @@ -211,6 +215,7 @@ def init_(self) -> None:
raise RuntimeError("the target parameter is part of a graph.")
target.data.copy_(source.data)
self.initialized = True
self.loss_module._has_update_associated = True

def step(self) -> None:
if not self.initialized:
Expand Down
2 changes: 0 additions & 2 deletions torchrl/trainers/helpers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def make_target_updater(
target_net_updater = HardUpdate(
loss_module, cfg.value_network_update_interval
)
# assert len(target_net_updater.net_pairs) == 3, "length of target_net_updater nets should be 3"
target_net_updater.init_()
else:
if cfg.hard_update:
raise RuntimeError(
Expand Down
2 changes: 0 additions & 2 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,8 +1061,6 @@ def ceil_div(x, y):
from torchrl.objectives.utils import SoftUpdate

target_net_updater = SoftUpdate(loss_module, eps=1 - tau)
# This class will raise an error if `init_` is not called first.
target_net_updater.init_()

###############################################################################
# Optimizer
Expand Down
1 change: 0 additions & 1 deletion tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,6 @@ def get_loss_module(actor, gamma):
# Get model
actor, actor_explore = make_model(test_env)
loss_module, target_net_updater = get_loss_module(actor, gamma)
target_net_updater.init_()

collector = get_collector(
stats, num_collectors, actor_explore, frames_per_batch, total_frames, device
Expand Down

0 comments on commit cd02ee0

Please sign in to comment.