diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 8a093c8f0ac..dd922372cbb 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -57,7 +57,7 @@ ActorCriticOperator, ActorValueOperator, NoisyLinear, - NormalParamWrapper, + NormalParamExtractor, SafeModule, SafeSequential, ) @@ -483,10 +483,12 @@ def make_redq_model( } if not gSDE: - actor_net = NormalParamWrapper( + actor_net = nn.Sequential( actor_net, - scale_mapping=f"biased_softplus_{default_policy_scale}", - scale_lb=cfg.network.scale_lb, + NormalParamExtractor( + scale_mapping=f"biased_softplus_{default_policy_scale}", + scale_lb=cfg.network.scale_lb, + ), ) actor_module = SafeModule( actor_net, diff --git a/test/test_cost.py b/test/test_cost.py index 090b32ac8e5..871d9170aa1 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -72,11 +72,7 @@ SafeSequential, WorldModelWrapper, ) -from torchrl.modules.distributions.continuous import ( - NormalParamWrapper, - TanhDelta, - TanhNormal, -) +from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal from torchrl.modules.models.model_based import ( DreamerActor, ObsDecoder, @@ -3462,7 +3458,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -4372,7 +4368,7 @@ def _create_mock_actor( ): # Actor action_spec = OneHotDiscreteTensorSpec(action_dim) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) actor = ProbabilisticActor( spec=action_spec, @@ -4960,7 +4956,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -5655,7 +5651,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -5763,7 +5759,9 @@ def forward(self, obs): class ActorClass(nn.Module): def __init__(self): super().__init__() - self.linear = NormalParamWrapper(nn.Linear(hidden_dim, 2 * action_dim)) + self.linear = nn.Sequential( + nn.Linear(hidden_dim, 2 * action_dim), NormalParamExtractor() + ) def forward(self, hidden): return self.linear(hidden) @@ -6598,7 +6596,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=["observation"], out_keys=["loc", "scale"] ) @@ -7556,7 +7554,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -7593,8 +7591,8 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) base_layer = nn.Linear(obs_dim, 5) - net = NormalParamWrapper( - nn.Sequential(base_layer, nn.Linear(5, 2 * action_dim)) + net = nn.Sequential( + base_layer, nn.Linear(5, 2 * action_dim), NormalParamExtractor() ) module = TensorDictModule( net, in_keys=["observation"], out_keys=["loc", "scale"] @@ -8447,7 +8445,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -9144,7 +9142,7 @@ def test_reinforce_value_net( batch = 4 gamma = 0.9 value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) - net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) module = TensorDictModule( net, in_keys=["observation"], out_keys=["loc", "scale"] ) @@ -9254,7 +9252,7 @@ def test_reinforce_tensordict_keys(self, td_est): n_obs = 3 n_act = 5 value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) - net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) module = TensorDictModule( net, in_keys=["observation"], out_keys=["loc", "scale"] ) @@ -9448,7 +9446,7 @@ def test_reinforce_notensordict( n_act = 5 batch = 4 value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=[observation_key]) - net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -10054,7 +10052,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=["observation"], out_keys=["loc", "scale"] ) @@ -10286,7 +10284,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule(net, in_keys=["observation"], out_keys=["param"]) actor = ProbabilisticActor( module=module, @@ -10479,7 +10477,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -11288,7 +11286,7 @@ def _create_mock_actor( ): # Actor action_spec = OneHotDiscreteTensorSpec(action_dim) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) actor = ProbabilisticActor( spec=action_spec, @@ -13989,7 +13987,7 @@ def test_shared_params(dest, expected_dtype, expected_device): out_keys=["hidden"], ) module_action = TensorDictModule( - NormalParamWrapper(torch.nn.Linear(4, 8)), + nn.Sequential(nn.Linear(4, 8), NormalParamExtractor()), in_keys=["hidden"], out_keys=["loc", "scale"], ) diff --git a/test/test_exploration.py b/test/test_exploration.py index af618f843e2..83ee4bc4220 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -31,10 +31,10 @@ from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv from torchrl.envs.utils import set_exploration_type from torchrl.modules import SafeModule, SafeSequential -from torchrl.modules.distributions import TanhNormal -from torchrl.modules.distributions.continuous import ( +from torchrl.modules.distributions import ( IndependentNormal, - NormalParamWrapper, + NormalParamExtractor, + TanhNormal, ) from torchrl.modules.models.exploration import LazygSDEModule from torchrl.modules.tensordict_module.actors import ( @@ -236,7 +236,9 @@ def test_ou( self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0 ): torch.manual_seed(seed) - net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) + net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to( + device + ) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) action_spec = BoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,)) policy = ProbabilisticActor( @@ -308,7 +310,9 @@ def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0 action_spec = ContinuousActionVecMockEnv(device=device).action_spec d_act = action_spec.shape[-1] if probabilistic: - net = NormalParamWrapper(nn.LazyLinear(2 * d_act)).to(device) + net = nn.Sequential(nn.LazyLinear(2 * d_act), NormalParamExtractor()).to( + device + ) module = SafeModule( net, in_keys=["observation"], @@ -449,7 +453,9 @@ def test_additivegaussian_sd( if interface == "module": exploratory_policy = AdditiveGaussianModule(action_spec).to(device) else: - net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) + net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to( + device + ) module = SafeModule( net, in_keys=["observation"], @@ -531,7 +537,9 @@ def test_additivegaussian( pytest.skip("module raises an error if given spec=None") torch.manual_seed(seed) - net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) + net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to( + device + ) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) action_spec = BoundedTensorSpec( -torch.ones(d_act, device=device), @@ -593,7 +601,7 @@ def test_collector(self, device, parallel_spec, interface, seed=0): else: action_spec = ContinuousActionVecMockEnv(device=device).action_spec d_act = action_spec.shape[-1] - net = NormalParamWrapper(nn.LazyLinear(2 * d_act)).to(device) + net = nn.Sequential(nn.LazyLinear(2 * d_act), NormalParamExtractor()).to(device) module = SafeModule( net, in_keys=["observation"], @@ -658,7 +666,7 @@ def test_gsde( else: in_keys = ["observation"] model = torch.nn.LazyLinear(action_dim * 2, device=device) - wrapper = NormalParamWrapper(model) + wrapper = nn.Sequential(model, NormalParamExtractor()) module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"]) distribution_class = TanhNormal distribution_kwargs = {"low": -bound, "high": bound} diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index a6f66291719..38360a464e0 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -34,7 +34,7 @@ LSTMModule, MLP, MultiStepActorWrapper, - NormalParamWrapper, + NormalParamExtractor, OnlineDTActor, ProbabilisticActor, SafeModule, @@ -201,7 +201,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) in_keys = ["in"] net = SafeModule( - module=NormalParamWrapper(net), + module=nn.Sequential(net, NormalParamExtractor()), spec=None, in_keys=in_keys, out_keys=out_keys, @@ -363,7 +363,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): net1 = nn.Linear(3, 4) dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) + net2 = nn.Sequential(net2, NormalParamExtractor()) if spec_type is None: spec = None @@ -474,11 +474,11 @@ def test_sequential_partial(self, stack): net1 = nn.Linear(3, 4) net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) + net2 = nn.Sequential(net2, NormalParamExtractor()) net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) net3 = nn.Linear(4, 4 * param_multiplier) - net3 = NormalParamWrapper(net3) + net3 = nn.Sequential(net3, NormalParamExtractor()) net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) spec = BoundedTensorSpec(-0.1, 0.1, 4) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 0c6505602f3..0a06e5844a0 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -11,6 +11,7 @@ IndependentNormal, MaskedCategorical, MaskedOneHotCategorical, + NormalParamExtractor, NormalParamWrapper, OneHotCategorical, ReparamGradientStrategy, diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index a3c5d0d4774..367765812bb 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -3,8 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from tensordict.nn import NormalParamExtractor + from .continuous import ( - __all__ as _all_continuous, Delta, IndependentNormal, NormalParamWrapper, @@ -13,7 +14,6 @@ TruncatedNormal, ) from .discrete import ( - __all__ as _all_discrete, MaskedCategorical, MaskedOneHotCategorical, OneHotCategorical, @@ -21,6 +21,15 @@ ) distributions_maps = { - distribution_class.lower(): eval(distribution_class) - for distribution_class in _all_continuous + _all_discrete + str(dist).lower(): dist + for dist in ( + Delta, + IndependentNormal, + TanhDelta, + TanhNormal, + TruncatedNormal, + MaskedCategorical, + MaskedOneHotCategorical, + OneHotCategorical, + ) } diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 38d8d1dfd02..fddc2f3415d 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -25,15 +25,6 @@ ) from torchrl.modules.utils import mappings -__all__ = [ - "NormalParamWrapper", - "TanhNormal", - "Delta", - "TanhDelta", - "TruncatedNormal", - "IndependentNormal", -] - # speeds up distribution construction D.Distribution.set_default_validate_args(False) @@ -153,6 +144,10 @@ def __init__( scale_mapping: str = "biased_softplus_1.0", scale_lb: Number = 1e-4, ) -> None: + warnings.warn( + "The NormalParamWrapper class will be deprecated in v0.7 in favor of :class:`~tensordict.nn.NormalParamExtractor`.", + category=DeprecationWarning, + ) super().__init__() self.operator = operator self.scale_mapping = scale_mapping @@ -759,7 +754,10 @@ def mean(self) -> torch.Tensor: raise AttributeError("TanhDelta mean has not analytical form.") -def uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: +def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: if size is None: size = torch.Size([]) return torch.randn_like(dist.sample(size)) + + +uniform_sample_delta = _uniform_sample_delta diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 83b6a8d1fb3..81b7ec1e605 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -206,11 +206,11 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, TanhNormal + >>> from torchrl.modules import ProbabilisticActor, NormalParamExtractor, TanhNormal >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) >>> action_spec = BoundedTensorSpec(shape=torch.Size([4]), ... low=-1, high=1) - >>> module = NormalParamWrapper(torch.nn.Linear(4, 8)) + >>> module = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> tensordict_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) >>> td_module = ProbabilisticActor( ... module=tensordict_module, @@ -1379,7 +1379,7 @@ class ActorValueOperator(SafeSequential): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor, SafeModule - >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamWrapper + >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamExtractor >>> module_hidden = torch.nn.Linear(4, 4) >>> td_module_hidden = SafeModule( ... module=module_hidden, @@ -1387,7 +1387,7 @@ class ActorValueOperator(SafeSequential): ... out_keys=["hidden"], ... ) >>> module_action = TensorDictModule( - ... NormalParamWrapper(torch.nn.Linear(4, 8)), + ... nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()), ... in_keys=["hidden"], ... out_keys=["loc", "scale"], ... ) @@ -1531,14 +1531,14 @@ class ActorCriticOperator(ActorValueOperator): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor - >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamWrapper, MLP + >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamExtractor, MLP >>> module_hidden = torch.nn.Linear(4, 4) >>> td_module_hidden = SafeModule( ... module=module_hidden, ... in_keys=["observation"], ... out_keys=["hidden"], ... ) - >>> module_action = NormalParamWrapper(torch.nn.Linear(4, 8)) + >>> module_action = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> module_action = TensorDictModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"]) >>> td_module_action = ProbabilisticActor( ... module=module_action, @@ -1677,12 +1677,12 @@ class ActorCriticWrapper(SafeSequential): >>> from torchrl.modules import ( ... ActorCriticWrapper, ... ProbabilisticActor, - ... NormalParamWrapper, + ... NormalParamExtractor, ... TanhNormal, ... ValueOperator, ... ) >>> action_module = TensorDictModule( - ... NormalParamWrapper(torch.nn.Linear(4, 8)), + ... nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()), ... in_keys=["observation"], ... out_keys=["loc", "scale"], ... ) diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 28f721ba6a1..41ddb55fb35 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -34,11 +34,11 @@ class SafeSequential(TensorDictSequential, SafeModule): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec - >>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamWrapper + >>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamExtractor >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) >>> spec1 = CompositeSpec(hidden=UnboundedContinuousTensorSpec(4), loc=None, scale=None) - >>> net1 = NormalParamWrapper(torch.nn.Linear(4, 8)) + >>> net1 = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["loc", "scale"]) >>> td_module1 = SafeProbabilisticModule( ... module=module1, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index bedd91e2e56..4a0948e1bca 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -97,14 +97,14 @@ class A2CLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.a2c import A2CLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -148,14 +148,14 @@ class A2CLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.a2c import A2CLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index d68a9fce782..f1e2aa9c532 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -101,14 +101,14 @@ class CQLLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.cql import CQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -161,14 +161,14 @@ class CQLLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.cql import CQLLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 05499cb227d..e76e3438c09 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -99,14 +99,14 @@ class CrossQLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.crossq import CrossQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -157,14 +157,14 @@ class CrossQLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives import CrossQLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 013435c9079..74cfe504e78 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -74,14 +74,14 @@ class IQLLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import IQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -137,14 +137,14 @@ class IQLLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import IQLLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 16e2776805b..eb7f14f43c4 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -152,7 +152,7 @@ class PPOLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss @@ -160,7 +160,7 @@ class PPOLoss(LossModule): >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) - >>> net = NormalParamWrapper(nn.Sequential(base_layer, nn.Linear(5, 2 * n_act))) + >>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -205,14 +205,14 @@ class PPOLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) - >>> net = NormalParamWrapper(nn.Sequential(base_layer, nn.Linear(5, 2 * n_act))) + >>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index db05063535a..1522fd7749e 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -94,14 +94,14 @@ class REDQLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -156,13 +156,13 @@ class REDQLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index d2d387e9a99..3d867b8cb99 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -101,14 +101,14 @@ class ReinforceLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.reinforce import ReinforceLoss >>> from tensordict import TensorDict >>> n_obs, n_act = 3, 5 >>> value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor_net = ProbabilisticActor( ... module, @@ -147,13 +147,13 @@ class ReinforceLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.reinforce import ReinforceLoss >>> n_obs, n_act = 3, 5 >>> value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor_net = ProbabilisticActor( ... module, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 65482a2b876..df444eac053 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -118,14 +118,14 @@ class SACLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -181,14 +181,14 @@ class SACLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -853,8 +853,7 @@ class DiscreteSACLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper - >>> from torchrl.modules.distributions.discrete import OneHotCategorical + >>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import DiscreteSACLoss @@ -911,14 +910,13 @@ class DiscreteSACLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper - >>> from torchrl.modules.distributions.discrete import OneHotCategorical + >>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import DiscreteSACLoss >>> n_act, n_obs = 4, 3 >>> spec = OneHotDiscreteTensorSpec(n_act) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index b0026b0158d..eb1027ad936 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -84,7 +84,7 @@ class TD3Loss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3 import TD3Loss diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index bea101f4038..aa87ea9aa1a 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -95,7 +95,7 @@ class TD3BCLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3_bc import TD3BCLoss diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 99ede5dd56f..9d25da0a4cd 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -605,10 +605,12 @@ def exec_sequence(params, data): ############################################################################### # Probabilistic modules -from torchrl.modules import NormalParamWrapper, TanhNormal +from torchrl.modules import NormalParamExtractor, TanhNormal td = TensorDict({"input": torch.randn(3, 5)}, [3]) -net = NormalParamWrapper(nn.Linear(5, 4)) # splits the output in loc and scale +net = nn.Sequential( + nn.Linear(5, 4), NormalParamExtractor() +) # splits the output in loc and scale module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"]) td_module = ProbabilisticTensorDictSequential( module,