From c771e6e858faa84a9084c7e3993c3d0a134d8ea3 Mon Sep 17 00:00:00 2001 From: kurtamohler Date: Fri, 19 Jul 2024 09:41:14 -0700 Subject: [PATCH] [Feature] Add `modules.OrnsteinUhlenbeckProcessModule` (#2297) --- docs/source/reference/modules.rst | 1 + test/test_exploration.py | 59 ++++-- torchrl/modules/__init__.py | 1 + torchrl/modules/tensordict_module/__init__.py | 1 + .../modules/tensordict_module/exploration.py | 200 ++++++++++++++++++ 5 files changed, 248 insertions(+), 14 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index b46d789ed15..6c331aa8d46 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -75,6 +75,7 @@ other cases, the action written in the tensordict is simply the network output. AdditiveGaussianWrapper EGreedyModule EGreedyWrapper + OrnsteinUhlenbeckProcessModule OrnsteinUhlenbeckProcessWrapper Probabilistic actors diff --git a/test/test_exploration.py b/test/test_exploration.py index f65ea655de2..d0448fa5cb5 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -47,6 +47,7 @@ AdditiveGaussianWrapper, EGreedyModule, EGreedyWrapper, + OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, ) @@ -203,8 +204,8 @@ def test_wrong_action_shape(self, module): @pytest.mark.parametrize("device", get_default_devices()) -class TestOrnsteinUhlenbeckProcessWrapper: - def test_ou(self, device, seed=0): +class TestOrnsteinUhlenbeckProcess: + def test_ou_process(self, device, seed=0): torch.manual_seed(seed) td = TensorDict({"action": torch.randn(3) / 10}, batch_size=[], device=device) ou = _OrnsteinUhlenbeckProcess(10.0, mu=2.0, x0=-4, sigma=0.1, sigma_min=0.01) @@ -229,7 +230,10 @@ def test_ou(self, device, seed=0): assert pval_acc > 0.05 assert pval_reg < 0.1 - def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0): + @pytest.mark.parametrize("interface", ["module", "wrapper"]) + def test_ou( + self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0 + ): torch.manual_seed(seed) net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) @@ -241,7 +245,13 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed= distribution_class=TanhNormal, default_interaction_type=InteractionType.RANDOM, ).to(device) - exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) + + if interface == "module": + ou = OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device) + exploratory_policy = TensorDictSequential(policy, ou) + else: + exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) + ou = exploratory_policy tensordict = TensorDict( batch_size=[batch], @@ -261,13 +271,11 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed= ) tensordict = exploratory_policy(tensordict.clone()) if i == 0: - assert (tensordict[exploratory_policy.ou.steps_key] == 1).all() + assert (tensordict[ou.ou.steps_key] == 1).all() elif i == n_steps // 2 + 1: - assert ( - tensordict[exploratory_policy.ou.steps_key][: batch // 2] == 1 - ).all() + assert (tensordict[ou.ou.steps_key][: batch // 2] == 1).all() else: - assert not (tensordict[exploratory_policy.ou.steps_key] == 1).any() + assert not (tensordict[ou.ou.steps_key] == 1).any() out.append(tensordict.clone()) out_noexp.append(tensordict_noexp.clone()) @@ -284,7 +292,8 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed= @pytest.mark.parametrize("parallel_spec", [True, False]) @pytest.mark.parametrize("probabilistic", [True, False]) - def test_collector(self, device, parallel_spec, probabilistic, seed=0): + @pytest.mark.parametrize("interface", ["module", "wrapper"]) + def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0): torch.manual_seed(seed) env = SerialEnv( 2, @@ -317,7 +326,12 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0): net, in_keys=["observation"], out_keys=["action"], spec=action_spec ) - exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) + if interface == "module": + exploratory_policy = TensorDictSequential( + policy, OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device) + ) + else: + exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) exploratory_policy(env.reset()) collector = SyncDataCollector( create_env_fn=env, @@ -334,12 +348,14 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0): @pytest.mark.parametrize("nested_obs_action", [True, False]) @pytest.mark.parametrize("nested_done", [True, False]) @pytest.mark.parametrize("is_init_key", ["some"]) + @pytest.mark.parametrize("interface", ["module", "wrapper"]) def test_nested( self, device, nested_obs_action, nested_done, is_init_key, + interface, seed=0, n_envs=2, nested_dim=5, @@ -368,9 +384,20 @@ def test_nested( in_keys=[("data", "states") if nested_obs_action else "observation"], out_keys=[env.action_key], ) - exploratory_policy = OrnsteinUhlenbeckProcessWrapper( - policy, spec=action_spec, action_key=env.action_key, is_init_key=is_init_key - ) + if interface == "module": + exploratory_policy = TensorDictSequential( + policy, + OrnsteinUhlenbeckProcessModule( + spec=action_spec, action_key=env.action_key, is_init_key=is_init_key + ).to(device), + ) + else: + exploratory_policy = OrnsteinUhlenbeckProcessWrapper( + policy, + spec=action_spec, + action_key=env.action_key, + is_init_key=is_init_key, + ) collector = SyncDataCollector( create_env_fn=env, policy=exploratory_policy, @@ -388,6 +415,10 @@ def test_nested( return + def test_no_spec_error(self, device): + with pytest.raises(RuntimeError, match="spec cannot be None."): + OrnsteinUhlenbeckProcessModule(spec=None).to(device) + @pytest.mark.parametrize("device", get_default_devices()) class TestAdditiveGaussian: diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 4a3c5e716e8..782dd0bacd2 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -66,6 +66,7 @@ LSTMCell, LSTMModule, MultiStepActorWrapper, + OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, QValueActor, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 98dfcf80f3b..fb796f12438 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -26,6 +26,7 @@ AdditiveGaussianWrapper, EGreedyModule, EGreedyWrapper, + OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, ) from .probabilistic import ( diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index e8a1e94698d..b23096c1280 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -24,6 +24,7 @@ "EGreedyWrapper", "EGreedyModule", "AdditiveGaussianWrapper", + "OrnsteinUhlenbeckProcessModule", "OrnsteinUhlenbeckProcessWrapper", ] @@ -491,6 +492,12 @@ def __init__( safe: bool = True, key: Optional[NestedKey] = None, ): + warnings.warn( + "OrnsteinUhlenbeckProcessWrapper is deprecated and will be removed " + "in v0.7. Please use torchrl.modules.OrnsteinUhlenbeckProcessModule " + "instead.", + category=DeprecationWarning, + ) if key is not None: action_key = key warnings.warn( @@ -593,6 +600,199 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict +class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase): + r"""Ornstein-Uhlenbeck exploration policy module. + + Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf. + + The OU exploration is to be used with continuous control policies and introduces a auto-correlated exploration + noise. This enables a sort of 'structured' exploration. + + Noise equation: + + .. math:: + noise_t = noise_{t-1} + \theta * (mu - noise_{t-1}) * dt + \sigma_t * \sqrt{dt} * W + + Sigma equation: + + .. math:: + \sigma_t = max(\sigma^{min, (-(\sigma_{t-1} - \sigma^{min}) / (n^{\text{steps annealing}}) * n^{\text{steps}} + \sigma)) + + To keep track of the steps and noise from sample to sample, an :obj:`"ou_prev_noise{id}"` and :obj:`"ou_steps{id}"` keys + will be written in the input/output tensordict. It is expected that the tensordict will be zeroed at reset, + indicating that a new trajectory is being collected. If not, and is the same tensordict is used for consecutive + trajectories, the step count will keep on increasing across rollouts. Note that the collector classes take care of + zeroing the tensordict at reset time. + + .. note:: + It is + crucial to incorporate a call to :meth:`~.step` in the training loop + to update the exploration factor. + Since it is not easy to capture this omission no warning or exception + will be raised if this is ommitted! + + Args: + spec (TensorSpec): the spec used for sampling actions. The sampled + action will be projected onto the valid action space once explored. + eps_init (scalar): initial epsilon value, determining the amount of noise to be added. + default: 1.0 + eps_end (scalar): final epsilon value, determining the amount of noise to be added. + default: 0.1 + annealing_num_steps (int): number of steps it will take for epsilon to reach the eps_end value. + default: 1000 + theta (scalar): theta factor in the noise equation + default: 0.15 + mu (scalar): OU average (mu in the noise equation). + default: 0.0 + sigma (scalar): sigma value in the sigma equation. + default: 0.2 + dt (scalar): dt in the noise equation. + default: 0.01 + x0 (Tensor, ndarray, optional): initial value of the process. + default: 0.0 + sigma_min (number, optional): sigma_min in the sigma equation. + default: None + n_steps_annealing (int): number of steps for the sigma annealing. + default: 1000 + + Keyword Args: + action_key (NestedKey, optional): key of the action to be modified. + default: "action" + is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps. + default: "is_init" + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictSequential + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules import OrnsteinUhlenbeckProcessModule, Actor + >>> torch.manual_seed(0) + >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> module = torch.nn.Linear(4, 4, bias=False) + >>> policy = Actor(module=module, spec=spec) + >>> ou = OrnsteinUhlenbeckProcessModule(spec=spec) + >>> explorative_policy = TensorDictSequential(policy, ou) + >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) + >>> print(explorative_policy(td)) + TensorDict( + fields={ + _ou_prev_noise: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False), + _ou_steps: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), + action: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([10]), + device=None, + is_shared=False) + """ + + def __init__( + self, + spec: TensorSpec, + eps_init: float = 1.0, + eps_end: float = 0.1, + annealing_num_steps: int = 1000, + theta: float = 0.15, + mu: float = 0.0, + sigma: float = 0.2, + dt: float = 1e-2, + x0: Optional[Union[torch.Tensor, np.ndarray]] = None, + sigma_min: Optional[float] = None, + n_steps_annealing: int = 1000, + *, + action_key: Optional[NestedKey] = "action", + is_init_key: Optional[NestedKey] = "is_init", + ): + super().__init__() + + self.ou = _OrnsteinUhlenbeckProcess( + theta=theta, + mu=mu, + sigma=sigma, + dt=dt, + x0=x0, + sigma_min=sigma_min, + n_steps_annealing=n_steps_annealing, + key=action_key, + ) + + self.register_buffer("eps_init", torch.tensor([eps_init])) + self.register_buffer("eps_end", torch.tensor([eps_end])) + if self.eps_end > self.eps_init: + raise ValueError( + "eps should decrease over time or be constant, " + f"got eps_init={eps_init} and eps_end={eps_end}" + ) + self.annealing_num_steps = annealing_num_steps + self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) + + self.in_keys = [self.ou.key] + self.out_keys = [self.ou.key] + self.ou.out_keys + self.is_init_key = is_init_key + noise_key = self.ou.noise_key + steps_key = self.ou.steps_key + + if spec is not None: + if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: + spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + self._spec = spec + else: + raise RuntimeError("spec cannot be None.") + ou_specs = { + noise_key: None, + steps_key: None, + } + self._spec.update(ou_specs) + if len(set(self.out_keys)) != len(self.out_keys): + raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}") + self.register_forward_hook(_forward_hook_safe_action) + + @property + def spec(self): + return self._spec + + def step(self, frames: int = 1) -> None: + """Updates the eps noise factor. + + Args: + frames (int): number of frames of the current batch (corresponding to the number of updates to be made). + + """ + for _ in range(frames): + if self.annealing_num_steps > 0: + self.eps.data[0] = max( + self.eps_end.item(), + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ).item(), + ) + else: + raise ValueError( + f"{self.__class__.__name__}.step() called when " + f"self.annealing_num_steps={self.annealing_num_steps}. Expected a strictly positive " + f"number of frames." + ) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: + is_init = tensordict.get(self.is_init_key, None) + if is_init is None: + warnings.warn( + f"The tensordict passed to {self.__class__.__name__} appears to be " + f"missing the '{self.is_init_key}' entry. This entry is used to " + f"reset the noise at the beginning of a trajectory, without it " + f"the behaviour of this exploration method is undefined. " + f"This is allowed for BC compatibility purposes but it will be deprecated soon! " + f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " + f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." + ) + tensordict = self.ou.add_sample( + tensordict, self.eps.item(), is_init=is_init + ) + return tensordict + + # Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab class _OrnsteinUhlenbeckProcess: def __init__(