Skip to content

Commit

Permalink
[Refactor] Better updaters (#1184)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 23, 2023
1 parent b4cf4d7 commit 12ad69e
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def env_factory(num_workers):
loss_module.make_value_estimator(gamma=cfg.gamma)

# Define Target Network Updater
target_net_updater = SoftUpdate(loss_module, cfg.target_update_polyak)
target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak)

# Make Off-Policy Collector
collector = SyncDataCollector(
Expand Down
2 changes: 1 addition & 1 deletion examples/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def env_factory(num_workers):
loss_module.make_value_estimator(gamma=cfg.gamma)

# Define Target Network Updater
target_net_updater = SoftUpdate(loss_module, cfg.target_update_polyak)
target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak)

# Make Off-Policy Collector
collector = SyncDataCollector(
Expand Down
2 changes: 1 addition & 1 deletion examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def main(cfg: "DictConfig"): # noqa: F821
loss_module.make_value_estimator(gamma=cfg.gamma)

# Define Target Network Updater
target_net_updater = SoftUpdate(loss_module, cfg.target_update_polyak)
target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak)

# Make Off-Policy Collector
collector = MultiaSyncDataCollector(
Expand Down
51 changes: 42 additions & 9 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,7 +1973,7 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device):
)

if delay_qvalue:
target_updater = SoftUpdate(loss_fn)
target_updater = SoftUpdate(loss_fn, tau=0.05)

with _check_td_steady(td):
loss = loss_fn(td)
Expand Down Expand Up @@ -3709,15 +3709,19 @@ def __init__(self):
RuntimeError, match="Your module seems to have a target tensor list "
):
if mode == "hard":
upd = HardUpdate(module, value_network_update_interval)
upd = HardUpdate(
module, value_network_update_interval=value_network_update_interval
)
elif mode == "soft":
upd = SoftUpdate(module, 1 - 1 / value_network_update_interval)
upd = SoftUpdate(module, eps=1 - 1 / value_network_update_interval)

