Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add modules.OrnsteinUhlenbeckProcessModule #2297

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ other cases, the action written in the tensordict is simply the network output.
AdditiveGaussianWrapper
EGreedyModule
EGreedyWrapper
OrnsteinUhlenbeckProcessModule
OrnsteinUhlenbeckProcessWrapper
Comment on lines +78 to 79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
OrnsteinUhlenbeckProcessModule
OrnsteinUhlenbeckProcessWrapper
OrnsteinUhlenbeckProcessModule

since we want to deprecate it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm upon reflection, let's keep this one, people may still see it in their code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'd think it's typically a good idea to keep the old documentation and only remove it at the point when the deprecated feature is actually removed


Probabilistic actors
Expand Down
59 changes: 45 additions & 14 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
AdditiveGaussianWrapper,
EGreedyModule,
EGreedyWrapper,
OrnsteinUhlenbeckProcessModule,
OrnsteinUhlenbeckProcessWrapper,
)

Expand Down Expand Up @@ -203,8 +204,8 @@ def test_wrong_action_shape(self, module):


@pytest.mark.parametrize("device", get_default_devices())
class TestOrnsteinUhlenbeckProcessWrapper:
def test_ou(self, device, seed=0):
class TestOrnsteinUhlenbeckProcess:
def test_ou_process(self, device, seed=0):
torch.manual_seed(seed)
td = TensorDict({"action": torch.randn(3) / 10}, batch_size=[], device=device)
ou = _OrnsteinUhlenbeckProcess(10.0, mu=2.0, x0=-4, sigma=0.1, sigma_min=0.01)
Expand All @@ -229,7 +230,10 @@ def test_ou(self, device, seed=0):
assert pval_acc > 0.05
assert pval_reg < 0.1

def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0):
@pytest.mark.parametrize("interface", ["module", "wrapper"])
def test_ou(
self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
):
torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
Expand All @@ -241,7 +245,13 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)

if interface == "module":
ou = OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device)
exploratory_policy = TensorDictSequential(policy, ou)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
ou = exploratory_policy

