Skip to content

Commit

Permalink
[Feature] Add modules.AdditiveGaussianModule (pytorch#2296)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler authored Jul 19, 2024
1 parent c771e6e commit bdc9784
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ other cases, the action written in the tensordict is simply the network output.
:toctree: generated/
:template: rl_template_noinherit.rst

AdditiveGaussianModule
AdditiveGaussianWrapper
EGreedyModule
EGreedyWrapper
Expand Down
93 changes: 69 additions & 24 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from torchrl.modules.tensordict_module.exploration import (
_OrnsteinUhlenbeckProcess,
AdditiveGaussianModule,
AdditiveGaussianWrapper,
EGreedyModule,
EGreedyWrapper,
Expand Down Expand Up @@ -423,39 +424,51 @@ def test_no_spec_error(self, device):
@pytest.mark.parametrize("device", get_default_devices())
class TestAdditiveGaussian:
@pytest.mark.parametrize("spec_origin", ["spec", "policy", None])
@pytest.mark.parametrize("interface", ["module", "wrapper"])
def test_additivegaussian_sd(
self,
device,
spec_origin,
interface,
d_obs=4,
d_act=6,
batch=32,
n_steps=100,
seed=0,
):
if interface == "module" and spec_origin != "spec":
pytest.skip("module raises an error if given spec=None")

torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
action_spec = BoundedTensorSpec(
-torch.ones(d_act, device=device),
torch.ones(d_act, device=device),
(d_act,),
device=device,
)
module = SafeModule(
net,
in_keys=["observation"],
out_keys=["loc", "scale"],
spec=None,
)
policy = ProbabilisticActor(
spec=CompositeSpec(action=action_spec) if spec_origin is not None else None,
module=module,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
given_spec = action_spec if spec_origin == "spec" else None
exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to(device)
if interface == "module":
exploratory_policy = AdditiveGaussianModule(action_spec).to(device)
else:
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
module = SafeModule(
net,
in_keys=["observation"],
out_keys=["loc", "scale"],
spec=None,
)
policy = ProbabilisticActor(
spec=CompositeSpec(action=action_spec)
if spec_origin is not None
else None,
module=module,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
given_spec = action_spec if spec_origin == "spec" else None
exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to(
device
)
if spec_origin is not None:
sigma_init = (
action_spec.project(
Expand All @@ -473,9 +486,14 @@ def test_additivegaussian_sd(
sigma_init = exploratory_policy.sigma_init
sigma_end = exploratory_policy.sigma_end
if spec_origin is None:
class_name = (
"AdditiveGaussianModule"
if interface == "module"
else "AdditiveGaussianWrapper"
)
with pytest.raises(
RuntimeError,
match="the action spec must be provided to AdditiveGaussianWrapper",
match=f"the action spec must be provided to {class_name}",
):
exploratory_policy._add_noise(action_spec.rand((100000,)).zero_())
return
Expand All @@ -497,9 +515,21 @@ def test_additivegaussian_sd(
assert abs(noisy_action.std() - sigma_end) < 1e-1

@pytest.mark.parametrize("spec_origin", ["spec", "policy", None])
def test_additivegaussian_wrapper(
self, device, spec_origin, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
@pytest.mark.parametrize("interface", ["module", "wrapper"])
def test_additivegaussian(
self,
device,
spec_origin,
interface,
d_obs=4,
d_act=6,
batch=32,
n_steps=100,
seed=0,
):
if interface == "module" and spec_origin != "spec":
pytest.skip("module raises an error if given spec=None")

torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
Expand All @@ -517,9 +547,14 @@ def test_additivegaussian_wrapper(
default_interaction_type=InteractionType.RANDOM,
).to(device)
given_spec = action_spec if spec_origin == "spec" else None
exploratory_policy = AdditiveGaussianWrapper(
policy, spec=given_spec, safe=False
).to(device)
if interface == "module":
exploratory_policy = TensorDictSequential(
policy, AdditiveGaussianModule(spec=given_spec).to(device)
)
else:
exploratory_policy = AdditiveGaussianWrapper(
policy, spec=given_spec, safe=False
).to(device)

tensordict = TensorDict(
batch_size=[batch],
Expand All @@ -544,7 +579,8 @@ def test_additivegaussian_wrapper(
assert action_spec.is_in(out.get("action"))

@pytest.mark.parametrize("parallel_spec", [True, False])
def test_collector(self, device, parallel_spec, seed=0):
@pytest.mark.parametrize("interface", ["module", "wrapper"])
def test_collector(self, device, parallel_spec, interface, seed=0):
torch.manual_seed(seed)
env = SerialEnv(
2,
Expand All @@ -570,7 +606,12 @@ def test_collector(self, device, parallel_spec, seed=0):
default_interaction_type=InteractionType.RANDOM,
spec=action_spec,
).to(device)
exploratory_policy = AdditiveGaussianWrapper(policy, safe=False)
if interface == "module":
exploratory_policy = TensorDictSequential(
policy, AdditiveGaussianModule(spec=action_spec).to(device)
)
else:
exploratory_policy = AdditiveGaussianWrapper(policy, safe=False)
exploratory_policy(env.reset())
collector = SyncDataCollector(
create_env_fn=env,
Expand All @@ -584,6 +625,10 @@ def test_collector(self, device, parallel_spec, seed=0):
pass
return

def test_no_spec_error(self, device):
with pytest.raises(RuntimeError, match="spec cannot be None."):
AdditiveGaussianModule(spec=None).to(device)


@pytest.mark.parametrize("state_dim", [7])
@pytest.mark.parametrize("action_dim", [5, 11])
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ActorCriticOperator,
ActorCriticWrapper,
ActorValueOperator,
AdditiveGaussianModule,
AdditiveGaussianWrapper,
DecisionTransformerInferenceWrapper,
DistributionalQValueActor,
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from .common import SafeModule, VmapModule
from .exploration import (
AdditiveGaussianModule,
AdditiveGaussianWrapper,
EGreedyModule,
EGreedyWrapper,
Expand Down
118 changes: 118 additions & 0 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
__all__ = [
"EGreedyWrapper",
"EGreedyModule",
"AdditiveGaussianModule",
"AdditiveGaussianWrapper",
"OrnsteinUhlenbeckProcessModule",
"OrnsteinUhlenbeckProcessWrapper",
Expand Down Expand Up @@ -300,6 +301,12 @@ def __init__(
spec: Optional[TensorSpec] = None,
safe: Optional[bool] = True,
):
warnings.warn(
"AdditiveGaussianWrapper is deprecated and will be removed "
"in v0.7. Please use torchrl.modules.AdditiveGaussianModule "
"instead.",
category=DeprecationWarning,
)
super().__init__(policy)
if sigma_end > sigma_init:
raise RuntimeError("sigma should decrease over time or be constant")
Expand Down Expand Up @@ -383,6 +390,117 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict


class AdditiveGaussianModule(TensorDictModuleBase):
"""Additive Gaussian PO module.
Args:
spec (TensorSpec): the spec used for sampling actions. The sampled
action will be projected onto the valid action space once explored.
sigma_init (scalar, optional): initial epsilon value.
default: 1.0
sigma_end (scalar, optional): final epsilon value.
default: 0.1
annealing_num_steps (int, optional): number of steps it will take for
sigma to reach the :obj:`sigma_end` value.
default: 1000
mean (float, optional): mean of each output element’s normal distribution.
default: 0.0
std (float, optional): standard deviation of each output element’s normal distribution.
default: 1.0
Keyword Args:
action_key (NestedKey, optional): if the policy module has more than one output key,
its output spec will be of type CompositeSpec. One needs to know where to
find the action spec.
default: "action"
.. note::
It is
crucial to incorporate a call to :meth:`~.step` in the training loop
to update the exploration factor.
Since it is not easy to capture this omission no warning or exception
will be raised if this is ommitted!
"""

def __init__(
self,
spec: TensorSpec,
sigma_init: float = 1.0,
sigma_end: float = 0.1,
annealing_num_steps: int = 1000,
mean: float = 0.0,
std: float = 1.0,
*,
action_key: Optional[NestedKey] = "action",
):
if not isinstance(sigma_init, float):
warnings.warn("eps_init should be a float.")
if sigma_end > sigma_init:
raise RuntimeError("sigma should decrease over time or be constant")
self.action_key = action_key
self.in_keys = [self.action_key]
self.out_keys = [self.action_key]

super().__init__()

self.register_buffer("sigma_init", torch.tensor([sigma_init]))
self.register_buffer("sigma_end", torch.tensor([sigma_end]))
self.annealing_num_steps = annealing_num_steps
self.register_buffer("mean", torch.tensor([mean]))
self.register_buffer("std", torch.tensor([std]))
self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32))

if spec is not None:
if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
else:
raise RuntimeError("spec cannot be None.")
self._spec = spec
self.register_forward_hook(_forward_hook_safe_action)

@property
def spec(self):
return self._spec

def step(self, frames: int = 1) -> None:
"""A step of sigma decay.
After `self.annealing_num_steps` calls to this method, calls result in no-op.
Args:
frames (int): number of frames since last step. Defaults to ``1``.
"""
for _ in range(frames):
self.sigma.data[0] = max(
self.sigma_end.item(),
(
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps
).item(),
)

def _add_noise(self, action: torch.Tensor) -> torch.Tensor:
sigma = self.sigma.item()
noise = torch.normal(
mean=torch.ones(action.shape) * self.mean.item(),
std=torch.ones(action.shape) * self.std.item(),
).to(action.device)
action = action + noise * sigma
spec = self.spec[self.action_key]
action = spec.project(action)
return action

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if exploration_type() is ExplorationType.RANDOM or exploration_type() is None:
out = tensordict.get(self.action_key)
out = self._add_noise(out)
tensordict.set(self.action_key, out)
return tensordict


class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper):
r"""Ornstein-Uhlenbeck exploration policy wrapper.
Expand Down

0 comments on commit bdc9784

Please sign in to comment.