forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
73 lines (61 loc) · 2.36 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
def get_primers_from_module(module):
"""Get all tensordict primers from all submodules of a module.
This method is useful for retrieving primers from modules that are contained within a
parent module.
Args:
module (torch.nn.Module): The parent module.
Returns:
TensorDictPrimer: A TensorDictPrimer Transform.
Example:
>>> from torchrl.modules.utils import get_primers_from_module
>>> from torchrl.modules import GRUModule, MLP
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> # Define a GRU module
>>> gru_module = GRUModule(
... input_size=10,
... hidden_size=10,
... num_layers=1,
... in_keys=["input", "recurrent_state", "is_init"],
... out_keys=["features", ("next", "recurrent_state")],
... )
>>> # Define a head module
>>> head = TensorDictModule(
... MLP(
... in_features=10,
... out_features=10,
... num_cells=[],
... ),
... in_keys=["features"],
... out_keys=["output"],
... )
>>> # Create a sequential model
>>> model = TensorDictSequential(gru_module, head)
>>> # Retrieve primers from the model
>>> primers = get_primers_from_module(model)
>>> print(primers)
TensorDictPrimer(primers=Composite(
recurrent_state: UnboundedContinuous(
shape=torch.Size([1, 10]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous), device=None, shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
"""
primers = []
def make_primers(submodule):
if hasattr(submodule, "make_tensordict_primer"):
primers.append(submodule.make_tensordict_primer())
module.apply(make_primers)
if not primers:
warnings.warn("No primers found in the module.")
return
elif len(primers) == 1:
return primers[0]
else:
from torchrl.envs.transforms import Compose
return Compose(*primers)