Skip to content

Commit

Permalink
[Feature] Add modules.OrnsteinUhlenbeckProcessModule (pytorch#2297)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler authored Jul 19, 2024
1 parent c4b2eb0 commit c771e6e
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ other cases, the action written in the tensordict is simply the network output.
AdditiveGaussianWrapper
EGreedyModule
EGreedyWrapper
OrnsteinUhlenbeckProcessModule
OrnsteinUhlenbeckProcessWrapper

Probabilistic actors
Expand Down
59 changes: 45 additions & 14 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
AdditiveGaussianWrapper,
EGreedyModule,
EGreedyWrapper,
OrnsteinUhlenbeckProcessModule,
OrnsteinUhlenbeckProcessWrapper,
)

Expand Down Expand Up @@ -203,8 +204,8 @@ def test_wrong_action_shape(self, module):


@pytest.mark.parametrize("device", get_default_devices())
class TestOrnsteinUhlenbeckProcessWrapper:
def test_ou(self, device, seed=0):
class TestOrnsteinUhlenbeckProcess:
def test_ou_process(self, device, seed=0):
torch.manual_seed(seed)
td = TensorDict({"action": torch.randn(3) / 10}, batch_size=[], device=device)
ou = _OrnsteinUhlenbeckProcess(10.0, mu=2.0, x0=-4, sigma=0.1, sigma_min=0.01)
Expand All @@ -229,7 +230,10 @@ def test_ou(self, device, seed=0):
assert pval_acc > 0.05
assert pval_reg < 0.1

def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0):
@pytest.mark.parametrize("interface", ["module", "wrapper"])
def test_ou(
self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
):
torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
Expand All @@ -241,7 +245,13 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)

if interface == "module":
ou = OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device)
exploratory_policy = TensorDictSequential(policy, ou)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
ou = exploratory_policy

