From 3206ddbb118b2381d096bffb6c3dd5482f5760a1 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 7 Dec 2023 15:57:55 +0100 Subject: [PATCH 01/18] add randomness param to vmap for offpolicy algos --- torchrl/objectives/cql.py | 8 ++++++-- torchrl/objectives/iql.py | 4 +++- torchrl/objectives/redq.py | 4 +++- torchrl/objectives/sac.py | 6 ++++-- torchrl/objectives/td3.py | 8 ++++++-- 5 files changed, 22 insertions(+), 8 deletions(-) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 0c8caa5a60b..f16ffed2b79 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -335,8 +335,12 @@ def __init__( torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)), ) - self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) - self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network) + self._vmap_qvalue_networkN0 = _vmap_func( + self.qvalue_network, (None, 0), randomness="different" + ) + self._vmap_qvalue_network00 = _vmap_func( + self.qvalue_network, randomness="different" + ) @property def target_entropy(self): diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index e64dfa11f2d..72fdd46e289 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -286,7 +286,9 @@ def __init__( if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) + self._vmap_qvalue_networkN0 = _vmap_func( + self.qvalue_network, (None, 0), randomness="different" + ) @property def device(self) -> torch.device: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 347becc24ae..e34fb1bb305 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -318,7 +318,9 @@ def __init__( warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network) + self._vmap_qvalue_network00 = _vmap_func( + self.qvalue_network, randomness="different" + ) self._vmap_getdist = _vmap_func(self.actor_network, func="get_dist_params") @property diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index d0617dedc74..32814c91554 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -376,9 +376,11 @@ def __init__( if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0)) + self._vmap_qnetworkN0 = _vmap_func( + self.qvalue_network, (None, 0), randomness="different" + ) if self._version == 1: - self._vmap_qnetwork00 = _vmap_func(qvalue_network) + self._vmap_qnetwork00 = _vmap_func(qvalue_network, randomness="different") @property def target_entropy_buffer(self): diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 082873a2358..e50a8a724fd 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -296,8 +296,12 @@ def __init__( if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network) - self._vmap_actor_network00 = _vmap_func(self.actor_network) + self._vmap_qvalue_network00 = _vmap_func( + self.qvalue_network, randomness="different" + ) + self._vmap_actor_network00 = _vmap_func( + self.actor_network, randomness="different" + ) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: From 82ec8977de1e1de509950ec52ccea8787c4d775d Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 8 Dec 2023 14:10:22 +0100 Subject: [PATCH 02/18] update td3 with vmap_randomness --- torchrl/objectives/td3.py | 25 ++++++++++++++++++++++--- torchrl/objectives/utils.py | 28 +++++++++++++++++++--------- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index e50a8a724fd..cec4204c1de 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -201,6 +201,7 @@ class _AcceptedKeys: "next_state_value", "target_value", ] + _vmap_randomness = None def __init__( self, @@ -219,7 +220,6 @@ def __init__( priority_key: str = None, separate_losses: bool = False, ) -> None: - super().__init__() self._in_keys = None self._set_deprecated_ctor_keys(priority=priority_key) @@ -297,10 +297,10 @@ def __init__( warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma self._vmap_qvalue_network00 = _vmap_func( - self.qvalue_network, randomness="different" + self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_actor_network00 = _vmap_func( - self.actor_network, randomness="different" + self.actor_network, randomness=self.vmap_randomness ) def _forward_value_estimator_keys(self, **kwargs) -> None: @@ -347,6 +347,25 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) + @property + def vmap_randomness(self): + if self._vmap_randomness is None: + # look for nn.Dropout modules + dropouts = (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d) + for a, q in zip( + self.actor_network.modules(), self.qvalue_network.modules() + ): + if isinstance(a, dropouts) or isinstance(q, dropouts): + self._vmap_randomness = "different" + break + else: + self._vmap_randomness = "error" + + return self._vmap_randomness + + def set_vmap_randomness(self, value): + self._vmap_randomness = value + def actor_loss(self, tensordict): tensordict_actor_grad = tensordict.select(*self.actor_network.in_keys) with self.actor_network_params.to_module(self.actor_network): diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index c3e7dbc68ce..545c5d7f496 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -478,13 +478,23 @@ def new_fun(self, netname=None): def _vmap_func(module, *args, func=None, **kwargs): - def decorated_module(*module_args_params): - params = module_args_params[-1] - module_args = module_args_params[:-1] - with params.to_module(module): - if func is None: - return module(*module_args) - else: - return getattr(module, func)(*module_args) + try: - return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 + def decorated_module(*module_args_params): + params = module_args_params[-1] + module_args = module_args_params[:-1] + with params.to_module(module): + if func is None: + return module(*module_args) + else: + return getattr(module, func)(*module_args) + + return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 + + except RuntimeError as err: + if "vmap: called random operation while in randomness error mode" in str( + err + ): # better to use re.match here but anyway + raise RuntimeError( + "Please use loss_module.set_vmap_randomness to handle random operations during vmap." + ) from err From 29a5cda22c0c3ecaaa957717bc585da441b4c257 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 11 Dec 2023 14:30:38 +0100 Subject: [PATCH 03/18] undo vmap randomness zip actor and critic modules --- torchrl/objectives/td3.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index cec4204c1de..75c717335cf 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -350,13 +350,21 @@ def _cached_stack_actor_params(self): @property def vmap_randomness(self): if self._vmap_randomness is None: - # look for nn.Dropout modules - dropouts = (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d) - for a, q in zip( - self.actor_network.modules(), self.qvalue_network.modules() - ): - if isinstance(a, dropouts) or isinstance(q, dropouts): - self._vmap_randomness = "different" + RANDOM_MODULE_LIST = ( + torch.nn.Dropout, + torch.nn.Dropout2d, + torch.nn.Dropout3d, + ) + do_break = False + for val in self.__dict__.values(): + if isinstance(val, torch.nn.Module): + for module in val.modules(): + if isinstance(module, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" + do_break = True + break + if do_break: + # double break break else: self._vmap_randomness = "error" From 53ca415e36e48df168840eb45b26509bbceee893 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 11 Dec 2023 14:32:59 +0100 Subject: [PATCH 04/18] expand random module list --- torchrl/objectives/td3.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 75c717335cf..c98c674e788 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -26,6 +26,25 @@ ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator +RANDOM_MODULE_LIST = ( + torch.nn.Dropout, + torch.nn.Dropout2d, + torch.nn.Dropout3d, + torch.nn.AlphaDropout, + torch.nn.FeatureAlphaDropout, + torch.nn.GaussianDropout, + torch.nn.GaussianNoise, + torch.nn.SyncBatchNorm, + torch.nn.GroupNorm, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.SpatialDropout, + torch.nn.SpatialCrossMapLRN, +) + class TD3Loss(LossModule): """TD3 Loss module. @@ -350,11 +369,6 @@ def _cached_stack_actor_params(self): @property def vmap_randomness(self): if self._vmap_randomness is None: - RANDOM_MODULE_LIST = ( - torch.nn.Dropout, - torch.nn.Dropout2d, - torch.nn.Dropout3d, - ) do_break = False for val in self.__dict__.values(): if isinstance(val, torch.nn.Module): From 50aa8a4907953894d1a0e13dfec97ba11fa5ec4f Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 11 Dec 2023 14:38:42 +0100 Subject: [PATCH 05/18] move random_module_list to utils --- torchrl/objectives/td3.py | 20 +------------------- torchrl/objectives/utils.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index c98c674e788..b2b1f67a511 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -22,29 +22,11 @@ _vmap_func, default_value_kwargs, distance_loss, + RANDOM_MODULE_LIST, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -RANDOM_MODULE_LIST = ( - torch.nn.Dropout, - torch.nn.Dropout2d, - torch.nn.Dropout3d, - torch.nn.AlphaDropout, - torch.nn.FeatureAlphaDropout, - torch.nn.GaussianDropout, - torch.nn.GaussianNoise, - torch.nn.SyncBatchNorm, - torch.nn.GroupNorm, - torch.nn.LayerNorm, - torch.nn.LocalResponseNorm, - torch.nn.InstanceNorm1d, - torch.nn.InstanceNorm2d, - torch.nn.InstanceNorm3d, - torch.nn.SpatialDropout, - torch.nn.SpatialCrossMapLRN, -) - class TD3Loss(LossModule): """TD3 Loss module. diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 545c5d7f496..33b03871aa5 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -29,6 +29,21 @@ "run `loss_module.make_value_estimator(ValueEstimators., gamma=val)`." ) +RANDOM_MODULE_LIST = ( + nn.Dropout, + nn.Dropout2d, + nn.Dropout3d, + nn.AlphaDropout, + nn.FeatureAlphaDropout, + nn.SyncBatchNorm, + nn.GroupNorm, + nn.LayerNorm, + nn.LocalResponseNorm, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, +) + class ValueEstimators(Enum): """Value function enumerator for custom-built estimators. From 3f54635f91f0f7d7f7449958b6c632d084f30c71 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 13 Dec 2023 10:17:02 +0100 Subject: [PATCH 06/18] update sac losses --- torchrl/objectives/sac.py | 57 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 32814c91554..49f4126e710 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -29,6 +29,7 @@ _vmap_func, default_value_kwargs, distance_loss, + RANDOM_MODULE_LIST, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -257,6 +258,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + _vmap_randomness = None default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -377,10 +379,12 @@ def __init__( warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma self._vmap_qnetworkN0 = _vmap_func( - self.qvalue_network, (None, 0), randomness="different" + self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) if self._version == 1: - self._vmap_qnetwork00 = _vmap_func(qvalue_network, randomness="different") + self._vmap_qnetwork00 = _vmap_func( + qvalue_network, randomness=self.vmap_randomness + ) @property def target_entropy_buffer(self): @@ -532,6 +536,28 @@ def out_keys(self): def out_keys(self, values): self._out_keys = values + @property + def vmap_randomness(self): + if self._vmap_randomness is None: + do_break = False + for val in self.__dict__.values(): + if isinstance(val, torch.nn.Module): + for module in val.modules(): + if isinstance(module, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" + do_break = True + break + if do_break: + # double break + break + else: + self._vmap_randomness = "error" + + return self._vmap_randomness + + def set_vmap_randomness(self, value): + self._vmap_randomness = value + @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: shape = None @@ -946,6 +972,7 @@ class _AcceptedKeys: default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 delay_actor: bool = False + _vmap_randomness = None out_keys = [ "loss_actor", "loss_qvalue", @@ -1051,7 +1078,9 @@ def __init__( self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) - self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0)) + self._vmap_qnetworkN0 = _vmap_func( + self.qvalue_network, (None, 0), randomness=self.vmap_randomness + ) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -1085,6 +1114,28 @@ def in_keys(self): def in_keys(self, values): self._in_keys = values + @property + def vmap_randomness(self): + if self._vmap_randomness is None: + do_break = False + for val in self.__dict__.values(): + if isinstance(val, torch.nn.Module): + for module in val.modules(): + if isinstance(module, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" + do_break = True + break + if do_break: + # double break + break + else: + self._vmap_randomness = "error" + + return self._vmap_randomness + + def set_vmap_randomness(self, value): + self._vmap_randomness = value + @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: shape = None From c4ff9264c78f322b55170a95fe7f3fdca282545a Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 13 Dec 2023 10:25:26 +0100 Subject: [PATCH 07/18] update iql objective --- torchrl/objectives/iql.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 72fdd46e289..1dab2cdf8b5 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -20,6 +20,7 @@ _vmap_func, default_value_kwargs, distance_loss, + RANDOM_MODULE_LIST, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -221,6 +222,7 @@ class _AcceptedKeys: "loss_value", "entropy", ] + _vmap_randomness = None def __init__( self, @@ -321,6 +323,28 @@ def in_keys(self): def in_keys(self, values): self._in_keys = values + @property + def vmap_randomness(self): + if self._vmap_randomness is None: + do_break = False + for val in self.__dict__.values(): + if isinstance(val, torch.nn.Module): + for module in val.modules(): + if isinstance(module, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" + do_break = True + break + if do_break: + # double break + break + else: + self._vmap_randomness = "error" + + return self._vmap_randomness + + def set_vmap_randomness(self, value): + self._vmap_randomness = value + @staticmethod def loss_value_diff(diff, expectile=0.8): """Loss function for iql expectile value difference.""" From 1c44c35c3665fce236e792cad9eeaa0352a53b57 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 13 Dec 2023 10:29:13 +0100 Subject: [PATCH 08/18] update redq objective --- torchrl/objectives/redq.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index e34fb1bb305..dd124dfc5ec 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -25,6 +25,7 @@ _vmap_func, default_value_kwargs, distance_loss, + RANDOM_MODULE_LIST, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -234,6 +235,7 @@ class _AcceptedKeys: "next.state_value", "target_value", ] + _vmap_randomness = None def __init__( self, @@ -255,7 +257,6 @@ def __init__( priority_key: str = None, separate_losses: bool = False, ): - super().__init__() self._in_keys = None self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -319,9 +320,11 @@ def __init__( self.gamma = gamma self._vmap_qvalue_network00 = _vmap_func( - self.qvalue_network, randomness="different" + self.qvalue_network, randomness=self.vmap_randomness + ) + self._vmap_getdist = _vmap_func( + self.actor_network, func="get_dist_params", randomess=self.vmap_randomness ) - self._vmap_getdist = _vmap_func(self.actor_network, func="get_dist_params") @property def target_entropy(self): @@ -406,6 +409,28 @@ def in_keys(self): def in_keys(self, values): self._in_keys = values + @property + def vmap_randomness(self): + if self._vmap_randomness is None: + do_break = False + for val in self.__dict__.values(): + if isinstance(val, torch.nn.Module): + for module in val.modules(): + if isinstance(module, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" + do_break = True + break + if do_break: + # double break + break + else: + self._vmap_randomness = "error" + + return self._vmap_randomness + + def set_vmap_randomness(self, value): + self._vmap_randomness = value + @property @_cache_values def _cached_detach_qvalue_network_params(self): From bc03b928c093a8d06761ccb00d51de907ed9a29b Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 13 Dec 2023 10:36:21 +0100 Subject: [PATCH 09/18] update conti cql --- torchrl/objectives/cql.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index f16ffed2b79..34bb7e00f96 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -30,6 +30,7 @@ _vmap_func, default_value_kwargs, distance_loss, + RANDOM_MODULE_LIST, ValueEstimators, ) @@ -230,6 +231,7 @@ class _AcceptedKeys: default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 + _vmap_randomness = None def __init__( self, @@ -336,10 +338,10 @@ def __init__( ) self._vmap_qvalue_networkN0 = _vmap_func( - self.qvalue_network, (None, 0), randomness="different" + self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) self._vmap_qvalue_network00 = _vmap_func( - self.qvalue_network, randomness="different" + self.qvalue_network, randomness=self.vmap_randomness ) @property @@ -475,6 +477,28 @@ def out_keys(self): def out_keys(self, values): self._out_keys = values + @property + def vmap_randomness(self): + if self._vmap_randomness is None: + do_break = False + for val in self.__dict__.values(): + if isinstance(val, torch.nn.Module): + for module in val.modules(): + if isinstance(module, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" + do_break = True + break + if do_break: + # double break + break + else: + self._vmap_randomness = "error" + + return self._vmap_randomness + + def set_vmap_randomness(self, value): + self._vmap_randomness = value + @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: shape = None From 9aa65ed1b51c37db51636140f03f695884888662 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Dec 2023 10:31:54 +0100 Subject: [PATCH 10/18] move vmap_randomness to LossModule --- torchrl/objectives/common.py | 27 ++++++++++++++++++-- torchrl/objectives/cql.py | 24 ------------------ torchrl/objectives/iql.py | 24 ------------------ torchrl/objectives/redq.py | 24 ------------------ torchrl/objectives/sac.py | 49 ------------------------------------ torchrl/objectives/td3.py | 24 ------------------ torchrl/objectives/utils.py | 7 ------ 7 files changed, 25 insertions(+), 154 deletions(-) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 00ba8cf456a..04a7708e7db 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -16,10 +16,10 @@ from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams from torch import nn from torch.nn import Parameter - from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.objectives.utils import ValueEstimators + +from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators from torchrl.objectives.value import ValueEstimatorBase @@ -81,6 +81,7 @@ class _AcceptedKeys: pass + _vmap_randomness = None default_value_estimator: ValueEstimators = None SEP = "." TARGET_NET_WARNING = ( @@ -429,6 +430,28 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams return self + @property + def vmap_randomness(self): + if self._vmap_randomness is None: + do_break = False + for val in self.__dict__.values(): + if isinstance(val, torch.nn.Module): + for module in val.modules(): + if isinstance(module, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" + do_break = True + break + if do_break: + # double break + break + else: + self._vmap_randomness = "error" + + return self._vmap_randomness + + def set_vmap_randomness(self, value): + self._vmap_randomness = value + @staticmethod def _make_meta_params(param): is_param = isinstance(param, nn.Parameter) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 34bb7e00f96..e5ef74bbb35 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -30,7 +30,6 @@ _vmap_func, default_value_kwargs, distance_loss, - RANDOM_MODULE_LIST, ValueEstimators, ) @@ -231,7 +230,6 @@ class _AcceptedKeys: default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 - _vmap_randomness = None def __init__( self, @@ -477,28 +475,6 @@ def out_keys(self): def out_keys(self, values): self._out_keys = values - @property - def vmap_randomness(self): - if self._vmap_randomness is None: - do_break = False - for val in self.__dict__.values(): - if isinstance(val, torch.nn.Module): - for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): - self._vmap_randomness = "different" - do_break = True - break - if do_break: - # double break - break - else: - self._vmap_randomness = "error" - - return self._vmap_randomness - - def set_vmap_randomness(self, value): - self._vmap_randomness = value - @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: shape = None diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 1dab2cdf8b5..72fdd46e289 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -20,7 +20,6 @@ _vmap_func, default_value_kwargs, distance_loss, - RANDOM_MODULE_LIST, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -222,7 +221,6 @@ class _AcceptedKeys: "loss_value", "entropy", ] - _vmap_randomness = None def __init__( self, @@ -323,28 +321,6 @@ def in_keys(self): def in_keys(self, values): self._in_keys = values - @property - def vmap_randomness(self): - if self._vmap_randomness is None: - do_break = False - for val in self.__dict__.values(): - if isinstance(val, torch.nn.Module): - for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): - self._vmap_randomness = "different" - do_break = True - break - if do_break: - # double break - break - else: - self._vmap_randomness = "error" - - return self._vmap_randomness - - def set_vmap_randomness(self, value): - self._vmap_randomness = value - @staticmethod def loss_value_diff(diff, expectile=0.8): """Loss function for iql expectile value difference.""" diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index dd124dfc5ec..9436b4e9b61 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -25,7 +25,6 @@ _vmap_func, default_value_kwargs, distance_loss, - RANDOM_MODULE_LIST, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -235,7 +234,6 @@ class _AcceptedKeys: "next.state_value", "target_value", ] - _vmap_randomness = None def __init__( self, @@ -409,28 +407,6 @@ def in_keys(self): def in_keys(self, values): self._in_keys = values - @property - def vmap_randomness(self): - if self._vmap_randomness is None: - do_break = False - for val in self.__dict__.values(): - if isinstance(val, torch.nn.Module): - for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): - self._vmap_randomness = "different" - do_break = True - break - if do_break: - # double break - break - else: - self._vmap_randomness = "error" - - return self._vmap_randomness - - def set_vmap_randomness(self, value): - self._vmap_randomness = value - @property @_cache_values def _cached_detach_qvalue_network_params(self): diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 49f4126e710..171e0270435 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -29,7 +29,6 @@ _vmap_func, default_value_kwargs, distance_loss, - RANDOM_MODULE_LIST, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -258,7 +257,6 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" - _vmap_randomness = None default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -417,7 +415,6 @@ def target_entropy(self): isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 ): - action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape else: action_container_shape = action_spec.shape @@ -536,28 +533,6 @@ def out_keys(self): def out_keys(self, values): self._out_keys = values - @property - def vmap_randomness(self): - if self._vmap_randomness is None: - do_break = False - for val in self.__dict__.values(): - if isinstance(val, torch.nn.Module): - for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): - self._vmap_randomness = "different" - do_break = True - break - if do_break: - # double break - break - else: - self._vmap_randomness = "error" - - return self._vmap_randomness - - def set_vmap_randomness(self, value): - self._vmap_randomness = value - @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: shape = None @@ -781,7 +756,6 @@ def _value_loss( return loss_value, {} def _alpha_loss(self, log_prob: Tensor) -> Tensor: - if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) @@ -972,7 +946,6 @@ class _AcceptedKeys: default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 delay_actor: bool = False - _vmap_randomness = None out_keys = [ "loss_actor", "loss_qvalue", @@ -1114,28 +1087,6 @@ def in_keys(self): def in_keys(self, values): self._in_keys = values - @property - def vmap_randomness(self): - if self._vmap_randomness is None: - do_break = False - for val in self.__dict__.values(): - if isinstance(val, torch.nn.Module): - for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): - self._vmap_randomness = "different" - do_break = True - break - if do_break: - # double break - break - else: - self._vmap_randomness = "error" - - return self._vmap_randomness - - def set_vmap_randomness(self, value): - self._vmap_randomness = value - @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: shape = None diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index b2b1f67a511..8b9548a952b 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -22,7 +22,6 @@ _vmap_func, default_value_kwargs, distance_loss, - RANDOM_MODULE_LIST, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -202,7 +201,6 @@ class _AcceptedKeys: "next_state_value", "target_value", ] - _vmap_randomness = None def __init__( self, @@ -348,28 +346,6 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - @property - def vmap_randomness(self): - if self._vmap_randomness is None: - do_break = False - for val in self.__dict__.values(): - if isinstance(val, torch.nn.Module): - for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): - self._vmap_randomness = "different" - do_break = True - break - if do_break: - # double break - break - else: - self._vmap_randomness = "error" - - return self._vmap_randomness - - def set_vmap_randomness(self, value): - self._vmap_randomness = value - def actor_loss(self, tensordict): tensordict_actor_grad = tensordict.select(*self.actor_network.in_keys) with self.actor_network_params.to_module(self.actor_network): diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 33b03871aa5..4caff20e85b 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -35,13 +35,6 @@ nn.Dropout3d, nn.AlphaDropout, nn.FeatureAlphaDropout, - nn.SyncBatchNorm, - nn.GroupNorm, - nn.LayerNorm, - nn.LocalResponseNorm, - nn.InstanceNorm1d, - nn.InstanceNorm2d, - nn.InstanceNorm3d, ) From 3ff3ef86e205c7149784d633ee8d12c3fc3483eb Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Dec 2023 10:44:42 +0100 Subject: [PATCH 11/18] fix --- torchrl/objectives/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 4caff20e85b..8079b00856f 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import functools +import re import warnings from enum import Enum from typing import Iterable, Optional, Union @@ -500,9 +501,9 @@ def decorated_module(*module_args_params): return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 except RuntimeError as err: - if "vmap: called random operation while in randomness error mode" in str( - err - ): # better to use re.match here but anyway + if re.match( + r"vmap: called random operation while in randomness error mode", str(err) + ): raise RuntimeError( - "Please use loss_module.set_vmap_randomness to handle random operations during vmap." + "Please use .set_vmap_randomness('different') to handle random operations during vmap." ) from err From febf277397aeefe46750188b4445fbd1a770c949 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Dec 2023 10:47:30 +0100 Subject: [PATCH 12/18] fix --- torchrl/objectives/iql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 72fdd46e289..cac49f3edc6 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -287,7 +287,7 @@ def __init__( warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma self._vmap_qvalue_networkN0 = _vmap_func( - self.qvalue_network, (None, 0), randomness="different" + self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) @property From 3af72c522377ca496fcea102a179b065607a5faa Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Dec 2023 11:05:24 +0100 Subject: [PATCH 13/18] add test example --- test/test_cost.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 8bae683c5d5..4c4364c5bd1 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1803,12 +1803,17 @@ def _create_mock_actor( device="cpu", in_keys=None, out_keys=None, + dropout=0.0, ): # Actor action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - module = nn.Linear(obs_dim, action_dim) + module = nn.Sequential( + nn.Linear(obs_dim, obs_dim), + nn.Dropout(dropout), + nn.Linear(obs_dim, action_dim), + ) actor = Actor( spec=action_spec, module=module, in_keys=in_keys, out_keys=out_keys ) @@ -1984,6 +1989,7 @@ def _create_seq_mock_data_td3( @pytest.mark.parametrize("noise_clip", [0.1, 1.0]) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("use_action_spec", [True, False]) + @pytest.mark.parametrize("dropout", [0.0, 0.1]) def test_td3( self, delay_actor, @@ -1993,9 +1999,10 @@ def test_td3( noise_clip, td_est, use_action_spec, + dropout, ): torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor(device=device, dropout=dropout) value = self._create_mock_value(device=device) td = self._create_mock_data_td3(device=device) if use_action_spec: @@ -6056,7 +6063,7 @@ def zero_param(p): p.grad = None loss_objective.backward() named_parameters = loss_fn.named_parameters() - for (name, other_p) in named_parameters: + for name, other_p in named_parameters: p = params.get(tuple(name.split("."))) assert other_p.shape == p.shape assert other_p.dtype == p.dtype @@ -11118,7 +11125,6 @@ def test_set_deprecated_keys(self, adv, kwargs): ) with pytest.warns(DeprecationWarning): - if adv is VTrace: actor_net = TensorDictModule( nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] From f12282577959513d75ca9f511b4b3c40e41137e5 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 27 Dec 2023 12:55:02 +0100 Subject: [PATCH 14/18] fix --- torchrl/objectives/redq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 3a9a7fc699d..d76f76ddc41 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -321,7 +321,7 @@ def __init__( self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_getdist = _vmap_func( - self.actor_network, func="get_dist_params", randomess=self.vmap_randomness + self.actor_network, func="get_dist_params", randomness=self.vmap_randomness ) @property From d85d63baac451cc136cedfa1feb2512cfaa96497 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 3 Jan 2024 09:12:05 +0100 Subject: [PATCH 15/18] add vmap randomness test --- test/test_cost.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index b64641132bd..660a735cc90 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -120,6 +120,7 @@ from torchrl.objectives.redq import REDQLoss from torchrl.objectives.reinforce import ReinforceLoss from torchrl.objectives.utils import ( + _vmap_func, HardUpdate, hold_out_net, SoftUpdate, @@ -233,6 +234,38 @@ def set_advantage_keys_through_loss_test( ) +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("vmap_randomness", (None, "different", "same")) +@pytest.mark.parametrize("dropout", (0.0, 0.1)) +def test_loss_vmap_random(device, vmap_randomness, dropout): + class VmapTestLoss(LossModule): + def __init__(self): + super().__init__() + layers = [nn.Linear(4, 4), nn.ReLU()] + if dropout > 0.0: + layers.append(nn.Dropout(dropout)) + layers.append(nn.Linear(4, 4)) + net = nn.Sequential(*layers).to(device) + model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"]) + self.convert_to_functional(model, "model", expand_dim=4) + self.vmap_model = _vmap_func( + self.model, (None, 0), randomness=self.vmap_randomness + ) + + def forward(self, td): + out = self.vmap_model(td, self.model_params) + return {"loss": out["action"].mean()} + + loss_module = VmapTestLoss() + td = TensorDict({"obs": torch.randn(3, 4).to(device)}, [3]) + + # If user sets vmap randomness to a specific value + if vmap_randomness in ("different", "same") and dropout > 0.0: + loss_module.set_vmap_randomness(vmap_randomness) + + loss_module(td)["loss"] + + class TestDQN(LossModuleTestBase): seed = 0 @@ -4883,7 +4916,6 @@ def test_cql( device, td_est, ): - torch.manual_seed(self.seed) td = self._create_mock_data_cql(device=device) From 92c3e405bd451c8a15f78bc96e508973ad236647 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 4 Jan 2024 10:22:11 +0100 Subject: [PATCH 16/18] update ranodm_module_list --- torchrl/objectives/utils.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 8079b00856f..91305a6a777 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -14,6 +14,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, Tensor from torch.nn import functional as F +from torch.nn.modules import dropout try: from torch import vmap @@ -30,13 +31,7 @@ "run `loss_module.make_value_estimator(ValueEstimators., gamma=val)`." ) -RANDOM_MODULE_LIST = ( - nn.Dropout, - nn.Dropout2d, - nn.Dropout3d, - nn.AlphaDropout, - nn.FeatureAlphaDropout, -) +RANDOM_MODULE_LIST = (dropout._DropoutNd,) class ValueEstimators(Enum): From 7cdc1e6c95355be885e11bd6c628da62fe81e8ad Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 4 Jan 2024 12:11:44 +0100 Subject: [PATCH 17/18] add fail case for vmap --- test/test_cost.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 660a735cc90..715a1061062 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -235,7 +235,7 @@ def set_advantage_keys_through_loss_test( @pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("vmap_randomness", (None, "different", "same")) +@pytest.mark.parametrize("vmap_randomness", (None, "different", "same", "error")) @pytest.mark.parametrize("dropout", (0.0, 0.1)) def test_loss_vmap_random(device, vmap_randomness, dropout): class VmapTestLoss(LossModule): @@ -249,7 +249,11 @@ def __init__(self): model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"]) self.convert_to_functional(model, "model", expand_dim=4) self.vmap_model = _vmap_func( - self.model, (None, 0), randomness=self.vmap_randomness + self.model, + (None, 0), + randomness="error" + if vmap_randomness == "error" + else self.vmap_randomness, ) def forward(self, td): @@ -262,8 +266,14 @@ def forward(self, td): # If user sets vmap randomness to a specific value if vmap_randomness in ("different", "same") and dropout > 0.0: loss_module.set_vmap_randomness(vmap_randomness) + loss_module(td)["loss"] + # Fail case + elif vmap_randomness == "error" and dropout > 0.0: + with pytest.raises(RuntimeError): + loss_module(td)["loss"] - loss_module(td)["loss"] + else: + loss_module(td)["loss"] class TestDQN(LossModuleTestBase): From 1bfcc277b8d7111550c7e9d3cfb0705530ef33b2 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 8 Jan 2024 09:46:57 +0100 Subject: [PATCH 18/18] update vmap fail case test --- test/test_cost.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 715a1061062..dc9b75e7d87 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -8,6 +8,7 @@ import functools import itertools import operator +import re import warnings from copy import deepcopy from dataclasses import asdict, dataclass @@ -266,14 +267,18 @@ def forward(self, td): # If user sets vmap randomness to a specific value if vmap_randomness in ("different", "same") and dropout > 0.0: loss_module.set_vmap_randomness(vmap_randomness) - loss_module(td)["loss"] # Fail case elif vmap_randomness == "error" and dropout > 0.0: - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError) as exc_info: loss_module(td)["loss"] - else: - loss_module(td)["loss"] + # Accessing cause of the caught exception + cause = exc_info.value.__cause__ + assert re.match( + r"vmap: called random operation while in randomness error mode", str(cause) + ) + return + loss_module(td)["loss"] class TestDQN(LossModuleTestBase):