Skip to content

Commit

Permalink
[Refactor, Doc] Refactor refs to SafeModule to TensorDictModule unles…
Browse files Browse the repository at this point in the history
…s necessary (pytorch#986)
  • Loading branch information
vmoens authored Mar 24, 2023
1 parent 8e03f6b commit 4bf1b37
Show file tree
Hide file tree
Showing 35 changed files with 559 additions and 359 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,10 @@ And it is `functorch` and `torch.compile` compatible!
in_keys=["hidden"],
out_keys=["loc", "scale"],
)
# Use a SafeProbabilisticSequential to combine the SafeModule with a
# Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a
# SafeProbabilisticModule, indicating how to build the
# torch.distribution.Distribution object and what to do with it
policy_module = SafeProbabilisticSequential( # stochastic policy
policy_module = SafeProbabilisticTensorDictSequential( # stochastic policy
policy_module,
SafeProbabilisticModule(
in_keys=["loc", "scale"],
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ TensorDict modules
EGreedyWrapper
OrnsteinUhlenbeckProcessWrapper
SafeProbabilisticModule
SafeProbabilisticSequential
SafeProbabilisticTensorDictSequential
SafeSequential
WorldModelWrapper
common.is_tensordict_compatible
Expand Down
4 changes: 2 additions & 2 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
QValueActor,
SafeModule,
SafeProbabilisticModule,
SafeProbabilisticSequential,
SafeProbabilisticTensorDictSequential,
SafeSequential,
WorldModelWrapper,
)
Expand Down Expand Up @@ -2937,7 +2937,7 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
num_cells=mlp_num_units,
activation_class=nn.ELU,
)
actor_model = SafeProbabilisticSequential(
actor_model = SafeProbabilisticTensorDictSequential(
SafeModule(
actor_module,
in_keys=["state", "belief"],
Expand Down
34 changes: 18 additions & 16 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch
from tensordict import TensorDict
from tensordict.nn.functional_modules import make_functional
from tensordict.nn import make_functional, TensorDictModule
from torch import nn
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
Expand All @@ -23,13 +23,13 @@
)
from torchrl.modules.tensordict_module.probabilistic import (
SafeProbabilisticModule,
SafeProbabilisticSequential,
SafeProbabilisticTensorDictSequential,
)
from torchrl.modules.tensordict_module.sequence import SafeSequential

_has_functorch = False
try:
from functorch import vmap
from torch import vmap

_has_functorch = True
except ImportError:
Expand Down Expand Up @@ -213,7 +213,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys)
**kwargs,
)

tensordict_module = SafeProbabilisticSequential(net, prob_module)
tensordict_module = SafeProbabilisticTensorDictSequential(net, prob_module)
td = TensorDict({"in": torch.randn(3, 3)}, [3])
with set_exploration_mode(exp_mode):
tensordict_module(td)
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_functional(self, safe, spec_type):
)

td = TensorDict({"in": torch.randn(3, 3)}, [3])
tensordict_module(td, params=params)
tensordict_module(td, params=TensorDict({"module": params}, []))
assert td.shape == torch.Size([3])
assert td.get("out").shape == torch.Size([3, 4])

Expand Down Expand Up @@ -324,7 +324,7 @@ def test_functional_probabilistic(self, safe, spec_type):
**kwargs,
)

tensordict_module = SafeProbabilisticSequential(tdnet, prob_module)
tensordict_module = SafeProbabilisticTensorDictSequential(tdnet, prob_module)
params = make_functional(tensordict_module)

td = TensorDict({"in": torch.randn(3, 3)}, [3])
Expand Down Expand Up @@ -378,7 +378,7 @@ def test_functional_with_buffer(self, safe, spec_type):
)

td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3])
tdmodule(td, params=params)
tdmodule(td, params=TensorDict({"module": params}, []))
assert td.shape == torch.Size([3])
assert td.get("out").shape == torch.Size([3, 32])

Expand Down Expand Up @@ -436,7 +436,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type):
**kwargs,
)

tdmodule = SafeProbabilisticSequential(tdnet, prob_module)
tdmodule = SafeProbabilisticTensorDictSequential(tdnet, prob_module)
params = make_functional(tdmodule)

td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3])
Expand Down Expand Up @@ -574,7 +574,7 @@ def test_vmap_probabilistic(self, safe, spec_type):
**kwargs,
)

tdmodule = SafeProbabilisticSequential(tdnet, prob_module)
tdmodule = SafeProbabilisticTensorDictSequential(tdnet, prob_module)
params = make_functional(tdmodule)