tensordict = TensorDict(
batch_size=[batch],
Expand All @@ -261,13 +271,11 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=
)
tensordict = exploratory_policy(tensordict.clone())
if i == 0:
assert (tensordict[exploratory_policy.ou.steps_key] == 1).all()
assert (tensordict[ou.ou.steps_key] == 1).all()
elif i == n_steps // 2 + 1:
assert (
tensordict[exploratory_policy.ou.steps_key][: batch // 2] == 1
).all()
assert (tensordict[ou.ou.steps_key][: batch // 2] == 1).all()
else:
assert not (tensordict[exploratory_policy.ou.steps_key] == 1).any()
assert not (tensordict[ou.ou.steps_key] == 1).any()

out.append(tensordict.clone())
out_noexp.append(tensordict_noexp.clone())
Expand All @@ -284,7 +292,8 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=

@pytest.mark.parametrize("parallel_spec", [True, False])
@pytest.mark.parametrize("probabilistic", [True, False])
def test_collector(self, device, parallel_spec, probabilistic, seed=0):
@pytest.mark.parametrize("interface", ["module", "wrapper"])
def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0):
torch.manual_seed(seed)
env = SerialEnv(
2,
Expand Down Expand Up @@ -317,7 +326,12 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0):
net, in_keys=["observation"], out_keys=["action"], spec=action_spec
)

exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
if interface == "module":
exploratory_policy = TensorDictSequential(
policy, OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device)
)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
exploratory_policy(env.reset())
collector = SyncDataCollector(
create_env_fn=env,
Expand All @@ -334,12 +348,14 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0):
@pytest.mark.parametrize("nested_obs_action", [True, False])
@pytest.mark.parametrize("nested_done", [True, False])
@pytest.mark.parametrize("is_init_key", ["some"])
@pytest.mark.parametrize("interface", ["module", "wrapper"])
def test_nested(
self,
device,
nested_obs_action,
nested_done,
is_init_key,
interface,
seed=0,
n_envs=2,
nested_dim=5,
Expand Down Expand Up @@ -368,9 +384,20 @@ def test_nested(
in_keys=[("data", "states") if nested_obs_action else "observation"],
out_keys=[env.action_key],
)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(
policy, spec=action_spec, action_key=env.action_key, is_init_key=is_init_key
)
if interface == "module":
exploratory_policy = TensorDictSequential(
policy,
OrnsteinUhlenbeckProcessModule(
spec=action_spec, action_key=env.action_key, is_init_key=is_init_key
).to(device),
)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(
policy,
spec=action_spec,
action_key=env.action_key,
is_init_key=is_init_key,
)
collector = SyncDataCollector(
create_env_fn=env,
policy=exploratory_policy,
Expand All @@ -388,6 +415,10 @@ def test_nested(

return

def test_no_spec_error(self, device):
with pytest.raises(RuntimeError, match="spec cannot be None."):
OrnsteinUhlenbeckProcessModule(spec=None).to(device)


@pytest.mark.parametrize("device", get_default_devices())
class TestAdditiveGaussian:
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
LSTMCell,
LSTMModule,
MultiStepActorWrapper,
OrnsteinUhlenbeckProcessModule,
OrnsteinUhlenbeckProcessWrapper,
ProbabilisticActor,
QValueActor,
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AdditiveGaussianWrapper,
EGreedyModule,
EGreedyWrapper,
OrnsteinUhlenbeckProcessModule,
OrnsteinUhlenbeckProcessWrapper,
)
from .probabilistic import (
Expand Down
200 changes: 200 additions & 0 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"EGreedyWrapper",
"EGreedyModule",
"AdditiveGaussianWrapper",
"OrnsteinUhlenbeckProcessModule",
"OrnsteinUhlenbeckProcessWrapper",
]

Expand Down Expand Up @@ -491,6 +492,12 @@ def __init__(
safe: bool = True,
key: Optional[NestedKey] = None,
):
warnings.warn(
"OrnsteinUhlenbeckProcessWrapper is deprecated and will be removed "
"in v0.7. Please use torchrl.modules.OrnsteinUhlenbeckProcessModule "
"instead.",
category=DeprecationWarning,
)
if key is not None:
action_key = key
warnings.warn(
Expand Down Expand Up @@ -593,6 +600,199 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict


class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase):
r"""Ornstein-Uhlenbeck exploration policy module.
Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf.
The OU exploration is to be used with continuous control policies and introduces a auto-correlated exploration
noise. This enables a sort of 'structured' exploration.
Noise equation:
.. math::
noise_t = noise_{t-1} + \theta * (mu - noise_{t-1}) * dt + \sigma_t * \sqrt{dt} * W
Sigma equation:
.. math::
\sigma_t = max(\sigma^{min, (-(\sigma_{t-1} - \sigma^{min}) / (n^{\text{steps annealing}}) * n^{\text{steps}} + \sigma))
To keep track of the steps and noise from sample to sample, an :obj:`"ou_prev_noise{id}"` and :obj:`"ou_steps{id}"` keys
will be written in the input/output tensordict. It is expected that the tensordict will be zeroed at reset,
indicating that a new trajectory is being collected. If not, and is the same tensordict is used for consecutive
trajectories, the step count will keep on increasing across rollouts. Note that the collector classes take care of
zeroing the tensordict at reset time.
.. note::
It is
crucial to incorporate a call to :meth:`~.step` in the training loop
to update the exploration factor.
Since it is not easy to capture this omission no warning or exception
will be raised if this is ommitted!
Args:
spec (TensorSpec): the spec used for sampling actions. The sampled
action will be projected onto the valid action space once explored.
eps_init (scalar): initial epsilon value, determining the amount of noise to be added.
default: 1.0
eps_end (scalar): final epsilon value, determining the amount of noise to be added.
default: 0.1
annealing_num_steps (int): number of steps it will take for epsilon to reach the eps_end value.
default: 1000
theta (scalar): theta factor in the noise equation
default: 0.15
mu (scalar): OU average (mu in the noise equation).
default: 0.0
sigma (scalar): sigma value in the sigma equation.
default: 0.2
dt (scalar): dt in the noise equation.
default: 0.01
x0 (Tensor, ndarray, optional): initial value of the process.
default: 0.0
sigma_min (number, optional): sigma_min in the sigma equation.
default: None
n_steps_annealing (int): number of steps for the sigma annealing.
default: 1000
Keyword Args:
action_key (NestedKey, optional): key of the action to be modified.
default: "action"
is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps.
default: "is_init"
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictSequential
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules import OrnsteinUhlenbeckProcessModule, Actor
>>> torch.manual_seed(0)
>>> spec = BoundedTensorSpec(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(module=module, spec=spec)
>>> ou = OrnsteinUhlenbeckProcessModule(spec=spec)
>>> explorative_policy = TensorDictSequential(policy, ou)
>>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
>>> print(explorative_policy(td))
TensorDict(
fields={
_ou_prev_noise: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
_ou_steps: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
action: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)
"""

def __init__(
self,
spec: TensorSpec,
eps_init: float = 1.0,
eps_end: float = 0.1,
annealing_num_steps: int = 1000,
theta: float = 0.15,
mu: float = 0.0,
sigma: float = 0.2,
dt: float = 1e-2,
x0: Optional[Union[torch.Tensor, np.ndarray]] = None,
sigma_min: Optional[float] = None,
n_steps_annealing: int = 1000,
*,
action_key: Optional[NestedKey] = "action",
is_init_key: Optional[NestedKey] = "is_init",
):
super().__init__()

self.ou = _OrnsteinUhlenbeckProcess(
theta=theta,
mu=mu,
sigma=sigma,
dt=dt,
x0=x0,
sigma_min=sigma_min,
n_steps_annealing=n_steps_annealing,
key=action_key,
)

self.register_buffer("eps_init", torch.tensor([eps_init]))
self.register_buffer("eps_end", torch.tensor([eps_end]))
if self.eps_end > self.eps_init:
raise ValueError(
"eps should decrease over time or be constant, "
f"got eps_init={eps_init} and eps_end={eps_end}"
)
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32))

