diff --git a/test/test_specs.py b/test/test_specs.py index 4f371dfcded..fbc790fa872 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -521,16 +521,24 @@ def test_nested_composite_spec(self, is_complete, device, dtype): assert set(ts.keys()) == { "obs", "act", + "nested_cp", + } + assert set(ts.keys(include_nested=True)) == { + "obs", + "act", + "nested_cp", ("nested_cp", "obs"), ("nested_cp", "act"), } - assert len(ts.keys()) == len(ts.keys(yield_nesting_keys=True)) - 1 - assert set(ts.keys(yield_nesting_keys=True)) == { + assert set(ts.keys(include_nested=True, leaves_only=True)) == { "obs", "act", ("nested_cp", "obs"), ("nested_cp", "act"), - "nested_cp", + } + assert set(ts.keys(leaves_only=True)) == { + "obs", + "act", } td = ts.rand() assert isinstance(td["nested_cp"], TensorDictBase) @@ -577,9 +585,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype): ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) td2 = CompositeSpec(new=None) ts.update(td2) - assert set(ts.keys()) == { + assert set(ts.keys(include_nested=True)) == { "obs", "act", + "nested_cp", ("nested_cp", "obs"), ("nested_cp", "act"), "new", @@ -589,9 +598,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype): ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) td2 = CompositeSpec(nested_cp=CompositeSpec(new=None).to(device)) ts.update(td2) - assert set(ts.keys()) == { + assert set(ts.keys(include_nested=True)) == { "obs", "act", + "nested_cp", ("nested_cp", "obs"), ("nested_cp", "act"), ("nested_cp", "new"), @@ -601,9 +611,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype): ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device)) ts.update(td2) - assert set(ts.keys()) == { + assert set(ts.keys(include_nested=True)) == { "obs", "act", + "nested_cp", ("nested_cp", "obs"), ("nested_cp", "act"), } @@ -617,9 +628,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype): nested_cp=CompositeSpec(act=UnboundedContinuousTensorSpec(device=device)) ) ts.update(td2) - assert set(ts.keys()) == { + assert set(ts.keys(include_nested=True)) == { "obs", "act", + "nested_cp", ("nested_cp", "obs"), ("nested_cp", "act"), } @@ -629,7 +641,7 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype): def test_keys_to_empty_composite_spec(): keys = [("key1", "out"), ("key1", "in"), "key2", ("key1", "subkey1", "subkey2")] composite = _keys_to_empty_composite_spec(keys) - assert set(composite.keys()) == set(keys) + assert set(composite.keys(True, True)) == set(keys) class TestEquality: diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b5204ff06e8..65d6b6bba1a 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -447,7 +447,7 @@ def __init__( hasattr(self.policy, "spec") and self.policy.spec is not None and all(v is not None for v in self.policy.spec.values()) - and set(self.policy.spec.keys()) == set(self.policy.out_keys) + and set(self.policy.spec.keys(True, True)) == set(self.policy.out_keys) ): # if policy spec is non-empty, all the values are not None and the keys # match the out_keys we assume the user has given all relevant information diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 9c9e25485c3..5409aecec6a 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1862,13 +1862,12 @@ def type_check( self._specs[_key].type_check(value[_key], _key) def is_in(self, val: Union[dict, TensorDictBase]) -> bool: - return all( - [ - item.is_in(val.get(key)) - for (key, item) in self._specs.items() - if item is not None - ] - ) + for (key, item) in self._specs.items(): + if item is None: + continue + if not item.is_in(val.get(key)): + return False + return True def project(self, val: TensorDictBase) -> TensorDictBase: for key, item in self.items(): @@ -1894,22 +1893,29 @@ def rand(self, shape=None) -> TensorDictBase: ) def keys( - self, yield_nesting_keys: bool = False, nested_keys: bool = True + self, + include_nested: bool = False, + leaves_only: bool = False, ) -> KeysView: """Keys of the CompositeSpec. + The keys argument reflect those of :class:`tensordict.TensorDict`. + Args: - yield_nesting_keys (bool, optional): if :obj:`True`, the values returned - will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` - will lead to the keys :obj:`["next", ("next", "obs")]`. Default is :obj:`False`, i.e. - only nested keys will be returned. - nested_keys (bool, optional): if :obj:`False`, the returned keys will not be nested. They will + include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys - :obj:`["next"]. Default is :obj:`True`, i.e. nested keys will be returned. + :obj:`["next"]. Default is ``False``, i.e. nested keys will not + be returned. + leaves_only (bool, optional): if :obj:`False`, the values returned + will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` + will lead to the keys :obj:`["next", ("next", "obs")]`. + Default is ``False``. """ return _CompositeSpecKeysView( - self, _yield_nesting_keys=yield_nesting_keys, nested_keys=nested_keys + self, + include_nested=include_nested, + leaves_only=leaves_only, ) def items(self) -> ItemsView: @@ -2014,13 +2020,14 @@ def expand(self, *shape): def _keys_to_empty_composite_spec(keys): + """Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value.""" if not len(keys): return c = CompositeSpec() for key in keys: if isinstance(key, str): c[key] = None - elif key[0] in c.keys(yield_nesting_keys=True): + elif key[0] in c.keys(): if c[key[0]] is None: # if the value is None we just replace it c[key[0]] = _keys_to_empty_composite_spec([key[1:]]) @@ -2042,28 +2049,34 @@ class _CompositeSpecKeysView: def __init__( self, composite: CompositeSpec, - nested_keys: bool = True, - _yield_nesting_keys: bool = False, + include_nested, + leaves_only, ): self.composite = composite - self._yield_nesting_keys = _yield_nesting_keys - self.nested_keys = nested_keys + self.leaves_only = leaves_only + self.include_nested = include_nested def __iter__( self, ): for key, item in self.composite.items(): - if self.nested_keys and isinstance(item, CompositeSpec): - for subkey in item.keys(): - yield (key, *subkey) if isinstance(subkey, tuple) else (key, subkey) - if self._yield_nesting_keys: - yield key - else: - if not isinstance(item, CompositeSpec) or len(item): + if self.include_nested and isinstance(item, CompositeSpec): + for subkey in item.keys( + include_nested=True, leaves_only=self.leaves_only + ): + if not isinstance(subkey, tuple): + subkey = (subkey,) + yield (key, *subkey) + if not self.leaves_only: yield key + elif not isinstance(item, CompositeSpec) or not self.leaves_only: + yield key def __len__(self): i = 0 for _ in self: i += 1 return i + + def __repr__(self): + return f"_CompositeSpecKeysView(keys={list(self)})" diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 7651bb9b6b7..bfada558d23 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -389,7 +389,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: ) tensordict.unlock_() - obs_keys = self.observation_spec.keys(nested_keys=False) + obs_keys = self.observation_spec.keys(True) # we deliberately do not update the input values, but we want to keep track of # new keys considered as "input" by inverse transforms. in_keys = self._get_in_keys_to_exclude(tensordict) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 0d4ad0df03f..e0f24e5431f 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -174,7 +174,7 @@ def read_obs( # when queried with and without pixels observations["observation"] = observations.pop("state") if not isinstance(observations, (TensorDict, dict)): - (key,) = itertools.islice(self.observation_spec.keys(), 1) + (key,) = itertools.islice(self.observation_spec.keys(True, True), 1) observations = {key: observations} observations = self.observation_spec.encode(observations) return observations diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index 76841fdd504..e5bcb10d832 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -107,7 +107,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec if not isinstance(observation_spec, CompositeSpec): raise ValueError("_R3MNet can only infer CompositeSpec") - keys = [key for key in observation_spec._specs.keys() if key in self.in_keys] + keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys] device = observation_spec[keys[0]].device dim = observation_spec[keys[0]].shape[:-3] diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 61cf45af1a6..1e21e900ff8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -61,7 +61,7 @@ def new_fun(self, observation_spec): if isinstance(observation_spec, CompositeSpec): d = observation_spec._specs for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key in observation_spec.keys(): + if in_key in observation_spec.keys(True, True): d[out_key] = function(self, observation_spec[in_key].clone()) return CompositeSpec( d, shape=observation_spec.shape, device=observation_spec.device @@ -85,7 +85,7 @@ def new_fun(self, input_spec): if isinstance(input_spec, CompositeSpec): d = input_spec._specs for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - if in_key in input_spec.keys(): + if in_key in input_spec.keys(True, True): d[out_key] = function(self, input_spec[in_key].clone()) return CompositeSpec(d, shape=input_spec.shape, device=input_spec.device) else: @@ -2066,7 +2066,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec # by def, there must be only one key return observation_spec - keys = [key for key in observation_spec._specs.keys() if key in self.in_keys] + keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys] sum_shape = sum( [ @@ -2849,7 +2849,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: raise KeyError( f"The key {in_key} was not found in the parent " f"observation_spec with keys " - f"{list(self.parent.observation_spec.keys())}. " + f"{list(self.parent.observation_spec.keys(True))}. " ) from err return tensordict @@ -2880,7 +2880,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec episode_specs = {} if isinstance(reward_spec, CompositeSpec): # If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec - if not all(k in reward_spec.keys() for k in self.in_keys): + if not all(k in reward_spec.keys(True, True) for k in self.in_keys): raise KeyError("Not all in_keys are present in ´reward_spec´") # Define episode specs for all out_keys @@ -3042,7 +3042,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict.exclude(*self.excluded_keys) def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - if any(key in observation_spec.keys() for key in self.excluded_keys): + if any(key in observation_spec.keys(True, True) for key in self.excluded_keys): return CompositeSpec( **{ key: value @@ -3074,7 +3074,7 @@ def __init__(self, *selected_keys): def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if self.parent: - input_keys = self.parent.input_spec.keys() + input_keys = self.parent.input_spec.keys(True, True) else: input_keys = [] return tensordict.select( @@ -3085,7 +3085,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: def reset(self, tensordict: TensorDictBase) -> TensorDictBase: if self.parent: - input_keys = self.parent.input_spec.keys() + input_keys = self.parent.input_spec.keys(True, True) else: input_keys = [] return tensordict.select( diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index e66505b61a3..2e6771d0399 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -96,7 +96,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec if not isinstance(observation_spec, CompositeSpec): raise ValueError("_VIPNet can only infer CompositeSpec") - keys = [key for key in observation_spec._specs.keys() if key in self.in_keys] + keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys] device = observation_spec[keys[0]].device dim = observation_spec[keys[0]].shape[:-3] diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index e9d6c034b00..2cdcc0e9f3d 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -242,14 +242,14 @@ def _check_isin(key, value, obs_spec, input_spec): for _key, _value in value.items(): _check_isin(_key, _value, obs_spec, input_spec) return - elif key in input_spec.keys(yield_nesting_keys=True): + elif key in input_spec.keys(True): if not input_spec[key].is_in(value): raise AssertionError( f"input_spec.is_in failed for key {key}. " f"Got input_spec={input_spec[key]} and real={value}." ) return - elif key in obs_spec.keys(yield_nesting_keys=True): + elif key in obs_spec.keys(True): if not obs_spec[key].is_in(value): raise AssertionError( f"obs_spec.is_in failed for key {key}. " @@ -257,7 +257,9 @@ def _check_isin(key, value, obs_spec, input_spec): ) return else: - raise KeyError(key) + raise KeyError( + f"key {key} was not found in input spec with keys {input_spec.keys(True)} or obs spec with keys {obs_spec.keys(True)}" + ) def _selective_unsqueeze(tensor: torch.Tensor, batch_size: torch.Size, dim: int = -1): diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index ef237a49afc..59330035ec6 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -392,7 +392,7 @@ def _create_td(self) -> None: ) raise_no_selected_keys = False if self.selected_keys is None: - self.selected_keys = list(shared_tensordict_parent.keys()) + self.selected_keys = list(shared_tensordict_parent.keys(True)) if self.excluded_keys is not None: self.selected_keys = set(self.selected_keys).difference( self.excluded_keys @@ -430,12 +430,12 @@ def _create_td(self) -> None: ) else: if self._single_task: - self.env_input_keys = sorted(self.input_spec.keys(), key=_sort_keys) + self.env_input_keys = sorted(self.input_spec.keys(True), key=_sort_keys) else: env_input_keys = set() for meta_data in self.meta_data: env_input_keys = env_input_keys.union( - meta_data.specs["input_spec"].keys() + meta_data.specs["input_spec"].keys(True) ) self.env_input_keys = sorted(env_input_keys, key=_sort_keys) if not len(self.env_input_keys): @@ -603,7 +603,7 @@ def _step( tensordict_in = tensordict.clone(False) # update the shared tensordict to keep the input entries up-to-date self.shared_tensordict_parent.update_( - tensordict_in.select(*self.input_spec.keys(), strict=False) + tensordict_in.select(*self.input_spec.keys(True, True), strict=False) ) # If a key is both in input and output spec, we should keep it because it has been modified input_keys = set(self.input_spec.keys(True)) - set( @@ -788,7 +788,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: self._assert_tensordict_shape(tensordict) self.shared_tensordict_parent.update_( - tensordict.select(*self.input_spec.keys(), strict=False) + tensordict.select(*self.input_spec.keys(True, True), strict=False) ) for i in range(self.num_workers): self.parent_channels[i].send(("step", None)) @@ -1051,7 +1051,7 @@ def _run_worker_pipe_shared_mem( _td = tensordict.clone(recurse=False) _td = env._step(_td) if step_keys is None: - step_keys = set(env.observation_spec.keys()).union( + step_keys = set(env.observation_spec.keys(True)).union( {"done", "terminated", "reward"} ) if pin_memory: diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index a9d1e4ca942..057efa1ef3a 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -34,7 +34,7 @@ def __init__( "Environment is batch_locked. MPCPlanners need an environnement that accepts batched inputs with any batch size" ) out_keys = [action_key] - in_keys = list(env.observation_spec.keys()) + in_keys = list(env.observation_spec.keys(True, True)) super().__init__(env, in_keys=in_keys, out_keys=out_keys) self.env = env self.action_spec = env.action_spec diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 641301e9159..90929a0388f 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -67,7 +67,7 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): keys = [out_key] values = [spec] else: - keys = list(spec.keys()) + keys = list(spec.keys(True, True)) values = [spec[key] for key in keys] for _spec, _key in zip(values, keys): if _spec is None: @@ -207,15 +207,15 @@ def __init__( elif spec is None: spec = CompositeSpec() - if set(spec.keys()) != set(self.out_keys): + if set(spec.keys(True, True)) != set(self.out_keys): # then assume that all the non indicated specs are None for key in self.out_keys: if key not in spec: spec[key] = None - if set(spec.keys()) != set(self.out_keys): + if set(spec.keys(True, True)) != set(self.out_keys): raise RuntimeError( - f"spec keys and out_keys do not match, got: {set(spec.keys())} and {set(self.out_keys)} respectively" + f"spec keys and out_keys do not match, got: {set(spec.keys(True))} and {set(self.out_keys)} respectively" ) self._spec = spec diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index a399edde0b5..334df18394a 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -132,15 +132,15 @@ def __init__( elif spec is None: spec = CompositeSpec() - if set(spec.keys()) != set(self.out_keys): + if set(spec.keys(True, True)) != set(self.out_keys): # then assume that all the non indicated specs are None for key in self.out_keys: if key not in spec: spec[key] = None - if set(spec.keys()) != set(self.out_keys): + if set(spec.keys(True, True)) != set(self.out_keys): raise RuntimeError( - f"spec keys and out_keys do not match, got: {set(spec.keys())} and {set(self.out_keys)} respectively" + f"spec keys and out_keys do not match, got: {set(spec.keys(True, True))} and {set(self.out_keys)} respectively" ) self._spec = spec diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index f12ad5949bc..fac07ee0afd 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -153,8 +153,8 @@ def make_env_transforms( if not from_pixels: selected_keys = [ key - for key in env.observation_spec.keys() - if ("pixels" not in key) and (key not in env.input_spec.keys()) + for key in env.observation_spec.keys(True, True) + if ("pixels" not in key) and (key not in env.input_spec.keys(True, True)) ] # even if there is a single tensor, it'll be renamed in "observation_vector" @@ -409,7 +409,7 @@ def get_stats_random_rollout( val_stats = torch.cat(val_stats, 0) if key is None: - keys = list(proof_environment.observation_spec.keys()) + keys = list(proof_environment.observation_spec.keys(True, True)) key = keys.pop() if len(keys): raise RuntimeError( @@ -474,7 +474,7 @@ def initialize_observation_norm_transforms( return if key is None: - keys = list(proof_environment.base_env.observation_spec.keys()) + keys = list(proof_environment.base_env.observation_spec.keys(True, True)) key = keys.pop() if len(keys): raise RuntimeError(