tensordict = TensorDict(
batch_size=[batch],
Expand All @@ -261,13 +271,11 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=
)
tensordict = exploratory_policy(tensordict.clone())
if i == 0:
assert (tensordict[exploratory_policy.ou.steps_key] == 1).all()
assert (tensordict[ou.ou.steps_key] == 1).all()
elif i == n_steps // 2 + 1:
assert (
tensordict[exploratory_policy.ou.steps_key][: batch // 2] == 1
).all()
assert (tensordict[ou.ou.steps_key][: batch // 2] == 1).all()
else:
assert not (tensordict[exploratory_policy.ou.steps_key] == 1).any()
assert not (tensordict[ou.ou.steps_key] == 1).any()

out.append(tensordict.clone())
out_noexp.append(tensordict_noexp.clone())
Expand All @@ -284,7 +292,8 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=

@pytest.mark.parametrize("parallel_spec", [True, False])
@pytest.mark.parametrize("probabilistic", [True, False])
def test_collector(self, device, parallel_spec, probabilistic, seed=0):
@pytest.mark.parametrize("interface", ["module", "wrapper"])
def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0):
torch.manual_seed(seed)
env = SerialEnv(
2,
Expand Down Expand Up @@ -317,7 +326,12 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0):
net, in_keys=["observation"], out_keys=["action"], spec=action_spec
)

exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
if interface == "module":
exploratory_policy = TensorDictSequential(
policy, OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device)
)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
exploratory_policy(env.reset())
collector = SyncDataCollector(
create_env_fn=env,
Expand All @@ -334,12 +348,14 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0):
@pytest.mark.parametrize("nested_obs_action", [True, False])
@pytest.mark.parametrize("nested_done", [True, False])
@pytest.mark.parametrize("is_init_key", ["some"])
@pytest.mark.parametrize("interface", ["module", "wrapper"])
def test_nested(
self,
device,
nested_obs_action,
nested_done,
is_init_key,
interface,
seed=0,
n_envs=2,
nested_dim=5,
Expand Down Expand Up @@ -368,9 +384,20 @@ def test_nested(
in_keys=[("data", "states") if nested_obs_action else "observation"],
out_keys=[env.action_key],
)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(
policy, spec=action_spec, action_key=env.action_key, is_init_key=is_init_key
)
if interface == "module":
exploratory_policy = TensorDictSequential(
policy,
OrnsteinUhlenbeckProcessModule(
spec=action_spec, action_key=env.action_key, is_init_key=is_init_key
).to(device),
)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(
policy,
spec=action_spec,
action_key=env.action_key,
is_init_key=is_init_key,
)
collector = SyncDataCollector(
create_env_fn=env,
policy=exploratory_policy,
Expand All @@ -388,6 +415,10 @@ def test_nested(

return

def test_no_spec_error(self, device):
with pytest.raises(RuntimeError, match="spec cannot be None."):
OrnsteinUhlenbeckProcessModule(spec=None).to(device)


@pytest.mark.parametrize("device", get_default_devices())
class TestAdditiveGaussian:
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
LSTMCell,
LSTMModule,
MultiStepActorWrapper,
OrnsteinUhlenbeckProcessModule,
OrnsteinUhlenbeckProcessWrapper,
ProbabilisticActor,
QValueActor,
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AdditiveGaussianWrapper,
EGreedyModule,
EGreedyWrapper,
OrnsteinUhlenbeckProcessModule,
OrnsteinUhlenbeckProcessWrapper,
)
from .probabilistic import (
Expand Down
200 changes: 200 additions & 0 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"EGreedyWrapper",
"EGreedyModule",
"AdditiveGaussianWrapper",
"OrnsteinUhlenbeckProcessModule",
"OrnsteinUhlenbeckProcessWrapper",
]

Expand Down Expand Up @@ -491,6 +492,12 @@ def __init__(
safe: bool = True,
key: Optional[NestedKey] = None,
):
warnings.warn(
"OrnsteinUhlenbeckProcessWrapper is deprecated and will be removed "
"in v0.7. Please use torchrl.modules.OrnsteinUhlenbeckProcessModule "
"instead.",
category=DeprecationWarning,
)
if key is not None:
action_key = key
warnings.warn(
Expand Down Expand Up @@ -593,6 +600,199 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict


class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase):
r"""Ornstein-Uhlenbeck exploration policy module.

Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf.

The OU exploration is to be used with continuous control policies and introduces a auto-correlated exploration
noise. This enables a sort of 'structured' exploration.

Noise equation:

.. math::
noise_t = noise_{t-1} + \theta * (mu - noise_{t-1}) * dt + \sigma_t * \sqrt{dt} * W

Sigma equation:

.. math::
\sigma_t = max(\sigma^{min, (-(\sigma_{t-1} - \sigma^{min}) / (n^{\text{steps annealing}}) * n^{\text{steps}} + \sigma))

To keep track of the steps and noise from sample to sample, an :obj:`"ou_prev_noise{id}"` and :obj:`"ou_steps{id}"` keys
will be written in the input/output tensordict. It is expected that the tensordict will be zeroed at reset,
indicating that a new trajectory is being collected. If not, and is the same tensordict is used for consecutive
trajectories, the step count will keep on increasing across rollouts. Note that the collector classes take care of
zeroing the tensordict at reset time.

.. note::
It is
crucial to incorporate a call to :meth:`~.step` in the training loop
to update the exploration factor.
Since it is not easy to capture this omission no warning or exception
will be raised if this is ommitted!

Args:
spec (TensorSpec): the spec used for sampling actions. The sampled
action will be projected onto the valid action space once explored.
eps_init (scalar): initial epsilon value, determining the amount of noise to be added.
default: 1.0
eps_end (scalar): final epsilon value, determining the amount of noise to be added.
default: 0.1
annealing_num_steps (int): number of steps it will take for epsilon to reach the eps_end value.
default: 1000
theta (scalar): theta factor in the noise equation
default: 0.15
mu (scalar): OU average (mu in the noise equation).
default: 0.0
sigma (scalar): sigma value in the sigma equation.
default: 0.2
dt (scalar): dt in the noise equation.
default: 0.01
x0 (Tensor, ndarray, optional): initial value of the process.
default: 0.0
sigma_min (number, optional): sigma_min in the sigma equation.
default: None
n_steps_annealing (int): number of steps for the sigma annealing.
default: 1000

Keyword Args:
action_key (NestedKey, optional): key of the action to be modified.
default: "action"
is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps.
default: "is_init"

Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictSequential
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules import OrnsteinUhlenbeckProcessModule, Actor
>>> torch.manual_seed(0)
>>> spec = BoundedTensorSpec(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(module=module, spec=spec)
>>> ou = OrnsteinUhlenbeckProcessModule(spec=spec)
>>> explorative_policy = TensorDictSequential(policy, ou)
>>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
>>> print(explorative_policy(td))
TensorDict(
fields={
_ou_prev_noise: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
_ou_steps: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
action: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)
"""

def __init__(
self,
spec: TensorSpec,
eps_init: float = 1.0,
eps_end: float = 0.1,
annealing_num_steps: int = 1000,
theta: float = 0.15,
mu: float = 0.0,
sigma: float = 0.2,
dt: float = 1e-2,
x0: Optional[Union[torch.Tensor, np.ndarray]] = None,
sigma_min: Optional[float] = None,
n_steps_annealing: int = 1000,
*,
action_key: Optional[NestedKey] = "action",
is_init_key: Optional[NestedKey] = "is_init",
):
super().__init__()

self.ou = _OrnsteinUhlenbeckProcess(
theta=theta,
mu=mu,
sigma=sigma,
dt=dt,
x0=x0,
sigma_min=sigma_min,
n_steps_annealing=n_steps_annealing,
key=action_key,
)

self.register_buffer("eps_init", torch.tensor([eps_init]))
self.register_buffer("eps_end", torch.tensor([eps_end]))
if self.eps_end > self.eps_init:
raise ValueError(
"eps should decrease over time or be constant, "
f"got eps_init={eps_init} and eps_end={eps_end}"
)
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32))