self.in_keys = [self.ou.key]
self.out_keys = [self.ou.key] + self.ou.out_keys
self.is_init_key = is_init_key
noise_key = self.ou.noise_key
steps_key = self.ou.steps_key

if spec is not None:
if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
else:
raise RuntimeError("spec cannot be None.")
ou_specs = {
noise_key: None,
steps_key: None,
}
self._spec.update(ou_specs)
if len(set(self.out_keys)) != len(self.out_keys):
raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}")
self.register_forward_hook(_forward_hook_safe_action)

@property
def spec(self):
return self._spec

def step(self, frames: int = 1) -> None:
"""Updates the eps noise factor.
Args:
frames (int): number of frames of the current batch (corresponding to the number of updates to be made).
"""
for _ in range(frames):
if self.annealing_num_steps > 0:
self.eps.data[0] = max(
self.eps_end.item(),
(
self.eps
- (self.eps_init - self.eps_end) / self.annealing_num_steps
).item(),
)
else:
raise ValueError(
f"{self.__class__.__name__}.step() called when "
f"self.annealing_num_steps={self.annealing_num_steps}. Expected a strictly positive "
f"number of frames."
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
is_init = tensordict.get(self.is_init_key, None)
if is_init is None:
warnings.warn(
f"The tensordict passed to {self.__class__.__name__} appears to be "
f"missing the '{self.is_init_key}' entry. This entry is used to "
f"reset the noise at the beginning of a trajectory, without it "
f"the behaviour of this exploration method is undefined. "
f"This is allowed for BC compatibility purposes but it will be deprecated soon! "
f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker "
f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
)
tensordict = self.ou.add_sample(
tensordict, self.eps.item(), is_init=is_init
)
return tensordict


# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
class _OrnsteinUhlenbeckProcess:
def __init__(
Expand Down

0 comments on commit c771e6e

Please sign in to comment.