Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Extract primers from modules that contain RNNs #2127

Merged
merged 11 commits into from
May 3, 2024
Next Next commit
get_primers_from_module
  • Loading branch information
albertbou92 committed Apr 28, 2024
commit 221756f3c30a17722d050f8b922ba65ecd30a2e2
10 changes: 8 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,12 +1597,18 @@ def __init__(
if high is not None:
raise TypeError(self.CONFLICTING_KWARGS.format("high", "maximum"))
high = kwargs.pop("maximum")
warnings.warn("Maximum is deprecated since v0.4.0, using high instead.", category=DeprecationWarning)
warnings.warn(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-commit reformatted this

"Maximum is deprecated since v0.4.0, using high instead.",
category=DeprecationWarning,
)
if "minimum" in kwargs:
if low is not None:
raise TypeError(self.CONFLICTING_KWARGS.format("low", "minimum"))
low = kwargs.pop("minimum")
warnings.warn("Minimum is deprecated since v0.4.0, using low instead.", category=DeprecationWarning)
warnings.warn(
"Minimum is deprecated since v0.4.0, using low instead.",
category=DeprecationWarning,
)
domain = kwargs.pop("domain", "continuous")
if len(kwargs):
raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.")
Expand Down
13 changes: 8 additions & 5 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4696,15 +4696,18 @@ def _reset(
spec shape is assumed to match the tensordict's.

"""
shape = (
()
if (not self.parent or self.parent.batch_locked)
else tensordict.batch_size
)
_reset = _get_reset(self.reset_key, tensordict)
if _reset.any():
for key, spec in self.primers.items(True, True):
if spec.shape[: len(tensordict.batch_size)] != tensordict.batch_size:
vmoens marked this conversation as resolved.
Show resolved Hide resolved
expanded_spec = self._expand_shape(spec)
self.primers[key] = spec = expanded_spec
if self.random:
shape = (
()
if (not self.parent or self.parent.batch_locked)
else tensordict.batch_size
)
value = spec.rand(shape)
else:
value = self.default_value[key]
Expand Down
20 changes: 20 additions & 0 deletions torchrl/modules/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import warnings
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

from torchrl.envs import Compose
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved


def get_primers_from_module(module):
"""Get all tensordict primers from all submodules of a module."""
primers = []

def make_primers(submodule):
if hasattr(submodule, "make_tensordict_primer"):
primers.append(submodule.make_tensordict_primer())

module.apply(make_primers)
if not primers:
raise warnings.warn("No primers found in the module.")
vmoens marked this conversation as resolved.
Show resolved Hide resolved
elif len(primers) == 1:
return primers[0]
else:
return Compose(primers)
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved