From 21b297eb174c26aca1471c8b6733c8d82ea9785c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Jul 2024 08:59:10 +0100 Subject: [PATCH] [BugFix] Reinitialize vmap callers after reset of vmap randomness ghstack-source-id: 0598afe1ed3bb118054f5021b35a8681ce964615 Pull Request resolved: https://github.com/pytorch/rl/pull/2315 --- test/test_cost.py | 3 +++ torchrl/objectives/common.py | 28 +++++++++++++++++----------- torchrl/objectives/cql.py | 4 +++- torchrl/objectives/crossq.py | 5 ++++- torchrl/objectives/deprecated.py | 6 ++++-- torchrl/objectives/iql.py | 5 ++++- torchrl/objectives/redq.py | 2 ++ torchrl/objectives/sac.py | 12 +++++++++--- torchrl/objectives/td3.py | 5 ++++- torchrl/objectives/td3_bc.py | 5 ++++- 10 files changed, 54 insertions(+), 21 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 9ac2bb6b950..090b32ac8e5 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -260,6 +260,9 @@ def __init__(self): net = nn.Sequential(*layers).to(device) model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"]) self.convert_to_functional(model, "model", expand_dim=4) + self._make_vmap() + + def _make_vmap(self): self.vmap_model = _vmap_func( self.model, (None, 0), diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index f2b02825005..c62fd485e28 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -12,7 +12,6 @@ from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple -import torch from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams @@ -541,16 +540,16 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams @property def vmap_randomness(self): if self._vmap_randomness is None: - do_break = False - for val in self.__dict__.values(): - if isinstance(val, torch.nn.Module): - for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): - self._vmap_randomness = "different" - do_break = True - break - if do_break: - # double break + main_modules = list(self.__dict__.values()) + list(self.children()) + modules = ( + module + for main_module in main_modules + if isinstance(main_module, nn.Module) + for module in main_module.modules() + ) + for val in modules: + if isinstance(val, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" break else: self._vmap_randomness = "error" @@ -559,6 +558,7 @@ def vmap_randomness(self): def set_vmap_randomness(self, value): self._vmap_randomness = value + self._make_vmap() @staticmethod def _make_meta_params(param): @@ -570,6 +570,12 @@ def _make_meta_params(param): pd = nn.Parameter(pd, requires_grad=False) return pd + def _make_vmap(self): + """Caches the the vmap callers to reduce the overhead at runtime.""" + raise NotImplementedError( + f"_make_vmap has been called but is not implemented for loss of type {type(self).__name__}." + ) + class _make_target_param: def __init__(self, clone): diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 0d2d869d1e1..d68a9fce782 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -373,14 +373,16 @@ def __init__( "log_alpha_prime", torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)), ) + self._make_vmap() + self.reduction = reduction + def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy(self): diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 355a33a4682..05499cb227d 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -338,10 +338,13 @@ def __init__( self._target_entropy = target_entropy self._action_spec = action_spec + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy_buffer(self): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 9e7115ac601..c1ed8b2cffe 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -221,13 +221,15 @@ def __init__( self._action_spec = action_spec self.target_entropy_buffer = None self.gSDE = gSDE + self._make_vmap() self.reduction = reduction - self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) - if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + def _make_vmap(self): + self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) + @property def target_entropy(self): target_entropy = self.target_entropy_buffer diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 7fab95a95ed..013435c9079 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -318,10 +318,13 @@ def __init__( self.loss_function = loss_function if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction @property def device(self) -> torch.device: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index a0aaa96f7c5..db05063535a 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -336,7 +336,9 @@ def __init__( self.gSDE = gSDE if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 67ab7d7d8ce..65482a2b876 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -403,14 +403,17 @@ def __init__( ) if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) if self._version == 1: self._vmap_qnetwork00 = _vmap_func( - qvalue_network, randomness=self.vmap_randomness + self.qvalue_network, randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy_buffer(self): @@ -1101,10 +1104,13 @@ def __init__( self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index db99237d39e..b0026b0158d 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -317,13 +317,16 @@ def __init__( self.register_buffer("min_action", low) if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_actor_network00 = _vmap_func( self.actor_network, randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index d5529e0b859..bea101f4038 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -331,13 +331,16 @@ def __init__( high = high.to(device) self.register_buffer("max_action", high) self.register_buffer("min_action", low) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_actor_network00 = _vmap_func( self.actor_network, randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: