Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Extract primers from modules that contain RNNs #2127

Merged
merged 11 commits into from
May 3, 2024
Prev Previous commit
Next Next commit
add cross refs
  • Loading branch information
albertbou92 committed Apr 30, 2024
commit 97696d313c11b89fd921b0aa5edc32cbd9518ff5
6 changes: 6 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4511,6 +4511,12 @@ class TensorDictPrimer(Transform):
tensor([[1., 1., 1.],
[1., 1., 1.]])

Note:
Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
To facilitate this process, the method :func:`~torchrl.models.utils.get_primers_from_module`
automatically checks for required primer transforms in a module and its submodules and
generates them.
vmoens marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
Expand Down
14 changes: 14 additions & 0 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,13 @@ class LSTMModule(ModuleBase):
set_recurrent_mode: controls whether the module should be executed in
recurrent mode.

Note:
This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.LSTMModule.make_tensordict_primer`.
If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called
on the parent module to automatically generate the primer transforms required for all submodules, including this one.
vmoens marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
Expand Down Expand Up @@ -1059,6 +1066,13 @@ class GRUModule(ModuleBase):
set_recurrent_mode: controls whether the module should be executed in
recurrent mode.

Note:
This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.GRUModule.make_tensordict_primer`.
If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called
on the parent module to automatically generate the primer transforms required for all submodules, including this one.
vmoens marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
Expand Down