diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index ee78c68835f..c79e4f42c49 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -373,6 +373,8 @@ algorithms, such as DQN, DDPG or Dreamer. OnlineDTActor RSSMPosterior RSSMPrior + set_recurrent_mode + recurrent_mode Multi-agent-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_cost.py b/test/test_cost.py index 1e157fd7a2f..598b9ba004d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -47,6 +47,7 @@ DistributionalQValueActor, OneHotCategorical, QValueActor, + recurrent_mode, SafeSequential, WorldModelWrapper, ) @@ -15507,6 +15508,29 @@ def test_set_deprecated_keys(self, adv, kwargs): class TestBase: + def test_decorators(self): + class MyLoss(LossModule): + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + assert recurrent_mode() + assert exploration_type() is ExplorationType.DETERMINISTIC + return TensorDict() + + def actor_loss(self, tensordict: TensorDictBase) -> TensorDictBase: + assert recurrent_mode() + assert exploration_type() is ExplorationType.DETERMINISTIC + return TensorDict() + + def something_loss(self, tensordict: TensorDictBase) -> TensorDictBase: + assert recurrent_mode() + assert exploration_type() is ExplorationType.DETERMINISTIC + return TensorDict() + + loss = MyLoss() + loss.forward(None) + loss.actor_loss(None) + loss.something_loss(None) + assert not recurrent_mode() + @pytest.mark.parametrize("expand_dim", [None, 2]) @pytest.mark.parametrize("compare_against", [True, False]) @pytest.mark.skipif(not _has_functorch, reason="functorch is needed for expansion") diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index ec9322500b4..d3b7b7850f4 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -36,6 +36,7 @@ OnlineDTActor, ProbabilisticActor, SafeModule, + set_recurrent_mode, TanhDelta, TanhNormal, ValueOperator, @@ -729,6 +730,31 @@ def test_errs(self): with pytest.raises(KeyError, match="is_init"): lstm_module(td) + @pytest.mark.parametrize("default_val", [False, True, None]) + def test_set_recurrent_mode(self, default_val): + lstm_module = LSTMModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden0", "hidden1"], + out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], + default_recurrent_mode=default_val, + ) + assert lstm_module.recurrent_mode is bool(default_val) + with set_recurrent_mode(True): + assert lstm_module.recurrent_mode + with set_recurrent_mode(False): + assert not lstm_module.recurrent_mode + with set_recurrent_mode("recurrent"): + assert lstm_module.recurrent_mode + with set_recurrent_mode("sequential"): + assert not lstm_module.recurrent_mode + assert lstm_module.recurrent_mode + assert not lstm_module.recurrent_mode + assert lstm_module.recurrent_mode + assert lstm_module.recurrent_mode is bool(default_val) + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_set_temporal_mode(self): lstm_module = LSTMModule( input_size=3, @@ -754,7 +780,8 @@ def test_python_cudnn(self): num_layers=2, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], - ).set_recurrent_mode(True) + default_recurrent_mode=True, + ) obs = torch.rand(10, 20, 3) hidden0 = torch.rand(10, 20, 2, 12) @@ -1109,6 +1136,31 @@ def test_errs(self): with pytest.raises(KeyError, match="is_init"): gru_module(td) + @pytest.mark.parametrize("default_val", [False, True, None]) + def test_set_recurrent_mode(self, default_val): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + default_recurrent_mode=default_val, + ) + assert gru_module.recurrent_mode is bool(default_val) + with set_recurrent_mode(True): + assert gru_module.recurrent_mode + with set_recurrent_mode(False): + assert not gru_module.recurrent_mode + with set_recurrent_mode("recurrent"): + assert gru_module.recurrent_mode + with set_recurrent_mode("sequential"): + assert not gru_module.recurrent_mode + assert gru_module.recurrent_mode + assert not gru_module.recurrent_mode + assert gru_module.recurrent_mode + assert gru_module.recurrent_mode is bool(default_val) + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_set_temporal_mode(self): gru_module = GRUModule( input_size=3, diff --git a/test/test_transforms.py b/test/test_transforms.py index 56a39218f5f..8b2ada8c93a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -10885,7 +10885,8 @@ def _make_gru_module(self, input_size=4, hidden_size=4, device="cpu"): in_keys=["observation", "rhs", "is_init"], out_keys=["output", ("next", "rhs")], device=device, - ).set_recurrent_mode(True) + default_recurrent_mode=True, + ) def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"): return LSTMModule( @@ -10895,7 +10896,8 @@ def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"): in_keys=["observation", "rhs_h", "rhs_c", "is_init"], out_keys=["output", ("next", "rhs_h"), ("next", "rhs_c")], device=device, - ).set_recurrent_mode(True) + default_recurrent_mode=True, + ) def _make_batch(self, batch_size: int = 2, sequence_length: int = 5): observation = torch.randn(batch_size, sequence_length + 1, 4) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 0b4dd03a636..d37aebb862f 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -15,9 +15,11 @@ import os import pickle import sys +import threading import time import traceback import warnings +from contextlib import nullcontext from copy import copy from distutils.util import strtobool from functools import wraps @@ -32,6 +34,11 @@ from tensordict.utils import NestedKey from torch import multiprocessing as mp +try: + from torch.compiler import is_compiling +except ImportError: + from torch._dynamo import is_compiling + LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO") logger = logging.getLogger("torchrl") logger.setLevel(getattr(logging, LOGGING_LEVEL)) @@ -827,3 +834,19 @@ def _make_ordinal_device(device: torch.device): if device.type == "mps" and device.index is None: return torch.device("mps", index=0) return device + + +class _ContextManager: + def __init__(self): + self._mode: Any | None = None + self._lock = threading.Lock() + + def get_mode(self) -> Any | None: + cm = self._lock if not is_compiling() else nullcontext() + with cm: + return self._mode + + def set_mode(self, type: Any | None) -> None: + cm = self._lock if not is_compiling() else nullcontext() + with cm: + self._mode = type diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e02c88c5330..7bdd25591cd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -7411,7 +7411,8 @@ class BurnInTransform(Transform): ... hidden_size=10, ... in_keys=["observation", "hidden"], ... out_keys=["intermediate", ("next", "hidden")], - ... ).set_recurrent_mode(True) + ... default_recurrent_mode=True, + ... ) >>> burn_in_transform = BurnInTransform( ... modules=[gru_module], ... burn_in=5, diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 4cb6366f817..edf90a4e85b 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -80,10 +80,12 @@ QValueActor, QValueHook, QValueModule, + recurrent_mode, SafeModule, SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, SafeSequential, + set_recurrent_mode, TanhModule, ValueOperator, VmapModule, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 202f84fd173..3fb1559833a 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -34,6 +34,15 @@ SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, ) -from .rnn import GRU, GRUCell, GRUModule, LSTM, LSTMCell, LSTMModule +from .rnn import ( + GRU, + GRUCell, + GRUModule, + LSTM, + LSTMCell, + LSTMModule, + recurrent_mode, + set_recurrent_mode, +) from .sequence import SafeSequential from .world_models import WorldModelWrapper diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 6a99e85812b..f4ceb648665 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -4,7 +4,9 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Optional, Tuple +import typing +import warnings +from typing import Any, Optional, Tuple import torch import torch.nn.functional as F @@ -18,6 +20,7 @@ from torch import nn, Tensor from torch.nn.modules.rnn import RNNCellBase +from torchrl._utils import _ContextManager, _DecoratorContextManager from torchrl.data.tensor_specs import Unbounded from torchrl.objectives.value.functional import ( _inv_pad_sequence, @@ -376,6 +379,9 @@ class LSTMModule(ModuleBase): device (torch.device or compatible): the device of the module. lstm (torch.nn.LSTM, optional): an LSTM instance to be wrapped. Exclusive with other nn.LSTM arguments. + default_recurrent_mode (bool, optional): if provided, the recurrent mode if it hasn't been overridden + by the :class:`~torchrl.modules.set_recurrent_mode` context manager / decorator. + Defaults to ``False``. Attributes: recurrent_mode: Returns the recurrent mode of the module. @@ -451,6 +457,7 @@ def __init__( out_keys=None, device=None, lstm=None, + default_recurrent_mode: bool | None = None, ): super().__init__() if lstm is not None: @@ -524,7 +531,7 @@ def __init__( in_keys = in_keys + ["is_init"] self.in_keys = in_keys self.out_keys = out_keys - self._recurrent_mode = False + self._recurrent_mode = default_recurrent_mode def make_python_based(self) -> LSTMModule: """Transforms the LSTM layer in its python-based version. @@ -647,12 +654,15 @@ def make_tuple(key): @property def recurrent_mode(self): - return self._recurrent_mode + rm = recurrent_mode() + if rm is None: + return bool(self._recurrent_mode) + return rm @recurrent_mode.setter def recurrent_mode(self, value): raise RuntimeError( - "recurrent_mode cannot be changed in-place. Call `module.set" + "recurrent_mode cannot be changed in-place. Please use the set_recurrent_mode context manager." ) @property @@ -662,7 +672,7 @@ def temporal_mode(self): ) def set_recurrent_mode(self, mode: bool = True): - """Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). + """[DEPRECATED - use :class:`torchrl.modules.set_recurrent_mode` context manager instead] Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). A copy is created such that the module can be used with divergent behavior in various parts of the code (inference vs training): @@ -692,7 +702,13 @@ def set_recurrent_mode(self, mode: bool = True): ... >>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"]) """ - if mode is self._recurrent_mode: + warnings.warn( + "The lstm.set_recurrent_mode() API is deprecated and will be removed in v0.8. " + "To set the recurent mode, use the :class:`~torchrl.modules.set_recurrent_mode` context manager or " + "the `default_recurrent_mode` keyword argument in the constructor.", + category=DeprecationWarning, + ) + if mode is self.recurrent_mode: return self out = LSTMModule(lstm=self.lstm, in_keys=self.in_keys, out_keys=self.out_keys) out._recurrent_mode = mode @@ -1155,6 +1171,9 @@ class GRUModule(ModuleBase): device (torch.device or compatible): the device of the module. gru (torch.nn.GRU, optional): a GRU instance to be wrapped. Exclusive with other nn.GRU arguments. + default_recurrent_mode (bool, optional): if provided, the recurrent mode if it hasn't been overridden + by the :class:`~torchrl.modules.set_recurrent_mode` context manager / decorator. + Defaults to ``False``. Attributes: recurrent_mode: Returns the recurrent mode of the module. @@ -1256,6 +1275,7 @@ def __init__( out_keys=None, device=None, gru=None, + default_recurrent_mode: bool | None = None, ): super().__init__() if gru is not None: @@ -1326,7 +1346,7 @@ def __init__( in_keys = in_keys + ["is_init"] self.in_keys = in_keys self.out_keys = out_keys - self._recurrent_mode = False + self._recurrent_mode = default_recurrent_mode def make_python_based(self) -> GRUModule: """Transforms the GRU layer in its python-based version. @@ -1444,12 +1464,15 @@ def make_tuple(key): @property def recurrent_mode(self): - return self._recurrent_mode + rm = recurrent_mode() + if rm is None: + return bool(self._recurrent_mode) + return rm @recurrent_mode.setter def recurrent_mode(self, value): raise RuntimeError( - "recurrent_mode cannot be changed in-place. Call `module.set" + "recurrent_mode cannot be changed in-place. Please use the set_recurrent_mode context manager." ) @property @@ -1459,7 +1482,7 @@ def temporal_mode(self): ) def set_recurrent_mode(self, mode: bool = True): - """Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). + """[DEPRECATED - use :class:`torchrl.modules.set_recurrent_mode` context manager instead] Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). A copy is created such that the module can be used with divergent behavior in various parts of the code (inference vs training): @@ -1488,7 +1511,13 @@ def set_recurrent_mode(self, mode: bool = True): ... >>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"]) """ - if mode is self._recurrent_mode: + warnings.warn( + "The gru.set_recurrent_mode() API is deprecated and will be removed in v0.8. " + "To set the recurent mode, use the :class:`~torchrl.modules.set_recurrent_mode` context manager or " + "the `default_recurrent_mode` keyword argument in the constructor.", + category=DeprecationWarning, + ) + if mode is self.recurrent_mode: return self out = GRUModule(gru=self.gru, in_keys=self.in_keys, out_keys=self.out_keys) out._recurrent_mode = mode @@ -1598,3 +1627,57 @@ def _gru( ) out = [y, hidden] return tuple(out) + + +# Recurrent mode manager +recurrent_mode_state_manager = _ContextManager() + + +def recurrent_mode() -> bool | None: + """Returns the current sampling type.""" + return recurrent_mode_state_manager.get_mode() + + +class set_recurrent_mode(_DecoratorContextManager): + """Context manager for setting RNNs recurrent mode. + + Args: + mode (bool, "recurrent" or "stateful"): the recurrent mode to be used within the context manager. + `"recurrent"` leads to `mode=True` and `"stateful"` leads to `mode=False`. + An RNN executed with recurrent_mode "on" assumes that the data comes in time batches, otherwise + it is assumed that each data element in a tensordict is independent of the others. + The default value of this context manager is ``True``. + The default recurrent mode is ``None``, i.e., the default recurrent mode of the RNN is used + (see :class:`~torchrl.modules.LSTMModule` and :class:`~torchrl.modules.GRUModule` constructors). + + .. seealso:: :class:`~torchrl.modules.recurrent_mode``. + + .. note:: All of TorchRL methods are decorated with ``set_recurrent_mode(True)`` by default. + + """ + + def __init__( + self, mode: bool | typing.Literal["recurrent", "sequential"] | None = True + ) -> None: + super().__init__() + if isinstance(mode, str): + if mode.lower() in ("recurrent",): + mode = True + elif mode.lower() in ("sequential",): + mode = False + else: + raise ValueError( + f"Unsupported recurrent mode. Must be a bool, or one of {('recurrent', 'sequential')}" + ) + self.mode = mode + + def clone(self) -> set_recurrent_mode: + # override this method if your children class takes __init__ parameters + return type(self)(self.mode) + + def __enter__(self) -> None: + self.prev = recurrent_mode_state_manager.get_mode() + recurrent_mode_state_manager.set_mode(self.mode) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + recurrent_mode_state_manager.set_mode(self.prev) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 57310a5fc3d..d54671f569b 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -21,6 +21,7 @@ from torch.nn import Parameter from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import set_recurrent_mode from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators from torchrl.objectives.value import ValueEstimatorBase @@ -46,7 +47,9 @@ def _updater_check_forward_prehook(module, *args, **kwargs): def _forward_wrapper(func): @functools.wraps(func) def new_forward(self, *args, **kwargs): - with set_exploration_type(self.deterministic_sampling_mode): + with set_exploration_type(self.deterministic_sampling_mode), set_recurrent_mode( + True + ): return func(self, *args, **kwargs) return new_forward @@ -56,6 +59,9 @@ class _LossMeta(abc.ABCMeta): def __init__(cls, name, bases, attr_dict): super().__init__(name, bases, attr_dict) cls.forward = _forward_wrapper(cls.forward) + for name, value in cls.__dict__.items(): + if not name.startswith("_") and name.endswith("loss"): + setattr(cls, name, _forward_wrapper(value)) class LossModule(TensorDictModuleBase, metaclass=_LossMeta): diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index 8931f483384..58c47f68321 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -317,7 +317,7 @@ # # We can now put things together in a :class:`~tensordict.nn.TensorDictSequential` # -stoch_policy = Seq(feature, lstm, mlp, qval) +policy = Seq(feature, lstm, mlp, qval) ###################################################################### # DQN being a deterministic algorithm, exploration is a crucial part of it. @@ -330,7 +330,7 @@ annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2 ) stoch_policy = TensorDictSequential( - stoch_policy, + policy, exploration_module, ) @@ -338,20 +338,17 @@ # Using the model for the loss # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The model as we've built it is well equipped to be used in sequential settings. +# The model as we've built it is well-equipped to be used in sequential settings. # However, the class :class:`torch.nn.LSTM` can use a cuDNN-optimized backend # to run the RNN sequence faster on GPU device. We would not want to miss # such an opportunity to speed up our training loop! -# To use it, we just need to tell the LSTM module to run on "recurrent-mode" -# when used by the loss. -# As we'll usually want to have two copies of the LSTM module, we do this by -# calling a :meth:`~torchrl.modules.LSTMModule.set_recurrent_mode` method that -# will return a new instance of the LSTM (with shared weights) that will -# assume that the input data is sequential in nature. # -policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval) - -###################################################################### +# By default, torchrl losses will use this when executing any +# :class:`~torchrl.modules.LSTMModule` or :class:`~torchrl.modules.GRUModule` +# forward call. If you need to control this manually, the RNN modules are sensitive +# to a context manager/decorator, :class:`~torchrl.modules.set_recurrent_mode`, +# that handles the behaviour of the underlying RNN module. +# # Because we still have a couple of uninitialized parameters we should # initialize them before creating an optimizer and such. # diff --git a/tutorials/sphinx-tutorials/export.py b/tutorials/sphinx-tutorials/export.py index 0a4390abdfc..48dd8723ffc 100644 --- a/tutorials/sphinx-tutorials/export.py +++ b/tutorials/sphinx-tutorials/export.py @@ -265,10 +265,6 @@ in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", "hidden0", "hidden1"], ) -##################################### -# We set the recurrent mode to ``False`` to allow the module to read inputs one-by-one and not in batch. -# -lstm = lstm.set_recurrent_mode(False) ##################################### # If the LSTM module is not python based but CuDNN (:class:`~torch.nn.LSTM`), the :meth:`~torchrl.modules.LSTMModule.make_python_based`