Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Nested keys in OrnsteinUhlenbeckProcess #1305

Merged
merged 22 commits into from
Jul 6, 2023
Merged
Prev Previous commit
Next Next commit
amend
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini committed Jul 6, 2023
commit 14be604cb406fd44d9c436313e6bcada4a5523b8
56 changes: 36 additions & 20 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tensordict.tensordict import TensorDictBase
from tensordict.utils import expand_as_right, expand_right, NestedKey

from torchrl._utils import unravel_key
from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.envs.utils import exploration_type, ExplorationType
from torchrl.modules.tensordict_module.common import _forward_hook_safe_action
Expand All @@ -34,7 +35,7 @@ class EGreedyWrapper(TensorDictModuleWrapper):
eps_end (scalar, optional): final epsilon value.
default: 0.1
annealing_num_steps (int, optional): number of steps it will take for epsilon to reach the eps_end value
action_key (str, Tuple[str], optional): if the policy module has more than one output key,
action_key (NestedKey, optional): if the policy module has more than one output key,
its output spec will be of type CompositeSpec. One needs to know where to
find the action spec.
Default is "action".
Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(
eps_init: float = 1.0,
eps_end: float = 0.1,
annealing_num_steps: int = 1000,
action_key: NestedKey = "action",
action_key: Optional[NestedKey] = "action",
spec: Optional[TensorSpec] = None,
):
super().__init__(policy)
Expand Down Expand Up @@ -346,7 +347,7 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper):
default: None
n_steps_annealing (int): number of steps for the sigma annealing.
default: 1000
action_key (str): key of the action to be modified.
action_key (NestedKey, optional): key of the action to be modified.
default: "action"
spec (TensorSpec, optional): if provided, the sampled action will be
projected onto the valid action space once explored. If not provided,
Expand Down Expand Up @@ -392,10 +393,11 @@ def __init__(
x0: Optional[Union[torch.Tensor, np.ndarray]] = None,
sigma_min: Optional[float] = None,
n_steps_annealing: int = 1000,
action_key: str = "action",
action_key: Optional[NestedKey] = "action",
is_init_key: Optional[NestedKey] = "is_init",
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
spec: TensorSpec = None,
safe: bool = True,
key: str = None,
key: Optional[NestedKey] = None,
):
if key is not None:
action_key = key
Expand Down Expand Up @@ -423,6 +425,7 @@ def __init__(
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.tensor([eps_init]))
self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys
self.is_init_key = is_init_key
noise_key = self.ou.noise_key
steps_key = self.ou.steps_key

Expand All @@ -432,11 +435,11 @@ def __init__(
self._spec = spec
elif hasattr(self.td_module, "_spec"):
self._spec = self.td_module._spec.clone()
if action_key not in self._spec.keys():
if action_key not in self._spec.keys(True, True):
self._spec[action_key] = None
elif hasattr(self.td_module, "spec"):
self._spec = self.td_module.spec.clone()
if action_key not in self._spec.keys():
if action_key not in self._spec.keys(True, True):
self._spec[action_key] = None
else:
self._spec = CompositeSpec({key: None for key in policy.out_keys})
Expand Down Expand Up @@ -481,18 +484,19 @@ def step(self, frames: int = 1) -> None:
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = super().forward(tensordict)
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
if "is_init" not in tensordict.keys():
if self.is_init_key not in tensordict.keys(True, True):
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
f"The tensordict passed to {self.__class__.__name__} appears to be "
f"missing the 'is_init' entry. This entry is used to "
f"missing the '{self.is_init_key}' entry. This entry is used to "
f"reset the noise at the beginning of a trajectory, without it "
f"the behaviour of this exploration method is undefined. "
f"This is allowed for BC compatibility purposes but it will be deprecated soon! "
f"To create a 'is_init' entry, simply append an torchrl.envs.InitTracker "
f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker "
f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
)
tensordict.set(
"is_init", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)
self.is_init_key,
torch.zeros(*tensordict.shape, 1, dtype=torch.bool),
)
tensordict = self.ou.add_sample(tensordict, self.eps.item())
return tensordict
Expand All @@ -509,7 +513,8 @@ def __init__(
x0: Optional[Union[torch.Tensor, np.ndarray]] = None,
sigma_min: Optional[float] = None,
n_steps_annealing: int = 1000,
key: str = "action",
key: Optional[NestedKey] = "action",
is_init_key: Optional[NestedKey] = "is_init",
):
self.mu = mu
self.sigma = sigma
Expand All @@ -528,8 +533,12 @@ def __init__(
self.dt = dt
self.x0 = x0 if x0 is not None else 0.0
self.key = key
self.is_init_key = is_init_key
self._noise_key = "_ou_prev_noise"
self._steps_key = "_ou_steps"
if isinstance(self.key, tuple) and len(self.key) > 1:
self._noise_key = unravel_key((self.key[:-1], self._noise_key))
self._steps_key = unravel_key((self.key[:-1], self._steps_key))
self.out_keys = [self.noise_key, self.steps_key]

@property
Expand Down Expand Up @@ -564,27 +573,34 @@ def add_sample(
self, tensordict: TensorDictBase, eps: float = 1.0
) -> TensorDictBase:

# Get the nested tensordict where the action lives
if isinstance(self.key, tuple) and len(self.key) > 1:
action_tensordict = tensordict.get(self.key[:-1])
action_key = self.key[-1]
else:
action_tensordict = tensordict
action_key = self.key

is_init = tensordict.get("is_init", None)
if is_init is not None:
if (action_tensordict.ndim > is_init.ndim) or (
action_tensordict.ndim <= is_init.ndim
and is_init.shape[: action_tensordict.ndim] != action_tensordict.shape
): # if is_init has less dimensions than action_tensordict or
# if the leading dims of is_init do not correspond to the batch_size of action_tensordict
# we expand it
is_init = tensordict.get(self.is_init_key, None)
if (
is_init is not None
): # is_init has the shape of done_spec, let's bring it to the action_tensordict shape
if is_init.ndim > 1 and is_init.shape[-1] == 1:
is_init = is_init.squeeze(-1) # Squeeze dangling dim
if (
action_tensordict.ndim >= is_init.ndim
): # if is_init has less dimensions than action_tensordict we expand it
is_init = expand_right(is_init, action_tensordict.shape)
else:
is_init = is_init.sum(
tuple(range(action_tensordict.batch_dims, is_init.ndim)),
dtype=torch.bool,
) # otherwise we reduce it to that batch_size
if is_init.shape != action_tensordict.shape:
raise ValueError(
f"'{self.is_init_key}' shape not compatible with action tensordict shape, "
f"got {tensordict.get(self.is_init_key).shape} and {action_tensordict.shape}"
)

if self.noise_key not in action_tensordict.keys():
self._make_noise_pair(action_tensordict, action_key)
Expand Down