Skip to content

Commit

Permalink
[Refactor] Deprecate NormalParamWrapper (pytorch#2308)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 25, 2024
1 parent 94abb50 commit 474e837
Show file tree
Hide file tree
Showing 20 changed files with 122 additions and 106 deletions.
10 changes: 6 additions & 4 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
ActorCriticOperator,
ActorValueOperator,
NoisyLinear,
NormalParamWrapper,
NormalParamExtractor,
SafeModule,
SafeSequential,
)
Expand Down Expand Up @@ -483,10 +483,12 @@ def make_redq_model(
}

if not gSDE:
actor_net = NormalParamWrapper(
actor_net = nn.Sequential(
actor_net,
scale_mapping=f"biased_softplus_{default_policy_scale}",
scale_lb=cfg.network.scale_lb,
NormalParamExtractor(
scale_mapping=f"biased_softplus_{default_policy_scale}",
scale_lb=cfg.network.scale_lb,
),
)
actor_module = SafeModule(
actor_net,
Expand Down
44 changes: 21 additions & 23 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,7 @@
SafeSequential,
WorldModelWrapper,
)
from torchrl.modules.distributions.continuous import (
NormalParamWrapper,
TanhDelta,
TanhNormal,
)
from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal
from torchrl.modules.models.model_based import (
DreamerActor,
ObsDecoder,
Expand Down Expand Up @@ -3462,7 +3458,7 @@ def _create_mock_actor(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -4372,7 +4368,7 @@ def _create_mock_actor(
):
# Actor
action_spec = OneHotDiscreteTensorSpec(action_dim)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"])
actor = ProbabilisticActor(
spec=action_spec,
Expand Down Expand Up @@ -4960,7 +4956,7 @@ def _create_mock_actor(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -5655,7 +5651,7 @@ def _create_mock_actor(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -5763,7 +5759,9 @@ def forward(self, obs):
class ActorClass(nn.Module):
def __init__(self):
super().__init__()
self.linear = NormalParamWrapper(nn.Linear(hidden_dim, 2 * action_dim))
self.linear = nn.Sequential(
nn.Linear(hidden_dim, 2 * action_dim), NormalParamExtractor()
)

def forward(self, hidden):
return self.linear(hidden)
Expand Down Expand Up @@ -6598,7 +6596,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -7556,7 +7554,7 @@ def _create_mock_actor(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -7593,8 +7591,8 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
base_layer = nn.Linear(obs_dim, 5)
net = NormalParamWrapper(
nn.Sequential(base_layer, nn.Linear(5, 2 * action_dim))
net = nn.Sequential(
base_layer, nn.Linear(5, 2 * action_dim), NormalParamExtractor()
)
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
Expand Down Expand Up @@ -8447,7 +8445,7 @@ def _create_mock_actor(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -9144,7 +9142,7 @@ def test_reinforce_value_net(
batch = 4
gamma = 0.9
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -9254,7 +9252,7 @@ def test_reinforce_tensordict_keys(self, td_est):
n_obs = 3
n_act = 5
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -9448,7 +9446,7 @@ def test_reinforce_notensordict(
n_act = 5
batch = 4
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=[observation_key])
net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -10054,7 +10052,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -10286,7 +10284,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(net, in_keys=["observation"], out_keys=["param"])
actor = ProbabilisticActor(
module=module,
Expand Down Expand Up @@ -10479,7 +10477,7 @@ def _create_mock_actor(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=["loc", "scale"]
)
Expand Down Expand Up @@ -11288,7 +11286,7 @@ def _create_mock_actor(
):
# Actor
action_spec = OneHotDiscreteTensorSpec(action_dim)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"])
actor = ProbabilisticActor(
spec=action_spec,
Expand Down Expand Up @@ -13989,7 +13987,7 @@ def test_shared_params(dest, expected_dtype, expected_device):
out_keys=["hidden"],
)
module_action = TensorDictModule(
NormalParamWrapper(torch.nn.Linear(4, 8)),
nn.Sequential(nn.Linear(4, 8), NormalParamExtractor()),
in_keys=["hidden"],
out_keys=["loc", "scale"],
)
Expand Down
26 changes: 17 additions & 9 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv
from torchrl.envs.utils import set_exploration_type
from torchrl.modules import SafeModule, SafeSequential
from torchrl.modules.distributions import TanhNormal
from torchrl.modules.distributions.continuous import (
from torchrl.modules.distributions import (
IndependentNormal,
NormalParamWrapper,
NormalParamExtractor,
TanhNormal,
)
from torchrl.modules.models.exploration import LazygSDEModule
from torchrl.modules.tensordict_module.actors import (
Expand Down Expand Up @@ -236,7 +236,9 @@ def test_ou(
self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
):
torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
device
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
action_spec = BoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,))
policy = ProbabilisticActor(
Expand Down Expand Up @@ -308,7 +310,9 @@ def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0
action_spec = ContinuousActionVecMockEnv(device=device).action_spec
d_act = action_spec.shape[-1]
if probabilistic:
net = NormalParamWrapper(nn.LazyLinear(2 * d_act)).to(device)
net = nn.Sequential(nn.LazyLinear(2 * d_act), NormalParamExtractor()).to(
device
)
module = SafeModule(
net,
in_keys=["observation"],
Expand Down Expand Up @@ -449,7 +453,9 @@ def test_additivegaussian_sd(
if interface == "module":
exploratory_policy = AdditiveGaussianModule(action_spec).to(device)
else:
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
device
)
module = SafeModule(
net,
in_keys=["observation"],
Expand Down Expand Up @@ -531,7 +537,9 @@ def test_additivegaussian(
pytest.skip("module raises an error if given spec=None")

torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
device
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
action_spec = BoundedTensorSpec(
-torch.ones(d_act, device=device),
Expand Down Expand Up @@ -593,7 +601,7 @@ def test_collector(self, device, parallel_spec, interface, seed=0):
else:
action_spec = ContinuousActionVecMockEnv(device=device).action_spec
d_act = action_spec.shape[-1]
net = NormalParamWrapper(nn.LazyLinear(2 * d_act)).to(device)
net = nn.Sequential(nn.LazyLinear(2 * d_act), NormalParamExtractor()).to(device)
module = SafeModule(
net,
in_keys=["observation"],
Expand Down Expand Up @@ -658,7 +666,7 @@ def test_gsde(
else:
in_keys = ["observation"]
model = torch.nn.LazyLinear(action_dim * 2, device=device)
wrapper = NormalParamWrapper(model)
wrapper = nn.Sequential(model, NormalParamExtractor())
module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"])
distribution_class = TanhNormal
distribution_kwargs = {"low": -bound, "high": bound}
Expand Down
10 changes: 5 additions & 5 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
LSTMModule,
MLP,
MultiStepActorWrapper,
NormalParamWrapper,
NormalParamExtractor,
OnlineDTActor,
ProbabilisticActor,
SafeModule,
Expand Down Expand Up @@ -201,7 +201,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys)

in_keys = ["in"]
net = SafeModule(
module=NormalParamWrapper(net),
module=nn.Sequential(net, NormalParamExtractor()),
spec=None,
in_keys=in_keys,
out_keys=out_keys,
Expand Down Expand Up @@ -363,7 +363,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy):
net1 = nn.Linear(3, 4)
dummy_net = nn.Linear(4, 4)
net2 = nn.Linear(4, 4 * param_multiplier)
net2 = NormalParamWrapper(net2)
net2 = nn.Sequential(net2, NormalParamExtractor())

if spec_type is None:
spec = None
Expand Down Expand Up @@ -474,11 +474,11 @@ def test_sequential_partial(self, stack):
net1 = nn.Linear(3, 4)

net2 = nn.Linear(4, 4 * param_multiplier)
net2 = NormalParamWrapper(net2)
net2 = nn.Sequential(net2, NormalParamExtractor())
net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"])

net3 = nn.Linear(4, 4 * param_multiplier)
net3 = NormalParamWrapper(net3)
net3 = nn.Sequential(net3, NormalParamExtractor())
net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"])

spec = BoundedTensorSpec(-0.1, 0.1, 4)
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
IndependentNormal,
MaskedCategorical,
MaskedOneHotCategorical,
NormalParamExtractor,
NormalParamWrapper,
OneHotCategorical,
ReparamGradientStrategy,
Expand Down
17 changes: 13 additions & 4 deletions torchrl/modules/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from tensordict.nn import NormalParamExtractor

from .continuous import (
__all__ as _all_continuous,
Delta,
IndependentNormal,
NormalParamWrapper,
Expand All @@ -13,14 +14,22 @@
TruncatedNormal,
)
from .discrete import (
__all__ as _all_discrete,
MaskedCategorical,
MaskedOneHotCategorical,
OneHotCategorical,
ReparamGradientStrategy,
)

distributions_maps = {
distribution_class.lower(): eval(distribution_class)
for distribution_class in _all_continuous + _all_discrete
str(dist).lower(): dist
for dist in (
Delta,
IndependentNormal,
TanhDelta,
TanhNormal,
TruncatedNormal,
MaskedCategorical,
MaskedOneHotCategorical,
OneHotCategorical,
)
}
18 changes: 8 additions & 10 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,6 @@
)
from torchrl.modules.utils import mappings

__all__ = [
"NormalParamWrapper",
"TanhNormal",
"Delta",
"TanhDelta",
"TruncatedNormal",
"IndependentNormal",
]

# speeds up distribution construction
D.Distribution.set_default_validate_args(False)

Expand Down Expand Up @@ -153,6 +144,10 @@ def __init__(
scale_mapping: str = "biased_softplus_1.0",
scale_lb: Number = 1e-4,
) -> None:
warnings.warn(
"The NormalParamWrapper class will be deprecated in v0.7 in favor of :class:`~tensordict.nn.NormalParamExtractor`.",
category=DeprecationWarning,
)
super().__init__()
self.operator = operator
self.scale_mapping = scale_mapping
Expand Down Expand Up @@ -759,7 +754,10 @@ def mean(self) -> torch.Tensor:
raise AttributeError("TanhDelta mean has not analytical form.")


def uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
if size is None:
size = torch.Size([])
return torch.randn_like(dist.sample(size))


uniform_sample_delta = _uniform_sample_delta
Loading

0 comments on commit 474e837

Please sign in to comment.