Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Extract primers from modules that contain RNNs #2127

Merged
merged 11 commits into from
May 3, 2024
Merged
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ Utils
mappings
inv_softplus
biased_softplus
get_primers_from_module

.. currentmodule:: torchrl.modules

Expand Down
19 changes: 14 additions & 5 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4511,6 +4511,12 @@ class TensorDictPrimer(Transform):
tensor([[1., 1., 1.],
[1., 1., 1.]])

Note:
Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
To facilitate this process, the method :func:`~torchrl.models.utils.get_primers_from_module`
automatically checks for required primer transforms in a module and its submodules and
generates them.
vmoens marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
Expand Down Expand Up @@ -4696,15 +4702,18 @@ def _reset(
spec shape is assumed to match the tensordict's.

"""
shape = (
()
if (not self.parent or self.parent.batch_locked)
else tensordict.batch_size
)
_reset = _get_reset(self.reset_key, tensordict)
if _reset.any():
for key, spec in self.primers.items(True, True):
if spec.shape[: len(tensordict.batch_size)] != tensordict.batch_size:
vmoens marked this conversation as resolved.
Show resolved Hide resolved
expanded_spec = self._expand_shape(spec)
self.primers[key] = spec = expanded_spec
if self.random:
shape = (
()
if (not self.parent or self.parent.batch_locked)
else tensordict.batch_size
)
value = spec.rand(shape)
else:
value = self.default_value[key]
Expand Down
14 changes: 14 additions & 0 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,13 @@ class LSTMModule(ModuleBase):
set_recurrent_mode: controls whether the module should be executed in
recurrent mode.

Note:
This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.LSTMModule.make_tensordict_primer`.
If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called
on the parent module to automatically generate the primer transforms required for all submodules, including this one.
vmoens marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
Expand Down Expand Up @@ -1059,6 +1066,13 @@ class GRUModule(ModuleBase):
set_recurrent_mode: controls whether the module should be executed in
recurrent mode.

Note:
This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.GRUModule.make_tensordict_primer`.
If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called
on the parent module to automatically generate the primer transforms required for all submodules, including this one.
vmoens marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ def __instancecheck__(self, instance):


from .mappings import biased_softplus, inv_softplus, mappings
from .utils import get_primers_from_module
66 changes: 66 additions & 0 deletions torchrl/modules/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import warnings


def get_primers_from_module(module):
"""Get all tensordict primers from all submodules of a module.

This method is useful for retrieving primers from modules that are contained within a
parent module.

Args:
module (torch.nn.Module): The parent module.

Returns:
TensorDictPrimer: A TensorDictPrimer Transform.

Example:
>>> from torchrl.modules.utils import get_primers_from_module
>>> from torchrl.modules import GRUModule, MLP
>>> from tensordict.nn import TensorDictModule, TensorDictSequential

# Define a GRU module
vmoens marked this conversation as resolved.
Show resolved Hide resolved
>>> gru_module = GRUModule(
... input_size=10,
... hidden_size=10,
... num_layers=1,
... in_keys=["input", "recurrent_state", "is_init"],
... out_keys=["features", ("next", "recurrent_state")],
... )

# Define a head module
>>> head = TensorDictModule(
vmoens marked this conversation as resolved.
Show resolved Hide resolved
... MLP(
... in_features=10,
... out_features=10,
... num_cells=[],
... ),
... in_keys=["features"],
... out_keys=["output"],
... )

# Create a sequential model
>>> model = TensorDictSequential(gru_module, head)

# Retrieve primers from the model
>>> primers = get_primers_from_module(model)
vmoens marked this conversation as resolved.
Show resolved Hide resolved
"""
primers = []

def make_primers(submodule):
if hasattr(submodule, "make_tensordict_primer"):
primers.append(submodule.make_tensordict_primer())

module.apply(make_primers)
if not primers:
raise warnings.warn("No primers found in the module.")
vmoens marked this conversation as resolved.
Show resolved Hide resolved
elif len(primers) == 1:
return primers[0]
else:
from torchrl.envs.transforms import Compose

return Compose(primers)
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved