diff --git a/torchrl/__init__.py b/torchrl/__init__.py index cbd7b66a65e..7a41bf0ab8f 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -27,6 +27,11 @@ except ImportError: __version__ = None +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + _init_extension() try: @@ -69,7 +74,7 @@ def _inv(self): inv = self._inv() if inv is None: inv = _InverseTransform(self) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo_compiling(): self._inv = weakref.ref(inv) return inv @@ -84,7 +89,7 @@ def _inv(self): inv = self._inv() if inv is None: inv = ComposeTransform([p.inv for p in reversed(self.parts)]) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo_compiling(): self._inv = weakref.ref(inv) inv._inv = weakref.ref(self) else: diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index debb836d6fa..8b0d5654b8d 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -36,6 +36,11 @@ # speeds up distribution construction D.Distribution.set_default_validate_args(False) +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + class IndependentNormal(D.Independent): """Implements a Normal distribution with location scaling. @@ -112,7 +117,7 @@ def inv(self): inv = self._inv() if inv is None: inv = _InverseTransform(self) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo_compiling(): self._inv = weakref.ref(inv) return inv @@ -320,7 +325,7 @@ def inv(self): inv = self._inv() if inv is None: inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)]) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo_compiling(): self._inv = weakref.ref(inv) inv._inv = weakref.ref(self) return inv @@ -334,7 +339,7 @@ def inv(self): inv = self._inv() if inv is None: inv = _InverseTransform(self) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo_compiling(): self._inv = weakref.ref(inv) return inv @@ -432,15 +437,13 @@ def __init__( self.high = high if safe_tanh: - if torch.compiler.is_dynamo_compiling(): + if is_dynamo_compiling(): _err_compile_safetanh() t = SafeTanhTransform() else: t = D.TanhTransform() # t = D.TanhTransform() - if torch.compiler.is_dynamo_compiling() or ( - self.non_trivial_max or self.non_trivial_min - ): + if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min): t = _PatchedComposeTransform( [ t, @@ -467,9 +470,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: if self.tanh_loc: loc = (loc / self.upscale).tanh() * self.upscale # loc must be rescaled if tanh_loc - if torch.compiler.is_dynamo_compiling() or ( - self.non_trivial_max or self.non_trivial_min - ): + if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min): loc = loc + (self.high - self.low) / 2 + self.low self.loc = loc self.scale = scale diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 66eae215e54..017394de04b 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -26,6 +26,11 @@ raise err_ft from err from torchrl.envs.utils import step_mdp +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + _GAMMA_LMBDA_DEPREC_ERROR = ( "Passing gamma / lambda parameters through the loss constructor " "is a deprecated feature. To customize your value function, " @@ -460,7 +465,7 @@ def _cache_values(func): @functools.wraps(func) def new_func(self, netname=None): - if torch.compiler.is_dynamo_compiling(): + if is_dynamo_compiling(): if netname is not None: return func(self, netname) else: