Skip to content

Commit

Permalink
[Feature] Nested keys in OrnsteinUhlenbeckProcess (pytorch#1305)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini authored Jul 6, 2023
1 parent c44c25a commit 2a477a2
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 52 deletions.
11 changes: 8 additions & 3 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,13 +944,18 @@ def forward(self, observation, action):
return self.linear(torch.cat([observation, action], dim=-1))


class CountingEnvCountPolicy:
class CountingEnvCountPolicy(nn.Module):
def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"):
super().__init__()
self.action_spec = action_spec
self.action_key = action_key

def __call__(self, td: TensorDictBase) -> TensorDictBase:
return td.set(self.action_key, self.action_spec.zero() + 1)
def __call__(self, t):
action = self.action_spec.zero() + 1
if isinstance(t, torch.Tensor):
return action
elif isinstance(t, TensorDictBase):
return t.set(self.action_key, action)


class CountingEnv(EnvBase):
Expand Down
61 changes: 59 additions & 2 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
import pytest
import torch
from _utils_internal import get_default_devices
from mocking_classes import ContinuousActionVecMockEnv
from mocking_classes import (
ContinuousActionVecMockEnv,
CountingEnvCountPolicy,
NestedCountingEnv,
)
from scipy.stats import ttest_1samp
from tensordict.nn import InteractionType
from tensordict.nn import InteractionType, TensorDictModule
from tensordict.tensordict import TensorDict
from torch import nn

Expand Down Expand Up @@ -180,6 +184,59 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0):
pass
return

@pytest.mark.parametrize("nested_obs_action", [True, False])
@pytest.mark.parametrize("nested_done", [True, False])
@pytest.mark.parametrize("is_init_key", ["some", ("one", "nested")])
def test_nested(
self,
device,
nested_obs_action,
nested_done,
is_init_key,
seed=0,
n_envs=2,
nested_dim=5,
frames_per_batch=100,
):
torch.manual_seed(seed)

env = SerialEnv(
n_envs,
lambda: TransformedEnv(
NestedCountingEnv(
nest_obs_action=nested_obs_action,
nest_done=nested_done,
nested_dim=nested_dim,
).to(device),
InitTracker(init_key=is_init_key),
),
)

action_spec = env.action_spec
d_act = action_spec.shape[-1]

net = nn.LazyLinear(d_act).to(device)
policy = TensorDictModule(
CountingEnvCountPolicy(action_spec=action_spec, action_key=env.action_key),
in_keys=[("data", "states") if nested_obs_action else "observation"],
out_keys=[env.action_key],
)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(
policy, spec=action_spec, action_key=env.action_key, is_init_key=is_init_key
)
collector = SyncDataCollector(
create_env_fn=env,
policy=exploratory_policy,
frames_per_batch=frames_per_batch,
total_frames=1000,
device=device,
)
for _td in collector:
assert _td[is_init_key].shape == _td[env.done_key].shape
break

return


@pytest.mark.parametrize("device", get_default_devices())
class TestAdditiveGaussian:
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3953,7 +3953,7 @@ class InitTracker(Transform):
that is set to ``True`` whenever :meth:`~.reset` is called.
Args:
init_key (str, optional): the key to be used for the tracker entry.
init_key (NestedKey, optional): the key to be used for the tracker entry.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand All @@ -3971,7 +3971,7 @@ def __init__(self, init_key: bool = "is_init"):
super().__init__(in_keys=[], out_keys=[init_key])

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.out_keys[0] not in tensordict.keys():
if self.out_keys[0] not in tensordict.keys(True, True):
device = tensordict.device
if device is None:
device = torch.device("cpu")
Expand Down
126 changes: 81 additions & 45 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from tensordict.nn import TensorDictModule, TensorDictModuleWrapper
from tensordict.tensordict import TensorDictBase
from tensordict.utils import expand_as_right, NestedKey
from tensordict.utils import expand_as_right, expand_right, NestedKey

