Skip to content

Commit

Permalink
[Feature] Consistent Dropout (#2399)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
N00bcak and vmoens authored Sep 10, 2024
1 parent 6aa4b53 commit 0ad8e59
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 16 deletions.
23 changes: 15 additions & 8 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,29 @@ projected (in a L1-manner) into the desired domain.
SafeSequential
TanhModule

Exploration wrappers
~~~~~~~~~~~~~~~~~~~~
Exploration wrappers and modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To efficiently explore the environment, TorchRL proposes a series of wrappers
To efficiently explore the environment, TorchRL proposes a series of modules
that will override the action sampled by the policy by a noisier version.
Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_mode`:
if the exploration is set to ``"random"``, the exploration is active. In all
other cases, the action written in the tensordict is simply the network output.

.. currentmodule:: torchrl.modules.tensordict_module
.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule`
uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch.
The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on
this module.

.. currentmodule:: torchrl.modules

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

AdditiveGaussianModule
AdditiveGaussianWrapper
ConsistentDropoutModule
EGreedyModule
EGreedyWrapper
OrnsteinUhlenbeckProcessModule
Expand Down Expand Up @@ -438,12 +444,13 @@ Regular modules
:toctree: generated/
:template: rl_template_noinherit.rst

MLP
ConvNet
BatchRenorm1d
ConsistentDropout
Conv3dNet
SqueezeLayer
ConvNet
MLP
Squeeze2dLayer
BatchRenorm1d
SqueezeLayer

Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
152 changes: 151 additions & 1 deletion test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
NormalParamExtractor,
TanhNormal,
)
from torchrl.modules.models.exploration import LazygSDEModule
from torchrl.modules.models.exploration import ConsistentDropoutModule, LazygSDEModule
from torchrl.modules.tensordict_module.actors import (
Actor,
ProbabilisticActor,
Expand Down Expand Up @@ -738,6 +738,156 @@ def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_s
), f"failed: mean={mean}, std={std}, sigma_init={sigma_init}, actual: {sigma.mean()}"


class TestConsistentDropout:
@pytest.mark.parametrize("dropout_p", [0.0, 0.1, 0.5])
@pytest.mark.parametrize("parallel_spec", [False, True])
@pytest.mark.parametrize("device", get_default_devices())
def test_consistent_dropout(self, dropout_p, parallel_spec, device):
"""
This preliminary test seeks to ensure two things for both
ConsistentDropout and ConsistentDropoutModule:
1. Rollout transitions generate a dropout mask as desired.
- We can easily verify the existence of a mask
2. The dropout mask is correctly applied.
- We will check with stochastic policies whether or not
the loc and scale are the same.
"""
torch.manual_seed(0)

# NOTE: Please only put a module with one dropout layer.
# That's how this test is constructed anyways.
@torch.no_grad
def inner_verify_routine(module, env):
# Perform transitions.
collector = SyncDataCollector(
create_env_fn=env,
policy=module,
frames_per_batch=1,
total_frames=10,
device=device,
)
for frames in collector:
masks = [
(key, value)
for key, value in frames.items()
if key.startswith("mask_")
]
# Assert rollouts do indeed correctly generate the masks.
assert len(masks) == 1, (
"Expected exactly ONE mask since we only put "
f"one dropout module, got {len(masks)}."
)

# Verify that the result for this batch is the same.
# Kind of Monte Carlo, to be honest.
sentinel_mask = masks[0][1].clone()
sentinel_outputs = frames.select("loc", "scale").clone()

desired_dropout_mask = torch.full_like(
sentinel_mask, 1 / (1 - dropout_p)
)
desired_dropout_mask[sentinel_mask == 0.0] = 0.0
# As of 15/08/24, :meth:`~torch.nn.functional.dropout`
# is being used. Never hurts to be safe.
assert torch.allclose(
sentinel_mask, desired_dropout_mask
), "Dropout was not scaled properly."

new_frames = module(frames.clone())
infer_mask = new_frames[masks[0][0]]
infer_outputs = new_frames.select("loc", "scale")
assert (infer_mask == sentinel_mask).all(), "Mask does not match"