# vmap = True
Expand Down Expand Up @@ -757,7 +757,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy):
safe=False,
**kwargs,
)
tdmodule = SafeProbabilisticSequential(
tdmodule = SafeProbabilisticTensorDictSequential(
tdmodule1, dummy_tdmodule, tdmodule2, prob_module
)

Expand Down Expand Up @@ -905,7 +905,7 @@ def test_functional_probabilistic(self, safe, spec_type):
safe=safe,
**kwargs,
)
tdmodule = SafeProbabilisticSequential(
tdmodule = SafeProbabilisticTensorDictSequential(
tdmodule1, dummy_tdmodule, tdmodule2, prob_module
)

Expand Down Expand Up @@ -1067,7 +1067,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type):
safe=safe,
**kwargs,
)
tdmodule = SafeProbabilisticSequential(
tdmodule = SafeProbabilisticTensorDictSequential(
tdmodule1, dummy_tdmodule, tdmodule2, prob_module
)

Expand Down Expand Up @@ -1248,7 +1248,9 @@ def test_vmap_probabilistic(self, safe, spec_type):
safe=safe,
**kwargs,
)
tdmodule = SafeProbabilisticSequential(tdmodule1, tdmodule2, prob_module)
tdmodule = SafeProbabilisticTensorDictSequential(
tdmodule1, tdmodule2, prob_module
)

params = make_functional(tdmodule)

