Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Better doc for make_tensordict_primer #2324

Merged
merged 2 commits into from
Jul 28, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Jul 26, 2024
commit 4b3b09e80584aba02b49ccb4055a2e562f2fd52d
84 changes: 84 additions & 0 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ class LSTMModule(ModuleBase):
Methods:
set_recurrent_mode: controls whether the module should be executed in
recurrent mode.
make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the
recurrent states of the RNN.

.. note:: This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
Expand Down Expand Up @@ -521,6 +523,46 @@ def __init__(
self._recurrent_mode = False

def make_tensordict_primer(self):
"""Makes a tensordict primer for the environment.

A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary
inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
processes and dealt with properly.

Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance
in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
are not registered within the environment specs.

Examples:
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP, LSTMModule
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> assert env.base_env.batch_locked
>>> lstm_module = LSTMModule(
... input_size=env.observation_spec["observation"].shape[-1],
... hidden_size=64,
... in_keys=["observation", "rs_h", "rs_c"],
... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
>>> env = env.append_transform(lstm_module.make_tensordict_primer())
>>> data_collector = SyncDataCollector(
... env,
... policy,
... frames_per_batch=10
... )
>>> for data in data_collector:
... print(data)
... break

"""
from torchrl.envs.transforms.transforms import TensorDictPrimer

def make_tuple(key):
Expand Down Expand Up @@ -1065,6 +1107,8 @@ class GRUModule(ModuleBase):
Methods:
set_recurrent_mode: controls whether the module should be executed in
recurrent mode.
make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the
recurrent states of the RNN.

.. note:: This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
Expand Down Expand Up @@ -1230,6 +1274,46 @@ def __init__(
self._recurrent_mode = False

def make_tensordict_primer(self):
"""Makes a tensordict primer for the environment.

A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary
inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
processes and dealt with properly.

Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance
in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
are not registered within the environment specs.

Examples:
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP, LSTMModule
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> assert env.base_env.batch_locked
>>> gru_module = GRUModule(
... input_size=env.observation_spec["observation"].shape[-1],
... hidden_size=64,
... in_keys=["observation", "rs"],
... out_keys=["intermediate", ("next", "rs")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
>>> env = env.append_transform(gru_module.make_tensordict_primer())
>>> data_collector = SyncDataCollector(
... env,
... policy,
... frames_per_batch=10
... )
>>> for data in data_collector:
... print(data)
... break

"""
from torchrl.envs import TensorDictPrimer

def make_tuple(key):
Expand Down