Skip to content

Commit

Permalink
[BugFix, Feature] Vmap randomness in losses (pytorch#1740)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <vincentmoens@gmail.com>
  • Loading branch information
BY571 and vmoens authored Jan 9, 2024
1 parent 11a82c3 commit eb603ab
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 30 deletions.
63 changes: 58 additions & 5 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import itertools
import operator
import re
import warnings
from copy import deepcopy
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -120,6 +121,7 @@
from torchrl.objectives.redq import REDQLoss
from torchrl.objectives.reinforce import ReinforceLoss
from torchrl.objectives.utils import (
_vmap_func,
HardUpdate,
hold_out_net,
SoftUpdate,
Expand Down Expand Up @@ -233,6 +235,52 @@ def set_advantage_keys_through_loss_test(
)


@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("vmap_randomness", (None, "different", "same", "error"))
@pytest.mark.parametrize("dropout", (0.0, 0.1))
def test_loss_vmap_random(device, vmap_randomness, dropout):
class VmapTestLoss(LossModule):
def __init__(self):
super().__init__()
layers = [nn.Linear(4, 4), nn.ReLU()]
if dropout > 0.0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(4, 4))
net = nn.Sequential(*layers).to(device)
model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
self.convert_to_functional(model, "model", expand_dim=4)
self.vmap_model = _vmap_func(
self.model,
(None, 0),
randomness="error"
if vmap_randomness == "error"
else self.vmap_randomness,
)

def forward(self, td):
out = self.vmap_model(td, self.model_params)
return {"loss": out["action"].mean()}

loss_module = VmapTestLoss()
td = TensorDict({"obs": torch.randn(3, 4).to(device)}, [3])

# If user sets vmap randomness to a specific value
if vmap_randomness in ("different", "same") and dropout > 0.0:
loss_module.set_vmap_randomness(vmap_randomness)
# Fail case
elif vmap_randomness == "error" and dropout > 0.0:
with pytest.raises(RuntimeError) as exc_info:
loss_module(td)["loss"]

# Accessing cause of the caught exception
cause = exc_info.value.__cause__
assert re.match(
r"vmap: called random operation while in randomness error mode", str(cause)
)
return
loss_module(td)["loss"]


class TestDQN(LossModuleTestBase):
seed = 0

Expand Down Expand Up @@ -1803,12 +1851,17 @@ def _create_mock_actor(
device="cpu",
in_keys=None,
out_keys=None,
dropout=0.0,
):
# Actor
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
module = nn.Linear(obs_dim, action_dim)
module = nn.Sequential(
nn.Linear(obs_dim, obs_dim),
nn.Dropout(dropout),
nn.Linear(obs_dim, action_dim),
)
actor = Actor(
spec=action_spec, module=module, in_keys=in_keys, out_keys=out_keys
)
Expand Down Expand Up @@ -1984,6 +2037,7 @@ def _create_seq_mock_data_td3(
@pytest.mark.parametrize("noise_clip", [0.1, 1.0])
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("use_action_spec", [True, False])
@pytest.mark.parametrize("dropout", [0.0, 0.1])
def test_td3(
self,
delay_actor,
Expand All @@ -1993,9 +2047,10 @@ def test_td3(
noise_clip,
td_est,
use_action_spec,
dropout,
):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(device=device)
actor = self._create_mock_actor(device=device, dropout=dropout)
value = self._create_mock_value(device=device)
td = self._create_mock_data_td3(device=device)
if use_action_spec:
Expand Down Expand Up @@ -4876,7 +4931,6 @@ def test_cql(
device,
td_est,
):

torch.manual_seed(self.seed)
td = self._create_mock_data_cql(device=device)

Expand Down Expand Up @@ -6075,7 +6129,7 @@ def zero_param(p):
p.grad = None
loss_objective.backward()
named_parameters = loss_fn.named_parameters()
for (name, other_p) in named_parameters:
for name, other_p in named_parameters:
p = params.get(tuple(name.split(".")))
assert other_p.shape == p.shape
assert other_p.dtype == p.dtype
Expand Down Expand Up @@ -11137,7 +11191,6 @@ def test_set_deprecated_keys(self, adv, kwargs):
)

with pytest.warns(DeprecationWarning):

if adv is VTrace:
actor_net = TensorDictModule(
nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
Expand Down
27 changes: 25 additions & 2 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
from torch import nn
from torch.nn import Parameter

from torchrl._utils import RL_WARNINGS
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives.utils import ValueEstimators

from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators
from torchrl.objectives.value import ValueEstimatorBase


Expand Down Expand Up @@ -81,6 +81,7 @@ class _AcceptedKeys:

pass

_vmap_randomness = None
default_value_estimator: ValueEstimators = None
SEP = "."
TARGET_NET_WARNING = (
Expand Down Expand Up @@ -429,6 +430,28 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams

return self

@property
def vmap_randomness(self):
if self._vmap_randomness is None:
do_break = False
for val in self.__dict__.values():
if isinstance(val, torch.nn.Module):
for module in val.modules():
if isinstance(module, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
do_break = True
break
if do_break:
# double break
break
else:
self._vmap_randomness = "error"

return self._vmap_randomness

def set_vmap_randomness(self, value):
self._vmap_randomness = value

@staticmethod
def _make_meta_params(param):
is_param = isinstance(param, nn.Parameter)
Expand Down
8 changes: 6 additions & 2 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,12 @@ def __init__(
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)

self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)

@property
def target_entropy(self):
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)

@property
def device(self) -> torch.device:
Expand Down
9 changes: 6 additions & 3 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
):

super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority_key=priority_key)
Expand Down Expand Up @@ -318,8 +317,12 @@ def __init__(
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma

self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
self._vmap_getdist = _vmap_func(self.actor_network, func="get_dist_params")
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_getdist = _vmap_func(
self.actor_network, func="get_dist_params", randomness=self.vmap_randomness
)

@property
def target_entropy(self):
Expand Down
14 changes: 9 additions & 5 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,13 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
if self._version == 1:
self._vmap_qnetwork00 = _vmap_func(qvalue_network)
self._vmap_qnetwork00 = _vmap_func(
qvalue_network, randomness=self.vmap_randomness
)

@property
def target_entropy_buffer(self):
Expand Down Expand Up @@ -411,7 +415,6 @@ def target_entropy(self):
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):

action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape
else:
action_container_shape = action_spec.shape
Expand Down Expand Up @@ -753,7 +756,6 @@ def _value_loss(
return loss_value, {}

def _alpha_loss(self, log_prob: Tensor) -> Tensor:

if self.target_entropy is not None:
# we can compute this loss even if log_alpha is not a parameter
alpha_loss = -self.log_alpha * (log_prob + self.target_entropy)
Expand Down Expand Up @@ -1049,7 +1051,9 @@ def __init__(
self.register_buffer(
"target_entropy", torch.tensor(target_entropy, device=device)
)
self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
9 changes: 6 additions & 3 deletions torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
) -> None:

super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority=priority_key)
Expand Down Expand Up @@ -296,8 +295,12 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
self._vmap_actor_network00 = _vmap_func(self.actor_network)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
)

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
32 changes: 23 additions & 9 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import functools
import re
import warnings
from enum import Enum
from typing import Iterable, Optional, Union
Expand All @@ -13,6 +14,7 @@
from tensordict.tensordict import TensorDict, TensorDictBase
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules import dropout

