diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 6c331aa8d46..170251a81fa 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -72,6 +72,7 @@ other cases, the action written in the tensordict is simply the network output. :toctree: generated/ :template: rl_template_noinherit.rst + AdditiveGaussianModule AdditiveGaussianWrapper EGreedyModule EGreedyWrapper diff --git a/test/test_exploration.py b/test/test_exploration.py index d0448fa5cb5..af618f843e2 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -44,6 +44,7 @@ ) from torchrl.modules.tensordict_module.exploration import ( _OrnsteinUhlenbeckProcess, + AdditiveGaussianModule, AdditiveGaussianWrapper, EGreedyModule, EGreedyWrapper, @@ -423,39 +424,51 @@ def test_no_spec_error(self, device): @pytest.mark.parametrize("device", get_default_devices()) class TestAdditiveGaussian: @pytest.mark.parametrize("spec_origin", ["spec", "policy", None]) + @pytest.mark.parametrize("interface", ["module", "wrapper"]) def test_additivegaussian_sd( self, device, spec_origin, + interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0, ): + if interface == "module" and spec_origin != "spec": + pytest.skip("module raises an error if given spec=None") + torch.manual_seed(seed) - net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) action_spec = BoundedTensorSpec( -torch.ones(d_act, device=device), torch.ones(d_act, device=device), (d_act,), device=device, ) - module = SafeModule( - net, - in_keys=["observation"], - out_keys=["loc", "scale"], - spec=None, - ) - policy = ProbabilisticActor( - spec=CompositeSpec(action=action_spec) if spec_origin is not None else None, - module=module, - in_keys=["loc", "scale"], - distribution_class=TanhNormal, - default_interaction_type=InteractionType.RANDOM, - ).to(device) - given_spec = action_spec if spec_origin == "spec" else None - exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to(device) + if interface == "module": + exploratory_policy = AdditiveGaussianModule(action_spec).to(device) + else: + net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) + module = SafeModule( + net, + in_keys=["observation"], + out_keys=["loc", "scale"], + spec=None, + ) + policy = ProbabilisticActor( + spec=CompositeSpec(action=action_spec) + if spec_origin is not None + else None, + module=module, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + default_interaction_type=InteractionType.RANDOM, + ).to(device) + given_spec = action_spec if spec_origin == "spec" else None + exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to( + device + ) if spec_origin is not None: sigma_init = ( action_spec.project( @@ -473,9 +486,14 @@ def test_additivegaussian_sd( sigma_init = exploratory_policy.sigma_init sigma_end = exploratory_policy.sigma_end if spec_origin is None: + class_name = ( + "AdditiveGaussianModule" + if interface == "module" + else "AdditiveGaussianWrapper" + ) with pytest.raises( RuntimeError, - match="the action spec must be provided to AdditiveGaussianWrapper", + match=f"the action spec must be provided to {class_name}", ): exploratory_policy._add_noise(action_spec.rand((100000,)).zero_()) return @@ -497,9 +515,21 @@ def test_additivegaussian_sd( assert abs(noisy_action.std() - sigma_end) < 1e-1 @pytest.mark.parametrize("spec_origin", ["spec", "policy", None]) - def test_additivegaussian_wrapper( - self, device, spec_origin, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0 + @pytest.mark.parametrize("interface", ["module", "wrapper"]) + def test_additivegaussian( + self, + device, + spec_origin, + interface, + d_obs=4, + d_act=6, + batch=32, + n_steps=100, + seed=0, ): + if interface == "module" and spec_origin != "spec": + pytest.skip("module raises an error if given spec=None") + torch.manual_seed(seed) net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) @@ -517,9 +547,14 @@ def test_additivegaussian_wrapper( default_interaction_type=InteractionType.RANDOM, ).to(device) given_spec = action_spec if spec_origin == "spec" else None - exploratory_policy = AdditiveGaussianWrapper( - policy, spec=given_spec, safe=False - ).to(device) + if interface == "module": + exploratory_policy = TensorDictSequential( + policy, AdditiveGaussianModule(spec=given_spec).to(device) + ) + else: + exploratory_policy = AdditiveGaussianWrapper( + policy, spec=given_spec, safe=False + ).to(device) tensordict = TensorDict( batch_size=[batch], @@ -544,7 +579,8 @@ def test_additivegaussian_wrapper( assert action_spec.is_in(out.get("action")) @pytest.mark.parametrize("parallel_spec", [True, False]) - def test_collector(self, device, parallel_spec, seed=0): + @pytest.mark.parametrize("interface", ["module", "wrapper"]) + def test_collector(self, device, parallel_spec, interface, seed=0): torch.manual_seed(seed) env = SerialEnv( 2, @@ -570,7 +606,12 @@ def test_collector(self, device, parallel_spec, seed=0): default_interaction_type=InteractionType.RANDOM, spec=action_spec, ).to(device) - exploratory_policy = AdditiveGaussianWrapper(policy, safe=False) + if interface == "module": + exploratory_policy = TensorDictSequential( + policy, AdditiveGaussianModule(spec=action_spec).to(device) + ) + else: + exploratory_policy = AdditiveGaussianWrapper(policy, safe=False) exploratory_policy(env.reset()) collector = SyncDataCollector( create_env_fn=env, @@ -584,6 +625,10 @@ def test_collector(self, device, parallel_spec, seed=0): pass return + def test_no_spec_error(self, device): + with pytest.raises(RuntimeError, match="spec cannot be None."): + AdditiveGaussianModule(spec=None).to(device) + @pytest.mark.parametrize("state_dim", [7]) @pytest.mark.parametrize("action_dim", [5, 11]) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 782dd0bacd2..944210386e9 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -51,6 +51,7 @@ ActorCriticOperator, ActorCriticWrapper, ActorValueOperator, + AdditiveGaussianModule, AdditiveGaussianWrapper, DecisionTransformerInferenceWrapper, DistributionalQValueActor, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index fb796f12438..202f84fd173 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -23,6 +23,7 @@ ) from .common import SafeModule, VmapModule from .exploration import ( + AdditiveGaussianModule, AdditiveGaussianWrapper, EGreedyModule, EGreedyWrapper, diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index b23096c1280..5a41f11bf76 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -23,6 +23,7 @@ __all__ = [ "EGreedyWrapper", "EGreedyModule", + "AdditiveGaussianModule", "AdditiveGaussianWrapper", "OrnsteinUhlenbeckProcessModule", "OrnsteinUhlenbeckProcessWrapper", @@ -300,6 +301,12 @@ def __init__( spec: Optional[TensorSpec] = None, safe: Optional[bool] = True, ): + warnings.warn( + "AdditiveGaussianWrapper is deprecated and will be removed " + "in v0.7. Please use torchrl.modules.AdditiveGaussianModule " + "instead.", + category=DeprecationWarning, + ) super().__init__(policy) if sigma_end > sigma_init: raise RuntimeError("sigma should decrease over time or be constant") @@ -383,6 +390,117 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict +class AdditiveGaussianModule(TensorDictModuleBase): + """Additive Gaussian PO module. + + Args: + spec (TensorSpec): the spec used for sampling actions. The sampled + action will be projected onto the valid action space once explored. + sigma_init (scalar, optional): initial epsilon value. + default: 1.0 + sigma_end (scalar, optional): final epsilon value. + default: 0.1 + annealing_num_steps (int, optional): number of steps it will take for + sigma to reach the :obj:`sigma_end` value. + default: 1000 + mean (float, optional): mean of each output element’s normal distribution. + default: 0.0 + std (float, optional): standard deviation of each output element’s normal distribution. + default: 1.0 + + Keyword Args: + action_key (NestedKey, optional): if the policy module has more than one output key, + its output spec will be of type CompositeSpec. One needs to know where to + find the action spec. + default: "action" + + .. note:: + It is + crucial to incorporate a call to :meth:`~.step` in the training loop + to update the exploration factor. + Since it is not easy to capture this omission no warning or exception + will be raised if this is ommitted! + + + """ + + def __init__( + self, + spec: TensorSpec, + sigma_init: float = 1.0, + sigma_end: float = 0.1, + annealing_num_steps: int = 1000, + mean: float = 0.0, + std: float = 1.0, + *, + action_key: Optional[NestedKey] = "action", + ): + if not isinstance(sigma_init, float): + warnings.warn("eps_init should be a float.") + if sigma_end > sigma_init: + raise RuntimeError("sigma should decrease over time or be constant") + self.action_key = action_key + self.in_keys = [self.action_key] + self.out_keys = [self.action_key] + + super().__init__() + + self.register_buffer("sigma_init", torch.tensor([sigma_init])) + self.register_buffer("sigma_end", torch.tensor([sigma_end])) + self.annealing_num_steps = annealing_num_steps + self.register_buffer("mean", torch.tensor([mean])) + self.register_buffer("std", torch.tensor([std])) + self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32)) + + if spec is not None: + if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: + spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + else: + raise RuntimeError("spec cannot be None.") + self._spec = spec + self.register_forward_hook(_forward_hook_safe_action) + + @property + def spec(self): + return self._spec + + def step(self, frames: int = 1) -> None: + """A step of sigma decay. + + After `self.annealing_num_steps` calls to this method, calls result in no-op. + + Args: + frames (int): number of frames since last step. Defaults to ``1``. + + """ + for _ in range(frames): + self.sigma.data[0] = max( + self.sigma_end.item(), + ( + self.sigma + - (self.sigma_init - self.sigma_end) / self.annealing_num_steps + ).item(), + ) + + def _add_noise(self, action: torch.Tensor) -> torch.Tensor: + sigma = self.sigma.item() + noise = torch.normal( + mean=torch.ones(action.shape) * self.mean.item(), + std=torch.ones(action.shape) * self.std.item(), + ).to(action.device) + action = action + noise * sigma + spec = self.spec[self.action_key] + action = spec.project(action) + return action + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if exploration_type() is ExplorationType.RANDOM or exploration_type() is None: + out = tensordict.get(self.action_key) + out = self._add_noise(out) + tensordict.set(self.action_key, out) + return tensordict + + class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): r"""Ornstein-Uhlenbeck exploration policy wrapper.