self.in_keys = [self.ou.key]
self.out_keys = [self.ou.key] + self.ou.out_keys
self.is_init_key = is_init_key
noise_key = self.ou.noise_key
steps_key = self.ou.steps_key

if spec is not None:
if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
else:
raise RuntimeError("spec cannot be None.")
ou_specs = {
noise_key: None,
steps_key: None,
}
self._spec.update(ou_specs)
if len(set(self.out_keys)) != len(self.out_keys):
raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}")
self.register_forward_hook(_forward_hook_safe_action)

@property
def spec(self):
return self._spec

def step(self, frames: int = 1) -> None:
"""Updates the eps noise factor.

Args:
frames (int): number of frames of the current batch (corresponding to the number of updates to be made).

"""
for _ in range(frames):
if self.annealing_num_steps > 0:
self.eps.data[0] = max(
self.eps_end.item(),
(
self.eps
- (self.eps_init - self.eps_end) / self.annealing_num_steps
).item(),
)
else:
raise ValueError(
f"{self.__class__.__name__}.step() called when "
f"self.annealing_num_steps={self.annealing_num_steps}. Expected a strictly positive "
f"number of frames."
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
is_init = tensordict.get(self.is_init_key, None)
if is_init is None:
warnings.warn(
f"The tensordict passed to {self.__class__.__name__} appears to be "
f"missing the '{self.is_init_key}' entry. This entry is used to "
f"reset the noise at the beginning of a trajectory, without it "
f"the behaviour of this exploration method is undefined. "
f"This is allowed for BC compatibility purposes but it will be deprecated soon! "
f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker "
f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
)
tensordict = self.ou.add_sample(
tensordict, self.eps.item(), is_init=is_init
)
return tensordict


# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
class _OrnsteinUhlenbeckProcess:
def __init__(
Expand Down
Loading