try:
from torch import vmap
Expand All @@ -29,6 +31,8 @@
"run `loss_module.make_value_estimator(ValueEstimators.<value_fun>, gamma=val)`."
)

RANDOM_MODULE_LIST = (dropout._DropoutNd,)


class ValueEstimators(Enum):
"""Value function enumerator for custom-built estimators.
Expand Down Expand Up @@ -478,13 +482,23 @@ def new_fun(self, netname=None):


def _vmap_func(module, *args, func=None, **kwargs):
def decorated_module(*module_args_params):
params = module_args_params[-1]
module_args = module_args_params[:-1]
with params.to_module(module):
if func is None:
return module(*module_args)
else:
return getattr(module, func)(*module_args)
try:

return vmap(decorated_module, *args, **kwargs) # noqa: TOR101
def decorated_module(*module_args_params):
params = module_args_params[-1]
module_args = module_args_params[:-1]
with params.to_module(module):
if func is None:
return module(*module_args)
else:
return getattr(module, func)(*module_args)

return vmap(decorated_module, *args, **kwargs) # noqa: TOR101

except RuntimeError as err:
if re.match(
r"vmap: called random operation while in randomness error mode", str(err)
):
raise RuntimeError(
"Please use <loss_module>.set_vmap_randomness('different') to handle random operations during vmap."
) from err

0 comments on commit eb603ab

Please sign in to comment.