Expand Down Expand Up @@ -1351,7 +1353,7 @@ def test_sequential_partial(self, stack, functional):
spec=None,
safe=False,
)
tdmodule2 = SafeProbabilisticSequential(
tdmodule2 = SafeProbabilisticTensorDictSequential(
net2,
SafeProbabilisticModule(
in_keys=["loc", "scale"],
Expand All @@ -1361,7 +1363,7 @@ def test_sequential_partial(self, stack, functional):
**kwargs,
),
)
tdmodule3 = SafeProbabilisticSequential(
tdmodule3 = SafeProbabilisticTensorDictSequential(
net3,
SafeProbabilisticModule(
in_keys=["loc", "scale"],
Expand Down Expand Up @@ -1516,7 +1518,7 @@ def forward(self, in_1, in_2):
out_keys=["out_1", "out_2", "out_3"],
)
assert set(ensured_module.in_keys) == {"x"}
assert isinstance(ensured_module, SafeModule)
assert isinstance(ensured_module, TensorDictModule)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions torchrl/envs/model_based/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import numpy as np
import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule

from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.modules.tensordict_module import SafeModule


class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -53,12 +53,12 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta):
>>> import torch.nn as nn
>>> from torchrl.modules import MLP, WorldModelWrapper
>>> world_model = WorldModelWrapper(
... SafeModule(
... TensorDictModule(
... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
... in_keys=["hidden_observation", "action"],
... out_keys=["hidden_observation"],
... ),
... SafeModule(
... TensorDictModule(
... nn.Linear(4, 1),
... in_keys=["hidden_observation"],
... out_keys=["reward"],
Expand Down Expand Up @@ -113,7 +113,7 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta):

def __init__(
self,
world_model: SafeModule,
world_model: TensorDictModule,
params: Optional[List[torch.Tensor]] = None,
buffers: Optional[List[torch.Tensor]] = None,
device: DEVICE_TYPING = "cpu",
Expand Down
6 changes: 3 additions & 3 deletions torchrl/envs/model_based/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@
import numpy as np
import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule

from torchrl.data.tensor_specs import CompositeSpec
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs import EnvBase
from torchrl.envs.model_based import ModelBasedEnvBase
from torchrl.modules.tensordict_module import SafeModule


class DreamerEnv(ModelBasedEnvBase):
"""Dreamer simulation environment."""

def __init__(
self,
world_model: SafeModule,
world_model: TensorDictModule,
prior_shape: Tuple[int, ...],
belief_shape: Tuple[int, ...],
obs_decoder: SafeModule = None,
obs_decoder: TensorDictModule = None,
device: DEVICE_TYPING = "cpu",
dtype: Optional[Union[torch.dtype, np.dtype]] = None,
batch_size: Optional[torch.Size] = None,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
QValueActor,
SafeModule,
SafeProbabilisticModule,
SafeProbabilisticSequential,
SafeProbabilisticTensorDictSequential,
SafeSequential,
ValueOperator,
WorldModelWrapper,
Expand Down
6 changes: 3 additions & 3 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class IndependentNormal(D.Independent):
Default is 5.0
tanh_loc (bool, optional): if True, the above formula is used for the location scaling, otherwise the raw value
tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw value
is kept.
Default is :obj:`True`;
"""
Expand Down Expand Up @@ -180,7 +180,7 @@ class TruncatedNormal(D.Independent):
min (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0;
max (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0;
tanh_loc (bool, optional): if True, the above formula is used for the location scaling, otherwise the raw value
tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw value
is kept.
Default is :obj:`True`;
"""
Expand Down Expand Up @@ -298,7 +298,7 @@ class TanhNormal(D.TransformedDistribution):
max (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0;
event_dims (int, optional): number of dimensions describing the action.
Default is 1;
tanh_loc (bool, optional): if True, the above formula is used for the location scaling, otherwise the raw
tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw
value is kept. Default is :obj:`True`;
"""

Expand Down
37 changes: 19 additions & 18 deletions torchrl/modules/models/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class NoisyLinear(nn.Linear):
Args:
in_features (int): input features dimension
out_features (int): out features dimension
bias (bool): if True, a bias term will be added to the matrix multiplication: Ax + b.
bias (bool): if ``True``, a bias term will be added to the matrix multiplication: Ax + b.
default: True
device (DEVICE_TYPING, optional): device of the layer.
default: "cpu"
Expand Down Expand Up @@ -154,7 +154,7 @@ class NoisyLazyLinear(LazyModuleMixin, NoisyLinear):
Args:
out_features (int): out features dimension
bias (bool): if True, a bias term will be added to the matrix multiplication: Ax + b.
bias (bool): if ``True``, a bias term will be added to the matrix multiplication: Ax + b.
default: True
device (DEVICE_TYPING, optional): device of the layer.
dtype (torch.dtype, optional): dtype of the parameters.
Expand Down Expand Up @@ -264,37 +264,38 @@ class gSDEModule(nn.Module):
Examples:
>>> from tensordict import TensorDict
>>> from torchrl.modules import SafeModule, SafeSequential, ProbabilisticActor, TanhNormal
>>> from torchrl.modules import ProbabilisticActor, TanhNormal
>>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictSequential
>>> batch, state_dim, action_dim = 3, 7, 5
>>> model = nn.Linear(state_dim, action_dim)
>>> deterministic_policy = SafeModule(model, in_keys=["obs"], out_keys=["action"])
>>> stochatstic_part = SafeModule(
>>> deterministic_policy = TensorDictModule(model, in_keys=["obs"], out_keys=["action"])
>>> stochastic_part = TensorDictModule(
... gSDEModule(action_dim, state_dim),
... in_keys=["action", "obs", "_eps_gSDE"],
... out_keys=["loc", "scale", "action", "_eps_gSDE"])
>>> stochatstic_part = ProbabilisticActor(stochatstic_part,
... dist_in_keys=["loc", "scale"],
>>> stochastic_part = ProbabilisticActor(stochastic_part,
... in_keys=["loc", "scale"],
... distribution_class=TanhNormal)
>>> stochatstic_policy = SafeSequential(deterministic_policy, stochatstic_part)
>>> stochastic_policy = ProbabilisticTensorDictSequential(deterministic_policy, *stochastic_part)
>>> tensordict = TensorDict({'obs': torch.randn(state_dim), '_epx_gSDE': torch.zeros(1)}, [])
>>> _ = stochatstic_policy(tensordict)
>>> _ = stochastic_policy(tensordict)
>>> print(tensordict)
TensorDict(
fields={
obs: Tensor(torch.Size([7]), dtype=torch.float32),
_epx_gSDE: Tensor(torch.Size([1]), dtype=torch.float32),
action: Tensor(torch.Size([5]), dtype=torch.float32),
loc: Tensor(torch.Size([5]), dtype=torch.float32),
scale: Tensor(torch.Size([5]), dtype=torch.float32),
_eps_gSDE: Tensor(torch.Size([5, 7]), dtype=torch.float32)},
_eps_gSDE: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.float32, is_shared=False),
_epx_gSDE: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
action: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
obs: Tensor(shape=torch.Size([7]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=cpu,
device=None,
is_shared=False)
>>> action_first_call = tensordict.get("action").clone()
>>> dist, *_ = stochatstic_policy.get_dist(tensordict)
>>> dist = stochastic_policy.get_dist(tensordict)
>>> print(dist)
TanhNormal(loc: torch.Size([5]), scale: torch.Size([5]))
>>> _ = stochatstic_policy(tensordict)
>>> _ = stochastic_policy(tensordict)
>>> action_second_call = tensordict.get("action").clone()
>>> assert (action_second_call == action_first_call).all() # actions are the same
>>> assert (action_first_call != dist.base_dist.base_dist.loc).all() # actions are truly stochastic
Expand Down
8 changes: 4 additions & 4 deletions torchrl/modules/models/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import torch
from packaging import version
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.envs.utils import step_mdp
from torchrl.modules.distributions import NormalParamWrapper
from torchrl.modules.models.models import MLP
from torchrl.modules.tensordict_module.common import SafeModule
from torchrl.modules.tensordict_module.sequence import SafeSequential


Expand Down Expand Up @@ -151,13 +151,13 @@ class RSSMRollout(nn.Module):
Reference: https://arxiv.org/abs/1811.04551
Args:
rssm_prior (SafeModule): Prior network.
rssm_posterior (SafeModule): Posterior network.
rssm_prior (TensorDictModule): Prior network.
rssm_posterior (TensorDictModule): Posterior network.
"""

def __init__(self, rssm_prior: SafeModule, rssm_posterior: SafeModule):
def __init__(self, rssm_prior: TensorDictModule, rssm_posterior: TensorDictModule):
super().__init__()
_module = SafeSequential(rssm_prior, rssm_posterior)
self.in_keys = _module.in_keys
Expand Down
Loading

0 comments on commit 4bf1b37

Please sign in to comment.