from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.envs.utils import exploration_type, ExplorationType
Expand All @@ -34,7 +34,7 @@ class EGreedyWrapper(TensorDictModuleWrapper):
eps_end (scalar, optional): final epsilon value.
default: 0.1
annealing_num_steps (int, optional): number of steps it will take for epsilon to reach the eps_end value
action_key (str, Tuple[str], optional): if the policy module has more than one output key,
action_key (NestedKey, optional): if the policy module has more than one output key,
its output spec will be of type CompositeSpec. One needs to know where to
find the action spec.
Default is "action".
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
eps_init: float = 1.0,
eps_end: float = 0.1,
annealing_num_steps: int = 1000,
action_key: NestedKey = "action",
action_key: Optional[NestedKey] = "action",
spec: Optional[TensorSpec] = None,
):
super().__init__(policy)
Expand Down Expand Up @@ -173,7 +173,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper):
sigma to reach the :obj:`sigma_end` value.
mean (float, optional): mean of each output element’s normal distribution.
std (float, optional): standard deviation of each output element’s normal distribution.
action_key (str, optional): if the policy module has more than one output key,
action_key (NestedKey, optional): if the policy module has more than one output key,
its output spec will be of type CompositeSpec. One needs to know where to
find the action spec.
Default is "action".
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(
annealing_num_steps: int = 1000,
mean: float = 0.0,
std: float = 1.0,
action_key: str = "action",
action_key: Optional[NestedKey] = "action",
spec: Optional[TensorSpec] = None,
safe: Optional[bool] = True,
):
Expand Down Expand Up @@ -346,8 +346,10 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper):
default: None
n_steps_annealing (int): number of steps for the sigma annealing.
default: 1000
action_key (str): key of the action to be modified.
action_key (NestedKey, optional): key of the action to be modified.
default: "action"
is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps.
default: "is_init"
spec (TensorSpec, optional): if provided, the sampled action will be
projected onto the valid action space once explored. If not provided,
the exploration wrapper will attempt to recover it from the policy.
Expand Down Expand Up @@ -392,10 +394,11 @@ def __init__(
x0: Optional[Union[torch.Tensor, np.ndarray]] = None,
sigma_min: Optional[float] = None,
n_steps_annealing: int = 1000,
action_key: str = "action",
action_key: Optional[NestedKey] = "action",
is_init_key: Optional[NestedKey] = "is_init",
spec: TensorSpec = None,
safe: bool = True,
key: str = None,
key: Optional[NestedKey] = None,
):
if key is not None:
action_key = key
Expand Down Expand Up @@ -423,6 +426,7 @@ def __init__(
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.tensor([eps_init]))
self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys
self.is_init_key = is_init_key
noise_key = self.ou.noise_key
steps_key = self.ou.steps_key

Expand All @@ -432,11 +436,11 @@ def __init__(
self._spec = spec
elif hasattr(self.td_module, "_spec"):
self._spec = self.td_module._spec.clone()
if action_key not in self._spec.keys():
if action_key not in self._spec.keys(True, True):
self._spec[action_key] = None
elif hasattr(self.td_module, "spec"):
self._spec = self.td_module.spec.clone()
if action_key not in self._spec.keys():
if action_key not in self._spec.keys(True, True):
self._spec[action_key] = None
else:
self._spec = CompositeSpec({key: None for key in policy.out_keys})
Expand Down Expand Up @@ -481,20 +485,20 @@ def step(self, frames: int = 1) -> None:
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = super().forward(tensordict)
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
if "is_init" not in tensordict.keys():
is_init = tensordict.get(self.is_init_key, None)
if is_init is None:
warnings.warn(
f"The tensordict passed to {self.__class__.__name__} appears to be "
f"missing the 'is_init' entry. This entry is used to "
f"missing the '{self.is_init_key}' entry. This entry is used to "
f"reset the noise at the beginning of a trajectory, without it "
f"the behaviour of this exploration method is undefined. "
f"This is allowed for BC compatibility purposes but it will be deprecated soon! "
f"To create a 'is_init' entry, simply append an torchrl.envs.InitTracker "
f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker "
f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
)
tensordict.set(
"is_init", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)
)
tensordict = self.ou.add_sample(tensordict, self.eps.item())
tensordict = self.ou.add_sample(
tensordict, self.eps.item(), is_init=is_init
)
return tensordict


Expand All @@ -509,7 +513,8 @@ def __init__(
x0: Optional[Union[torch.Tensor, np.ndarray]] = None,
sigma_min: Optional[float] = None,
n_steps_annealing: int = 1000,
key: str = "action",
key: Optional[NestedKey] = "action",
is_init_key: Optional[NestedKey] = "is_init",
):
self.mu = mu
self.sigma = sigma
Expand All @@ -528,6 +533,7 @@ def __init__(
self.dt = dt
self.x0 = x0 if x0 is not None else 0.0
self.key = key
self.is_init_key = is_init_key
self._noise_key = "_ou_prev_noise"
self._steps_key = "_ou_steps"
self.out_keys = [self.noise_key, self.steps_key]
Expand All @@ -540,43 +546,73 @@ def noise_key(self):
def steps_key(self):
return self._steps_key # + str(id(self))

def _make_noise_pair(self, tensordict: TensorDictBase, is_init=None) -> None:
def _make_noise_pair(
self,
action_tensordict: TensorDictBase,
tensordict: TensorDictBase,
is_init: torch.Tensor,
):
if self.steps_key not in tensordict.keys():
noise = torch.zeros(
tensordict.get(self.key).shape, device=tensordict.device
)
steps = torch.zeros(
action_tensordict.batch_size, dtype=torch.long, device=tensordict.device
)
tensordict.set(self.noise_key, noise)
tensordict.set(self.steps_key, steps)
else:
noise = tensordict.get(self.noise_key)
steps = tensordict.get(self.steps_key)
if is_init is not None:
tensordict = tensordict.get_sub_tensordict(is_init.view(tensordict.shape))
tensordict.set(
self.noise_key,
torch.zeros(tensordict.get(self.key).shape, device=tensordict.device),
inplace=is_init is not None,
)
tensordict.set(
self.steps_key,
torch.zeros(
torch.Size([*tensordict.batch_size, 1]),
dtype=torch.long,
device=tensordict.device,
),
inplace=is_init is not None,
)
noise[is_init] = 0
steps[is_init] = 0
return noise, steps

def add_sample(
self, tensordict: TensorDictBase, eps: float = 1.0
self,
tensordict: TensorDictBase,
eps: float = 1.0,
is_init: Optional[torch.Tensor] = None,
) -> TensorDictBase:

if self.noise_key not in tensordict.keys():
self._make_noise_pair(tensordict)
is_init = tensordict.get("is_init", None)
if is_init is not None and is_init.any():
self._make_noise_pair(tensordict, is_init.view(tensordict.shape))

prev_noise = tensordict.get(self.noise_key)
prev_noise = prev_noise + self.x0
# Get the nested tensordict where the action lives
if isinstance(self.key, tuple) and len(self.key) > 1:
action_tensordict = tensordict.get(self.key[:-1])
else:
action_tensordict = tensordict

if is_init is None:
is_init = tensordict.get(self.is_init_key, None)
if (
is_init is not None
): # is_init has the shape of done_spec, let's bring it to the action_tensordict shape
if is_init.ndim > 1 and is_init.shape[-1] == 1:
is_init = is_init.squeeze(-1) # Squeeze dangling dim
if (
action_tensordict.ndim >= is_init.ndim
): # if is_init has less dimensions than action_tensordict we expand it
is_init = expand_right(is_init, action_tensordict.shape)
else:
is_init = is_init.sum(
tuple(range(action_tensordict.batch_dims, is_init.ndim)),
dtype=torch.bool,
) # otherwise we reduce it to that batch_size
if is_init.shape != action_tensordict.shape:
raise ValueError(
f"'{self.is_init_key}' shape not compatible with action tensordict shape, "
f"got {tensordict.get(self.is_init_key).shape} and {action_tensordict.shape}"
)

n_steps = tensordict.get(self.steps_key)
prev_noise, n_steps = self._make_noise_pair(
action_tensordict, tensordict, is_init
)

prev_noise = prev_noise + self.x0
noise = (
prev_noise
+ self.theta * (self.mu - prev_noise) * self.dt
+ self.current_sigma(n_steps)
+ self.current_sigma(expand_as_right(n_steps, prev_noise))
* np.sqrt(self.dt)
* torch.randn_like(prev_noise)
)
Expand Down

0 comments on commit 2a477a2

Please sign in to comment.