Skip to content

Commit

Permalink
[Refactor] Refactor composite spec keys to match tensordict (pytorch#956
Browse files Browse the repository at this point in the history
)
  • Loading branch information
vmoens authored Mar 8, 2023
1 parent 1d0c335 commit cdc6798
Show file tree
Hide file tree
Showing 14 changed files with 96 additions and 69 deletions.
28 changes: 20 additions & 8 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,16 +521,24 @@ def test_nested_composite_spec(self, is_complete, device, dtype):
assert set(ts.keys()) == {
"obs",
"act",
"nested_cp",
}
assert set(ts.keys(include_nested=True)) == {
"obs",
"act",
"nested_cp",
("nested_cp", "obs"),
("nested_cp", "act"),
}
assert len(ts.keys()) == len(ts.keys(yield_nesting_keys=True)) - 1
assert set(ts.keys(yield_nesting_keys=True)) == {
assert set(ts.keys(include_nested=True, leaves_only=True)) == {
"obs",
"act",
("nested_cp", "obs"),
("nested_cp", "act"),
"nested_cp",
}
assert set(ts.keys(leaves_only=True)) == {
"obs",
"act",
}
td = ts.rand()
assert isinstance(td["nested_cp"], TensorDictBase)
Expand Down Expand Up @@ -577,9 +585,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
td2 = CompositeSpec(new=None)
ts.update(td2)
assert set(ts.keys()) == {
assert set(ts.keys(include_nested=True)) == {
"obs",
"act",
"nested_cp",
("nested_cp", "obs"),
("nested_cp", "act"),
"new",
Expand All @@ -589,9 +598,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
td2 = CompositeSpec(nested_cp=CompositeSpec(new=None).to(device))
ts.update(td2)
assert set(ts.keys()) == {
assert set(ts.keys(include_nested=True)) == {
"obs",
"act",
"nested_cp",
("nested_cp", "obs"),
("nested_cp", "act"),
("nested_cp", "new"),
Expand All @@ -601,9 +611,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device))
ts.update(td2)
assert set(ts.keys()) == {
assert set(ts.keys(include_nested=True)) == {
"obs",
"act",
"nested_cp",
("nested_cp", "obs"),
("nested_cp", "act"),
}
Expand All @@ -617,9 +628,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
nested_cp=CompositeSpec(act=UnboundedContinuousTensorSpec(device=device))
)
ts.update(td2)
assert set(ts.keys()) == {
assert set(ts.keys(include_nested=True)) == {
"obs",
"act",
"nested_cp",
("nested_cp", "obs"),
("nested_cp", "act"),
}
Expand All @@ -629,7 +641,7 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
def test_keys_to_empty_composite_spec():
keys = [("key1", "out"), ("key1", "in"), "key2", ("key1", "subkey1", "subkey2")]
composite = _keys_to_empty_composite_spec(keys)
assert set(composite.keys()) == set(keys)
assert set(composite.keys(True, True)) == set(keys)


class TestEquality:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def __init__(
hasattr(self.policy, "spec")
and self.policy.spec is not None
and all(v is not None for v in self.policy.spec.values())
and set(self.policy.spec.keys()) == set(self.policy.out_keys)
and set(self.policy.spec.keys(True, True)) == set(self.policy.out_keys)
):
# if policy spec is non-empty, all the values are not None and the keys
# match the out_keys we assume the user has given all relevant information
Expand Down
67 changes: 40 additions & 27 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,13 +1862,12 @@ def type_check(
self._specs[_key].type_check(value[_key], _key)

def is_in(self, val: Union[dict, TensorDictBase]) -> bool:
return all(
[
item.is_in(val.get(key))
for (key, item) in self._specs.items()
if item is not None
]
)
for (key, item) in self._specs.items():
if item is None:
continue
if not item.is_in(val.get(key)):
return False
return True

def project(self, val: TensorDictBase) -> TensorDictBase:
for key, item in self.items():
Expand All @@ -1894,22 +1893,29 @@ def rand(self, shape=None) -> TensorDictBase:
)

def keys(
self, yield_nesting_keys: bool = False, nested_keys: bool = True
self,
include_nested: bool = False,
leaves_only: bool = False,
) -> KeysView:
"""Keys of the CompositeSpec.
The keys argument reflect those of :class:`tensordict.TensorDict`.
Args:
yield_nesting_keys (bool, optional): if :obj:`True`, the values returned
will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
will lead to the keys :obj:`["next", ("next", "obs")]`. Default is :obj:`False`, i.e.
only nested keys will be returned.
nested_keys (bool, optional): if :obj:`False`, the returned keys will not be nested. They will
include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will
represent only the immediate children of the root, and not the whole nested sequence, i.e. a
:obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys
:obj:`["next"]. Default is :obj:`True`, i.e. nested keys will be returned.
:obj:`["next"]. Default is ``False``, i.e. nested keys will not
be returned.
leaves_only (bool, optional): if :obj:`False`, the values returned
will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
will lead to the keys :obj:`["next", ("next", "obs")]`.
Default is ``False``.
"""
return _CompositeSpecKeysView(
self, _yield_nesting_keys=yield_nesting_keys, nested_keys=nested_keys
self,
include_nested=include_nested,
leaves_only=leaves_only,
)

def items(self) -> ItemsView:
Expand Down Expand Up @@ -2014,13 +2020,14 @@ def expand(self, *shape):


def _keys_to_empty_composite_spec(keys):
"""Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value."""
if not len(keys):
return
c = CompositeSpec()
for key in keys:
if isinstance(key, str):
c[key] = None
elif key[0] in c.keys(yield_nesting_keys=True):
elif key[0] in c.keys():
if c[key[0]] is None:
# if the value is None we just replace it
c[key[0]] = _keys_to_empty_composite_spec([key[1:]])
Expand All @@ -2042,28 +2049,34 @@ class _CompositeSpecKeysView:
def __init__(
self,
composite: CompositeSpec,
nested_keys: bool = True,
_yield_nesting_keys: bool = False,
include_nested,
leaves_only,
):
self.composite = composite
self._yield_nesting_keys = _yield_nesting_keys
self.nested_keys = nested_keys
self.leaves_only = leaves_only
self.include_nested = include_nested

def __iter__(
self,
):
for key, item in self.composite.items():
if self.nested_keys and isinstance(item, CompositeSpec):
for subkey in item.keys():
yield (key, *subkey) if isinstance(subkey, tuple) else (key, subkey)
if self._yield_nesting_keys:
yield key
else:
if not isinstance(item, CompositeSpec) or len(item):
if self.include_nested and isinstance(item, CompositeSpec):
for subkey in item.keys(
include_nested=True, leaves_only=self.leaves_only
):
if not isinstance(subkey, tuple):
subkey = (subkey,)
yield (key, *subkey)
if not self.leaves_only:
yield key
elif not isinstance(item, CompositeSpec) or not self.leaves_only:
yield key

def __len__(self):
i = 0
for _ in self:
i += 1
return i

def __repr__(self):
return f"_CompositeSpecKeysView(keys={list(self)})"
2 changes: 1 addition & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
)
tensordict.unlock_()

obs_keys = self.observation_spec.keys(nested_keys=False)
obs_keys = self.observation_spec.keys(True)
# we deliberately do not update the input values, but we want to keep track of
# new keys considered as "input" by inverse transforms.
in_keys = self._get_in_keys_to_exclude(tensordict)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def read_obs(
# when queried with and without pixels
observations["observation"] = observations.pop("state")
if not isinstance(observations, (TensorDict, dict)):
(key,) = itertools.islice(self.observation_spec.keys(), 1)
(key,) = itertools.islice(self.observation_spec.keys(True, True), 1)
observations = {key: observations}
observations = self.observation_spec.encode(observations)
return observations
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/r3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
if not isinstance(observation_spec, CompositeSpec):
raise ValueError("_R3MNet can only infer CompositeSpec")

keys = [key for key in observation_spec._specs.keys() if key in self.in_keys]
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
device = observation_spec[keys[0]].device
dim = observation_spec[keys[0]].shape[:-3]

Expand Down
16 changes: 8 additions & 8 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def new_fun(self, observation_spec):
if isinstance(observation_spec, CompositeSpec):
d = observation_spec._specs
for in_key, out_key in zip(self.in_keys, self.out_keys):
if in_key in observation_spec.keys():
if in_key in observation_spec.keys(True, True):
d[out_key] = function(self, observation_spec[in_key].clone())
return CompositeSpec(
d, shape=observation_spec.shape, device=observation_spec.device
Expand All @@ -85,7 +85,7 @@ def new_fun(self, input_spec):
if isinstance(input_spec, CompositeSpec):
d = input_spec._specs
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
if in_key in input_spec.keys():
if in_key in input_spec.keys(True, True):
d[out_key] = function(self, input_spec[in_key].clone())
return CompositeSpec(d, shape=input_spec.shape, device=input_spec.device)
else:
Expand Down Expand Up @@ -2066,7 +2066,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
# by def, there must be only one key
return observation_spec

keys = [key for key in observation_spec._specs.keys() if key in self.in_keys]
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]

sum_shape = sum(
[
Expand Down Expand Up @@ -2849,7 +2849,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
raise KeyError(
f"The key {in_key} was not found in the parent "
f"observation_spec with keys "
f"{list(self.parent.observation_spec.keys())}. "
f"{list(self.parent.observation_spec.keys(True))}. "
) from err

return tensordict
Expand Down Expand Up @@ -2880,7 +2880,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
episode_specs = {}
if isinstance(reward_spec, CompositeSpec):
# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
if not all(k in reward_spec.keys() for k in self.in_keys):
if not all(k in reward_spec.keys(True, True) for k in self.in_keys):
raise KeyError("Not all in_keys are present in ´reward_spec´")

# Define episode specs for all out_keys
Expand Down Expand Up @@ -3042,7 +3042,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict.exclude(*self.excluded_keys)

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
if any(key in observation_spec.keys() for key in self.excluded_keys):
if any(key in observation_spec.keys(True, True) for key in self.excluded_keys):
return CompositeSpec(
**{
key: value
Expand Down Expand Up @@ -3074,7 +3074,7 @@ def __init__(self, *selected_keys):

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.parent:
input_keys = self.parent.input_spec.keys()
input_keys = self.parent.input_spec.keys(True, True)
else:
input_keys = []
return tensordict.select(
Expand All @@ -3085,7 +3085,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.parent:
input_keys = self.parent.input_spec.keys()
input_keys = self.parent.input_spec.keys(True, True)
else:
input_keys = []
return tensordict.select(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/vip.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
if not isinstance(observation_spec, CompositeSpec):
raise ValueError("_VIPNet can only infer CompositeSpec")

keys = [key for key in observation_spec._specs.keys() if key in self.in_keys]
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
device = observation_spec[keys[0]].device
dim = observation_spec[keys[0]].shape[:-3]

Expand Down
8 changes: 5 additions & 3 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,22 +242,24 @@ def _check_isin(key, value, obs_spec, input_spec):
for _key, _value in value.items():
_check_isin(_key, _value, obs_spec, input_spec)
return
elif key in input_spec.keys(yield_nesting_keys=True):
elif key in input_spec.keys(True):
if not input_spec[key].is_in(value):
raise AssertionError(
f"input_spec.is_in failed for key {key}. "
f"Got input_spec={input_spec[key]} and real={value}."
)
return
elif key in obs_spec.keys(yield_nesting_keys=True):
elif key in obs_spec.keys(True):
if not obs_spec[key].is_in(value):
raise AssertionError(
f"obs_spec.is_in failed for key {key}. "
f"Got obs_spec={obs_spec[key]} and real={value}."
)
return
else:
raise KeyError(key)
raise KeyError(
f"key {key} was not found in input spec with keys {input_spec.keys(True)} or obs spec with keys {obs_spec.keys(True)}"
)


def _selective_unsqueeze(tensor: torch.Tensor, batch_size: torch.Size, dim: int = -1):
Expand Down
Loading

0 comments on commit cdc6798

Please sign in to comment.