assert all(
[
torch.allclose(infer_outputs[key], sentinel_outputs[key])
for key in ("loc", "scale")
]
), (
"Outputs do not match:\n "
f"{infer_outputs['loc']}\n--- vs ---\n{sentinel_outputs['loc']}"
f"{infer_outputs['scale']}\n--- vs ---\n{sentinel_outputs['scale']}"
)

env = SerialEnv(
2,
ContinuousActionVecMockEnv,
)
env = TransformedEnv(env.to(device), InitTracker())
env = env.to(device)
# the module must work with the action spec of a single env or a serial env
if parallel_spec:
action_spec = env.action_spec
else:
action_spec = ContinuousActionVecMockEnv(device=device).action_spec
d_act = action_spec.shape[-1]

# NOTE: Please only put a module with one dropout layer.
# That's how this test is constructed anyways.
module_td_seq = TensorDictSequential(
TensorDictModule(
nn.LazyLinear(2 * d_act), in_keys=["observation"], out_keys=["out"]
),
ConsistentDropoutModule(p=dropout_p, in_keys="out"),
TensorDictModule(
NormalParamExtractor(), in_keys=["out"], out_keys=["loc", "scale"]
),
)

policy_td_seq = ProbabilisticActor(
module=module_td_seq,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
spec=action_spec,
).to(device)

# Wake up the policies
policy_td_seq(env.reset())

# Test.
inner_verify_routine(policy_td_seq, env)

def test_consistent_dropout_primer(self):
import torch

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torchrl.envs import SerialEnv, StepCounter
from torchrl.modules import ConsistentDropoutModule, get_primers_from_module

torch.manual_seed(0)

m = Seq(
Mod(
torch.nn.Linear(7, 4),
in_keys=["observation"],
out_keys=["intermediate"],
),
ConsistentDropoutModule(
p=0.5,
input_shape=(
2,
4,
),
in_keys="intermediate",
),
Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]),
)
primer = get_primers_from_module(m)
env0 = ContinuousActionVecMockEnv().append_transform(StepCounter(5))
env1 = ContinuousActionVecMockEnv().append_transform(StepCounter(6))
env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env])
env = env.append_transform(primer)
r = env.rollout(10, m, break_when_any_done=False)
mask = [k for k in r.keys() if k.startswith("mask")][0]
assert (r[mask][0, :5] != r[mask][0, 5:6]).any()
assert (r[mask][0, :4] == r[mask][0, 4:5]).all()

assert (r[mask][1, :6] != r[mask][1, 6:7]).any()
assert (r[mask][1, :5] == r[mask][1, 5:6]).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 7 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4597,7 +4597,7 @@ class TensorDictPrimer(Transform):
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
To facilitate this process, the method :func:`~torchrl.models.utils.get_primers_from_module`
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
automatically checks for required primer transforms in a module and its submodules and
generates them.
"""
Expand Down Expand Up @@ -4664,10 +4664,15 @@ def __init__(
def reset_key(self):
reset_key = self.__dict__.get("_reset_key", None)
if reset_key is None:
if self.parent is None:
raise RuntimeError(
"Missing parent, cannot infer reset_key automatically."
)
reset_keys = self.parent.reset_keys
if len(reset_keys) > 1:
raise RuntimeError(
f"Got more than one reset key in env {self.container}, cannot infer which one to use. Consider providing the reset key in the {type(self)} constructor."
f"Got more than one reset key in env {self.container}, cannot infer which one to use. "
f"Consider providing the reset key in the {type(self)} constructor."
)
reset_key = self._reset_key = reset_keys[0]
return reset_key
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from .models import (
BatchRenorm1d,
ConsistentDropoutModule,
Conv3dNet,
ConvNet,
DdpgCnnActor,
Expand Down Expand Up @@ -85,4 +86,5 @@
VmapModule,
WorldModelWrapper,
)
from .utils import get_primers_from_module
from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip
7 changes: 6 additions & 1 deletion torchrl/modules/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from .batchrenorm import BatchRenorm1d

from .decision_transformer import DecisionTransformer
from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise
from .exploration import (
ConsistentDropoutModule,
NoisyLazyLinear,
NoisyLinear,
reset_noise,
)
from .model_based import (
DreamerActor,
ObsDecoder,
Expand Down
Loading

0 comments on commit 0ad8e59

Please sign in to comment.