Skip to content

Commit

Permalink
[Doc,Feature] Better doc for modules and list of kwargs when possible (
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 5, 2024
1 parent 5ad2436 commit fe6c070
Show file tree
Hide file tree
Showing 7 changed files with 814 additions and 397 deletions.
150 changes: 92 additions & 58 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,65 +59,99 @@ def double_prec_fixture():
torch.set_default_dtype(dtype)


@pytest.mark.parametrize("in_features", [3, 10, None])
@pytest.mark.parametrize("out_features", [3, (3, 10)])
@pytest.mark.parametrize("depth, num_cells", [(3, 32), (None, (32, 32, 32))])
@pytest.mark.parametrize(
"activation_class, activation_kwargs",
[(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})],
)
@pytest.mark.parametrize(
"norm_class, norm_kwargs",
[
(nn.LazyBatchNorm1d, {}),
(nn.BatchNorm1d, {"num_features": 32}),
(nn.LayerNorm, {"normalized_shape": 32}),
],
)
@pytest.mark.parametrize("dropout", [0.0, 0.5])
@pytest.mark.parametrize("bias_last_layer", [True, False])
@pytest.mark.parametrize("single_bias_last_layer", [True, False])
@pytest.mark.parametrize("layer_class", [nn.Linear, NoisyLinear])
@pytest.mark.parametrize("device", get_default_devices())
def test_mlp(
in_features,
out_features,
depth,
num_cells,
activation_class,
activation_kwargs,
dropout,
bias_last_layer,
norm_class,
norm_kwargs,
single_bias_last_layer,
layer_class,
device,
seed=0,
):
torch.manual_seed(seed)
batch = 2
mlp = MLP(
in_features=in_features,
out_features=out_features,
depth=depth,
num_cells=num_cells,
activation_class=activation_class,
activation_kwargs=activation_kwargs,
norm_class=norm_class,
norm_kwargs=norm_kwargs,
dropout=dropout,
bias_last_layer=bias_last_layer,
single_bias_last_layer=False,
layer_class=layer_class,
device=device,
class TestMLP:
@pytest.mark.parametrize("in_features", [3, 10, None])
@pytest.mark.parametrize("out_features", [3, (3, 10)])
@pytest.mark.parametrize("depth, num_cells", [(3, 32), (None, (32, 32, 32))])
@pytest.mark.parametrize(
"activation_class, activation_kwargs",
[(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})],
)
if in_features is None:
in_features = 5
x = torch.randn(batch, in_features, device=device)
y = mlp(x)
out_features = [out_features] if isinstance(out_features, Number) else out_features
assert y.shape == torch.Size([batch, *out_features])
@pytest.mark.parametrize(
"norm_class, norm_kwargs",
[
(nn.LazyBatchNorm1d, {}),
(nn.BatchNorm1d, {"num_features": 32}),
(nn.LayerNorm, {"normalized_shape": 32}),
],
)
@pytest.mark.parametrize("dropout", [0.0, 0.5])
@pytest.mark.parametrize("bias_last_layer", [True, False])
@pytest.mark.parametrize("single_bias_last_layer", [True, False])
@pytest.mark.parametrize("layer_class", [nn.Linear, NoisyLinear])
@pytest.mark.parametrize("device", get_default_devices())
def test_mlp(
self,
in_features,
out_features,
depth,
num_cells,
activation_class,
activation_kwargs,
dropout,
bias_last_layer,
norm_class,
norm_kwargs,
single_bias_last_layer,
layer_class,
device,
seed=0,
):
torch.manual_seed(seed)
batch = 2
mlp = MLP(
in_features=in_features,
out_features=out_features,
depth=depth,
num_cells=num_cells,
activation_class=activation_class,
activation_kwargs=activation_kwargs,
norm_class=norm_class,
norm_kwargs=norm_kwargs,
dropout=dropout,
bias_last_layer=bias_last_layer,
single_bias_last_layer=False,
layer_class=layer_class,
device=device,
)
if in_features is None:
in_features = 5
x = torch.randn(batch, in_features, device=device)
y = mlp(x)
out_features = (
[out_features] if isinstance(out_features, Number) else out_features
)
assert y.shape == torch.Size([batch, *out_features])

def test_kwargs(self):
def make_activation(shift):
return lambda x: x + shift

def layer(*args, **kwargs):
linear = nn.Linear(*args, **kwargs)
linear.weight.data.copy_(torch.eye(4))
return linear

in_features = 4
out_features = 4
num_cells = [4, 4, 4]
mlp = MLP(
in_features=in_features,
out_features=out_features,
num_cells=num_cells,
activation_class=make_activation,
activation_kwargs=[{"shift": 0}, {"shift": 1}, {"shift": 2}],
layer_class=layer,
layer_kwargs=[{"bias": False}] * 4,
bias_last_layer=False,
)
x = torch.zeros(4)
y = mlp(x)
for i, module in enumerate(mlp.modules()):
if isinstance(module, nn.Linear):
assert (module.weight == torch.eye(4)).all(), i
assert module.bias is None, i
assert (y == 3).all()


@pytest.mark.parametrize("in_features", [3, 10, None])
Expand Down
3 changes: 2 additions & 1 deletion torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from torchrl.modules.tensordict_module.common import DistributionalDQNnet

from .distributions import (
Delta,
distributions_maps,
Expand All @@ -24,7 +26,6 @@
DdpgMlpActor,
DdpgMlpQNet,
DecisionTransformer,
DistributionalDQNnet,
DreamerActor,
DTActor,
DuelingCnnDQNet,
Expand Down
4 changes: 3 additions & 1 deletion torchrl/modules/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.


from torchrl.modules.tensordict_module.common import DistributionalDQNnet

from .decision_transformer import DecisionTransformer
from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise
from .model_based import DreamerActor, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior
Expand All @@ -15,9 +17,9 @@
DdpgCnnQNet,
DdpgMlpActor,
DdpgMlpQNet,
DistributionalDQNnet,
DTActor,
DuelingCnnDQNet,
DuelingMlpDQNet,
LSTMNet,
MLP,
OnlineDTActor,
Expand Down
Loading

0 comments on commit fe6c070

Please sign in to comment.