class custom_module(LossModule):
def __init__(self):
def __init__(self, delay_module=True):
super().__init__()
module1 = torch.nn.BatchNorm2d(10).eval()
self.convert_to_functional(module1, "module1", create_target_params=True)
self.convert_to_functional(
module1, "module1", create_target_params=delay_module
)
module2 = torch.nn.BatchNorm2d(10).eval()
self.module2 = module2
iterator_params = self.target_module1_params.values(
Expand All @@ -3729,15 +3733,38 @@ def __init__(self):
else:
target.data += 10

module = custom_module(delay_module=False)
with pytest.raises(RuntimeError, match="The target and source data are identical"):
if mode == "hard":
upd = HardUpdate(
module, value_network_update_interval=value_network_update_interval
)
elif mode == "soft":
upd = SoftUpdate(
module,
eps=1 - 1 / value_network_update_interval,
)
else:
raise NotImplementedError

module = custom_module().to(device).to(dtype)

if mode == "soft":
with pytest.raises(ValueError, match="One and only one argument"):
upd = SoftUpdate(
module,
eps=1 - 1 / value_network_update_interval,
tau=0.1,
)

_ = module.module1_params
_ = module.target_module1_params
if mode == "hard":
upd = HardUpdate(
module, value_network_update_interval=value_network_update_interval
)
elif mode == "soft":
upd = SoftUpdate(module, 1 - 1 / value_network_update_interval)
upd = SoftUpdate(module, eps=1 - 1 / value_network_update_interval)
for _, _v in upd._targets.items(True, True):
if _v.dtype is not torch.int64:
_v.copy_(torch.randn_like(_v))
Expand Down Expand Up @@ -5346,14 +5373,20 @@ def fun(a, b, time_dim=-2):
assert (z2 == 2).all()


@pytest.mark.parametrize("updater", [HardUpdate, SoftUpdate])
def test_updater_warning(updater):
@pytest.mark.parametrize(
"updater,kwarg",
[
(HardUpdate, {"value_network_update_interval": 1000}),
(SoftUpdate, {"eps": 0.99}),
],
)
def test_updater_warning(updater, kwarg):
with warnings.catch_warnings():
dqn = DQNLoss(torch.nn.Linear(3, 4), delay_value=True, action_space="one_hot")
with pytest.warns(UserWarning):
dqn.target_value_network_params
with warnings.catch_warnings():
updater(dqn)
updater(dqn, **kwarg)
with warnings.catch_warnings():
dqn.target_value_network_params

Expand Down
44 changes: 42 additions & 2 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def _sources(self):
def init_(self) -> None:
if self.initialized:
warnings.warn("Updated already initialized.")
found_distinct = False
self._distinct = {}
for key, source in self._sources.items(True, True):
if not isinstance(key, tuple):
key = (key,)
Expand All @@ -218,7 +220,18 @@ def init_(self) -> None:
# for p_source, p_target in zip(source, target):
if target.requires_grad:
raise RuntimeError("the target parameter is part of a graph.")
self._distinct[key] = target.data_ptr() != source.data.data_ptr()
found_distinct = found_distinct or self._distinct[key]
target.data.copy_(source.data)
if not found_distinct:
raise RuntimeError(
f"The target and source data are identical for all params. "
"Have you created proper target parameters? "
"If the loss has a ``delay_value`` kwarg, make sure to set it "
"to True if it is not done by default. "
f"If no target parameter is needed, do not use a target updater such as {type(self)}."
)

self.initialized = True

def step(self) -> None:
Expand All @@ -231,6 +244,8 @@ def step(self) -> None:
if not isinstance(key, tuple):
key = (key,)
key = ("target_" + key[0], *key[1:])
if not self._distinct[key]:
continue
target = self._targets[key]
if target.requires_grad:
raise RuntimeError("the target parameter is part of a graph.")
Expand All @@ -255,14 +270,17 @@ class SoftUpdate(TargetNetUpdater):
This was proposed in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf
One and only one decay factor (tau or eps) must be specified.
Args:
loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated.
eps (scalar): epsilon in the update equation:
.. math::
\theta_t = \theta_{t-1} * \epsilon + \theta_t * (1-\epsilon)
Defaults to 0.999
Exclusive with ``tau``.
tau (scalar): Polyak tau. It is equal to ``1-eps``, and exclusive with it.
"""

def __init__(
Expand All @@ -274,8 +292,27 @@ def __init__(
"REDQLoss", # noqa: F821
"TD3Loss", # noqa: F821
],
eps: float = 0.999,
*,
eps: float = None,
tau: Optional[float] = None,
):
if eps is None and tau is None:
warnings.warn(
"Neither eps nor tau was provided. Taking the default value "
"eps=0.999. This behaviour will soon be deprecated.",
category=DeprecationWarning,
)
eps = 0.999
if (eps is None) ^ (tau is None):
if eps is None:
eps = 1 - tau
else:
raise ValueError("One and only one argument (tau or eps) can be specified.")
if eps < 0.5:
warnings.warn(
"Found an eps value < 0.5, which is unexpected. "
"You may want to use the `tau` keyword argument instead."
)
if not (eps <= 1.0 and eps >= 0.0):
raise ValueError(
f"Got eps = {eps} when it was supposed to be between 0 and 1."
Expand All @@ -295,13 +332,16 @@ class HardUpdate(TargetNetUpdater):
Args:
loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated.
Keyword Args:
value_network_update_interval (scalar): how often the target network should be updated.
default: 1000
"""

def __init__(
self,
loss_module: Union["DQNLoss", "DDPGLoss", "SACLoss", "TD3Loss"], # noqa: F821
*,
value_network_update_interval: float = 1000,
):
super(HardUpdate, self).__init__(loss_module)
Expand Down
5 changes: 3 additions & 2 deletions torchrl/trainers/helpers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def make_target_updater(
if cfg.loss == "double":
if not cfg.hard_update:
target_net_updater = SoftUpdate(
loss_module, 1 - 1 / cfg.value_network_update_interval
loss_module, eps=1 - 1 / cfg.value_network_update_interval
)
else:
target_net_updater = HardUpdate(
loss_module, cfg.value_network_update_interval
loss_module,
value_network_update_interval=cfg.value_network_update_interval,
)
else:
if cfg.hard_update:
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def get_collector(
def get_loss_module(actor, gamma):
loss_module = DQNLoss(actor, delay_value=True)
loss_module.make_value_estimator(gamma=gamma)
target_updater = SoftUpdate(loss_module)
target_updater = SoftUpdate(loss_module, eps=0.995)
return loss_module, target_updater


Expand Down

0 comments on commit 12ad69e

Please sign in to comment.