From d537dcb6347e2370fcdaed553bf3474d653cb5a6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 21:42:10 +0000 Subject: [PATCH] [Feature] EnvBase.auto_specs_ ghstack-source-id: 329679238c5172d7ff13097ceaa189479d4f4145 Pull Request resolved: https://github.com/pytorch/rl/pull/2601 --- test/mocking_classes.py | 14 ++- test/test_env.py | 28 +++++ test/test_specs.py | 9 ++ torchrl/data/tensor_specs.py | 25 +++- torchrl/envs/common.py | 230 +++++++++++++++++++++++++++-------- torchrl/envs/utils.py | 30 +++-- torchrl/objectives/sac.py | 2 +- 7 files changed, 267 insertions(+), 71 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index eb517429c08..d78e2f27184 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1038,11 +1038,13 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) + try: + device = self.full_action_spec[self.action_key].device + except KeyError: + device = self.device self.count += action.to( dtype=torch.int, - device=self.full_action_spec[self.action_key].device - if self.device is None - else self.device, + device=device if self.device is None else self.device, ) tensordict = TensorDict( source={ @@ -1275,8 +1277,10 @@ def __init__( max_steps = torch.tensor(5) if start_val is None: start_val = torch.zeros((), dtype=torch.int32) - if not max_steps.shape == self.batch_size: - raise RuntimeError("batch_size and max_steps shape must match.") + if max_steps.shape != self.batch_size: + raise RuntimeError( + f"batch_size and max_steps shape must match. Got self.batch_size={self.batch_size} and max_steps.shape={max_steps.shape}." + ) self.max_steps = max_steps diff --git a/test/test_env.py b/test/test_env.py index ab854a3b4be..81708b0b9a6 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3526,6 +3526,34 @@ def test_single_env_spec(): assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape)) +def test_auto_spec(): + env = CountingEnv() + td = env.reset() + + policy = lambda td, action_spec=env.full_action_spec.clone(): td.update( + action_spec.rand() + ) + + env.full_observation_spec = Composite( + shape=env.full_observation_spec.shape, device=env.full_observation_spec.device + ) + env.full_action_spec = Composite( + shape=env.full_action_spec.shape, device=env.full_action_spec.device + ) + env.full_reward_spec = Composite( + shape=env.full_reward_spec.shape, device=env.full_reward_spec.device + ) + env.full_done_spec = Composite( + shape=env.full_done_spec.shape, device=env.full_done_spec.device + ) + env.full_state_spec = Composite( + shape=env.full_state_spec.shape, device=env.full_state_spec.device + ) + env._action_keys = ["action"] + env.auto_specs_(policy, tensordict=td.copy()) + env.check_env_specs(tensordict=td.copy()) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_specs.py b/test/test_specs.py index 39b09798ac2..3dedc6233a9 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -412,6 +412,15 @@ def test_getitem(self, shape, is_complete, device, dtype): with pytest.raises(KeyError): _ = ts["UNK"] + def test_setitem_newshape(self, shape, is_complete, device, dtype): + ts = self._composite_spec(shape, is_complete, device, dtype) + new_spec = ts.clone() + new_spec.shape = torch.Size(()) + new_spec.clear_device_() + ts["new_spec"] = new_spec + assert ts["new_spec"].shape == ts.shape + assert ts["new_spec"].device == ts.device + def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) for key in {"shape", "device", "dtype", "space"}: diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 32e61bc3ede..ddf6ed41c99 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4372,11 +4372,20 @@ def set(self, name, spec): if spec is not None: shape = spec.shape if shape[: self.ndim] != self.shape: - raise ValueError( - "The shape of the spec and the Composite mismatch: the first " - f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " - f"Composite.shape={self.shape}." - ) + if ( + isinstance(spec, Composite) + and spec.ndim < self.ndim + and self.shape[: spec.ndim] == spec.shape + ): + # Try to set the composite shape + spec = spec.clone() + spec.shape = self.shape + else: + raise ValueError( + "The shape of the spec and the Composite mismatch: the first " + f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " + f"Composite.shape={self.shape}." + ) self._specs[name] = spec def __init__( @@ -4448,6 +4457,8 @@ def clear_device_(self): """Clears the device of the Composite.""" self._device = None for spec in self._specs.values(): + if spec is None: + continue spec.clear_device_() return self @@ -4530,6 +4541,10 @@ def __setitem__(self, key, value): and value.device != self.device ): if isinstance(value, Composite) and value.device is None: + # We make a clone not to mess up the spec that was provided. + # in set() we do the same for shape - these two ops should be grouped. + # we don't care about the overhead of cloning twice though because in theory + # we don't set specs often. value = value.clone().to(self.device) else: raise RuntimeError( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 8adf36b0019..d5a062bc11e 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -356,6 +356,9 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): .. note:: Learn more about dynamic specs and environments :ref:`here `. """ + _batch_size: torch.Size | None + _device: torch.device | None + def __init__( self, *, @@ -364,34 +367,178 @@ def __init__( run_type_checks: bool = False, allow_done_after_reset: bool = False, ): + super().__init__() + self.__dict__.setdefault("_batch_size", None) - if device is not None: - self.__dict__["_device"] = _make_ordinal_device(torch.device(device)) - output_spec = self.__dict__.get("_output_spec") - if output_spec is not None: - self.__dict__["_output_spec"] = ( - output_spec.to(self.device) - if self.device is not None - else output_spec - ) - input_spec = self.__dict__.get("_input_spec") - if input_spec is not None: - self.__dict__["_input_spec"] = ( - input_spec.to(self.device) - if self.device is not None - else input_spec - ) + self.__dict__.setdefault("_device", None) - super().__init__() - if "is_closed" not in self.__dir__(): - self.is_closed = True if batch_size is not None: # we want an error to be raised if we pass batch_size but # it's already been set - self.batch_size = torch.Size(batch_size) + batch_size = self.batch_size = torch.Size(batch_size) + else: + batch_size = torch.Size(()) + + if device is not None: + device = self.__dict__["_device"] = _make_ordinal_device( + torch.device(device) + ) + + output_spec = self.__dict__.get("_output_spec") + if output_spec is None: + output_spec = self.__dict__["_output_spec"] = Composite( + shape=batch_size, device=device + ).lock_() + elif self._output_spec.device != device and device is not None: + self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to( + self.device + ) + input_spec = self.__dict__.get("_input_spec") + if input_spec is None: + input_spec = self.__dict__["_input_spec"] = Composite( + shape=batch_size, device=device + ).lock_() + elif self._input_spec.device != device and device is not None: + self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(self.device) + + output_spec.unlock_() + input_spec.unlock_() + if "full_observation_spec" not in output_spec: + output_spec["full_observation_spec"] = Composite() + if "full_done_spec" not in output_spec: + output_spec["full_done_spec"] = Composite() + if "full_reward_spec" not in output_spec: + output_spec["full_reward_spec"] = Composite() + if "full_state_spec" not in input_spec: + input_spec["full_state_spec"] = Composite() + if "full_action_spec" not in input_spec: + input_spec["full_action_spec"] = Composite() + output_spec.lock_() + input_spec.lock_() + + if "is_closed" not in self.__dir__(): + self.is_closed = True self._run_type_checks = run_type_checks self._allow_done_after_reset = allow_done_after_reset + def auto_specs_( + self, + policy: Callable[[TensorDictBase], TensorDictBase], + *, + tensordict: TensorDictBase | None = None, + action_key: NestedKey | List[NestedKey] = "action", + done_key: NestedKey | List[NestedKey] | None = None, + observation_key: NestedKey | List[NestedKey] = "observation", + reward_key: NestedKey | List[NestedKey] = "reward", + batch_size: torch.Size | None = None, + ): + """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy. + + This method performs a rollout using the provided policy to infer the input and output specifications of the environment. + It updates the environment's specs for actions, observations, rewards, and done signals based on the data collected + during the rollout. + + Args: + policy (Callable[[TensorDictBase], TensorDictBase]): + A callable policy that takes a `TensorDictBase` as input and returns a `TensorDictBase` as output. + This policy is used to perform the rollout and determine the specs. + + Keyword Args: + tensordict (TensorDictBase, optional): + An optional `TensorDictBase` instance to be used as the initial state for the rollout. + If not provided, the environment's `reset` method will be called to obtain the initial state. + action_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify actions in the `TensorDictBase`. Defaults to "action". + done_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify done signals in the `TensorDictBase`. Defaults to ``None``, which will + attempt to use ["done", "terminated", "truncated"] as potential keys. + observation_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify observations in the `TensorDictBase`. Defaults to "observation". + reward_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify rewards in the `TensorDictBase`. Defaults to "reward". + + Returns: + EnvBase: The environment instance with updated specs. + + Raises: + RuntimeError: If there are keys in the output specs that are not accounted for in the provided keys. + """ + if self.batch_locked or tensordict is None: + batch_size = self.batch_size + else: + batch_size = tensordict.batch_size + if tensordict is None: + tensordict = self.reset() + + # Input specs + tensordict = policy(tensordict) + step_0 = self.step(tensordict.copy()) + tensordict2 = step_0.get("next").copy() + step_1 = self.step(policy(tensordict2).copy()) + nexts_0: TensorDictBase = step_0.pop("next") + nexts_1: TensorDictBase = step_1.pop("next") + + input_spec_stack = {} + tensordict.apply( + partial(_tensor_to_spec, stack=input_spec_stack), + tensordict2, + named=True, + nested_keys=True, + ) + input_spec = Composite(input_spec_stack, batch_size=batch_size) + if not self.batch_locked and batch_size != self.batch_size: + while input_spec.shape: + input_spec = input_spec[0] + if isinstance(action_key, NestedKey): + action_key = [action_key] + full_action_spec = input_spec.separates(*action_key, default=None) + + # Output specs + + output_spec_stack = {} + nexts_0.apply( + partial(_tensor_to_spec, stack=output_spec_stack), + nexts_1, + named=True, + nested_keys=True, + ) + + output_spec = Composite(output_spec_stack, batch_size=batch_size) + if not self.batch_locked and batch_size != self.batch_size: + while output_spec.shape: + output_spec = output_spec[0] + + if done_key is None: + done_key = ["done", "terminated", "truncated"] + full_done_spec = output_spec.separates(*done_key, default=None) + if full_done_spec is not None: + self.full_done_spec = full_done_spec + + if isinstance(reward_key, NestedKey): + reward_key = [reward_key] + full_reward_spec = output_spec.separates(*reward_key, default=None) + + if isinstance(observation_key, NestedKey): + observation_key = [observation_key] + full_observation_spec = output_spec.separates(*observation_key, default=None) + if not output_spec.is_empty(recurse=True): + raise RuntimeError( + f"Keys {list(output_spec.keys(True, True))} are unaccounted for." + ) + + if full_action_spec is not None: + self.full_action_spec = full_action_spec + if full_done_spec is not None: + self.full_done_specs = full_done_spec + if full_observation_spec is not None: + self.full_observation_spec = full_observation_spec + if full_reward_spec is not None: + self.full_reward_spec = full_reward_spec + full_state_spec = input_spec + self.full_state_spec = full_state_spec + + return self + @wraps(check_env_specs_func) def check_env_specs(self, *args, **kwargs): return check_env_specs_func(self, *args, **kwargs) @@ -475,7 +622,7 @@ def batch_size(self) -> torch.Size: in parallel). """ - _batch_size = self.__dict__["_batch_size"] + _batch_size = self.__dict__.get("_batch_size") if _batch_size is None: _batch_size = self._batch_size = torch.Size([]) return _batch_size @@ -667,8 +814,6 @@ def action_keys(self) -> List[NestedKey]: if action_keys is not None: return action_keys keys = self.full_action_spec.keys(True, True) - if not len(keys): - raise AttributeError("Could not find action spec") keys = sorted(keys, key=_repr_by_depth) self.__dict__["_action_keys"] = keys return keys @@ -827,15 +972,7 @@ def action_spec(self, value: TensorSpec) -> None: "Please use `env.action_spec_unbatched = value` to set unbatched versions instead." ) - if isinstance(value, Composite): - for _ in value.values(True, True): # noqa: B007 - break - else: - raise RuntimeError( - "An empty Composite was passed for the action spec. " - "This is currently not permitted." - ) - else: + if not isinstance(value, Composite): value = Composite( action=value.to(device), shape=self.batch_size, device=device ) @@ -892,7 +1029,6 @@ def reward_keys(self) -> List[NestedKey]: reward_keys = self.__dict__.get("_reward_keys") if reward_keys is not None: return reward_keys - reward_keys = sorted(self.full_reward_spec.keys(True, True), key=_repr_by_depth) self.__dict__["_reward_keys"] = reward_keys return reward_keys @@ -1030,15 +1166,7 @@ def reward_spec(self, value: TensorSpec) -> None: f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " "Please use `env.reward_spec_unbatched = value` to set unbatched versions instead." ) - if isinstance(value, Composite): - for _ in value.values(True, True): # noqa: B007 - break - else: - raise RuntimeError( - "An empty Composite was passed for the reward spec. " - "This is currently not permitted." - ) - else: + if not isinstance(value, Composite): value = Composite( reward=value.to(device), shape=self.batch_size, device=device ) @@ -1319,15 +1447,7 @@ def done_spec(self, value: TensorSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - if isinstance(value, Composite): - for _ in value.values(True, True): # noqa: B007 - break - else: - raise RuntimeError( - "An empty Composite was passed for the done spec. " - "This is currently not permitted." - ) - else: + if not isinstance(value, Composite): value = Composite( done=value.to(device), terminated=value.to(device), @@ -3445,3 +3565,11 @@ def _has_dynamic_specs(spec: Composite): any(s == -1 for s in spec.shape) for spec in spec.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS) ) + + +def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack): + shape = leaf.shape + if leaf_compare is not None: + shape_compare = leaf_compare.shape + shape = [s0 if s0 == s1 else -1 for s0, s1 in zip(shape, shape_compare)] + stack[name] = Unbounded(shape, device=leaf.device, dtype=leaf.dtype) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index c83591acb63..7454bce99b3 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -16,8 +16,6 @@ from enum import Enum from typing import Any, Dict, List, Union -import tensordict.base - import torch from tensordict import ( @@ -29,7 +27,7 @@ TensorDictBase, unravel_key, ) -from tensordict.base import _is_leaf_nontensor +from tensordict.base import _default_is_leaf, _is_leaf_nontensor from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa interaction_type as exploration_type, @@ -691,7 +689,11 @@ def _per_level_env_check(data0, data1, check_dtype): def check_env_specs( - env, return_contiguous=True, check_dtype=True, seed: int | None = None + env, + return_contiguous=True, + check_dtype=True, + seed: int | None = None, + tensordict: TensorDictBase | None = None, ): """Tests an environment specs against the results of short rollout. @@ -715,6 +717,7 @@ def check_env_specs( setting the rng state back to what is was isn't a feature of most environment, we leave it to the user to accomplish that. Defaults to ``None``. + tensordict (TensorDict, optional): an optional tensordict instance to use for reset. Caution: this function resets the env seed. It should be used "offline" to check that an env is adequately constructed, but it may affect the seeding @@ -732,7 +735,16 @@ def check_env_specs( ) fake_tensordict = env.fake_tensordict() - real_tensordict = env.rollout(3, return_contiguous=return_contiguous) + if not env._batch_locked and tensordict is not None: + shape = torch.broadcast_shapes(fake_tensordict.shape, tensordict.shape) + fake_tensordict = fake_tensordict.expand(shape) + tensordict = tensordict.expand(shape) + real_tensordict = env.rollout( + 3, + return_contiguous=return_contiguous, + tensordict=tensordict, + auto_reset=tensordict is None, + ) if return_contiguous: fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) @@ -743,17 +755,17 @@ def check_env_specs( ) # eliminate empty containers fake_tensordict_select = fake_tensordict.select( - *fake_tensordict.keys(True, True, is_leaf=tensordict.base._default_is_leaf) + *fake_tensordict.keys(True, True, is_leaf=_default_is_leaf) ) real_tensordict_select = real_tensordict.select( - *real_tensordict.keys(True, True, is_leaf=tensordict.base._default_is_leaf) + *real_tensordict.keys(True, True, is_leaf=_default_is_leaf) ) # check keys fake_tensordict_keys = set( - fake_tensordict.keys(True, True, is_leaf=tensordict.base._is_leaf_nontensor) + fake_tensordict.keys(True, True, is_leaf=_is_leaf_nontensor) ) real_tensordict_keys = set( - real_tensordict.keys(True, True, is_leaf=tensordict.base._is_leaf_nontensor) + real_tensordict.keys(True, True, is_leaf=_is_leaf_nontensor) ) if fake_tensordict_keys != real_tensordict_keys: raise AssertionError( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 52efb3d312b..cd7039c323d 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -1243,7 +1243,7 @@ def _compute_target(self, tensordict) -> Tensor: # unlike in continuous SAC, we can compute the exact expectation over all discrete actions next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) if next_tensordict_select is not next_tensordict: - mask = ~done.squeeze(-1) + mask = ~done next_state_value = next_state_value.new_zeros( mask.shape ).masked_scatter_(mask, next_state_value)