Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor composite spec keys to match tensordict #956

Merged
merged 5 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Mar 8, 2023
commit 3f0357b3f63cad3524f383dacc904f7bc6cfd4bf
28 changes: 20 additions & 8 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,16 +521,24 @@ def test_nested_composite_spec(self, is_complete, device, dtype):
assert set(ts.keys()) == {
"obs",
"act",
"nested_cp",
}
assert set(ts.keys(include_nested=True)) == {
"obs",
"act",
"nested_cp",
("nested_cp", "obs"),
("nested_cp", "act"),
}
assert len(ts.keys()) == len(ts.keys(yield_nesting_keys=True)) - 1
assert set(ts.keys(yield_nesting_keys=True)) == {
assert set(ts.keys(include_nested=True, leaves_only=True)) == {
"obs",
"act",
("nested_cp", "obs"),
("nested_cp", "act"),
"nested_cp",
}
assert set(ts.keys(leaves_only=True)) == {
"obs",
"act",
}
td = ts.rand()
assert isinstance(td["nested_cp"], TensorDictBase)
Expand Down Expand Up @@ -577,9 +585,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
td2 = CompositeSpec(new=None)
ts.update(td2)
assert set(ts.keys()) == {
assert set(ts.keys(include_nested=True)) == {
"obs",
"act",
"nested_cp",
("nested_cp", "obs"),
("nested_cp", "act"),
"new",
Expand All @@ -589,9 +598,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
td2 = CompositeSpec(nested_cp=CompositeSpec(new=None).to(device))
ts.update(td2)
assert set(ts.keys()) == {
assert set(ts.keys(include_nested=True)) == {
"obs",
"act",
"nested_cp",
("nested_cp", "obs"),
("nested_cp", "act"),
("nested_cp", "new"),
Expand All @@ -601,9 +611,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device))
ts.update(td2)
assert set(ts.keys()) == {
assert set(ts.keys(include_nested=True)) == {
"obs",
"act",
"nested_cp",
("nested_cp", "obs"),
("nested_cp", "act"),
}
Expand All @@ -617,9 +628,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
nested_cp=CompositeSpec(act=UnboundedContinuousTensorSpec(device=device))
)
ts.update(td2)
assert set(ts.keys()) == {
assert set(ts.keys(include_nested=True)) == {
"obs",
"act",
"nested_cp",
("nested_cp", "obs"),
("nested_cp", "act"),
}
Expand All @@ -629,7 +641,7 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
def test_keys_to_empty_composite_spec():
keys = [("key1", "out"), ("key1", "in"), "key2", ("key1", "subkey1", "subkey2")]
composite = _keys_to_empty_composite_spec(keys)
assert set(composite.keys()) == set(keys)
assert set(composite.keys(True, True)) == set(keys)


class TestEquality:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def __init__(
hasattr(self.policy, "spec")
and self.policy.spec is not None
and all(v is not None for v in self.policy.spec.values())
and set(self.policy.spec.keys()) == set(self.policy.out_keys)
and set(self.policy.spec.keys(True, True)) == set(self.policy.out_keys)
):
# if policy spec is non-empty, all the values are not None and the keys
# match the out_keys we assume the user has given all relevant information
Expand Down
54 changes: 34 additions & 20 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,22 +1894,29 @@ def rand(self, shape=None) -> TensorDictBase:
)

def keys(
self, yield_nesting_keys: bool = False, nested_keys: bool = True
self,
include_nested: bool = False,
leaves_only: bool = False,
) -> KeysView:
"""Keys of the CompositeSpec.

The keys argument reflect those of :class:`tensordict.TensorDict`.

Args:
yield_nesting_keys (bool, optional): if :obj:`True`, the values returned
will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
will lead to the keys :obj:`["next", ("next", "obs")]`. Default is :obj:`False`, i.e.
only nested keys will be returned.
nested_keys (bool, optional): if :obj:`False`, the returned keys will not be nested. They will
include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will
represent only the immediate children of the root, and not the whole nested sequence, i.e. a
:obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys
:obj:`["next"]. Default is :obj:`True`, i.e. nested keys will be returned.
:obj:`["next"]. Default is ``False``, i.e. nested keys will not
be returned.
leaves_only (bool, optional): if :obj:`False`, the values returned
will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
will lead to the keys :obj:`["next", ("next", "obs")]`.
Default is ``False``.
"""
return _CompositeSpecKeysView(
self, _yield_nesting_keys=yield_nesting_keys, nested_keys=nested_keys
self,
include_nested=include_nested,
leaves_only=leaves_only,
)

def items(self) -> ItemsView:
Expand Down Expand Up @@ -2014,13 +2021,14 @@ def expand(self, *shape):


def _keys_to_empty_composite_spec(keys):
"""Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value."""
if not len(keys):
return
c = CompositeSpec()
for key in keys:
if isinstance(key, str):
c[key] = None
elif key[0] in c.keys(yield_nesting_keys=True):
elif key[0] in c.keys():
if c[key[0]] is None:
# if the value is None we just replace it
c[key[0]] = _keys_to_empty_composite_spec([key[1:]])
Expand All @@ -2042,28 +2050,34 @@ class _CompositeSpecKeysView:
def __init__(
self,
composite: CompositeSpec,
nested_keys: bool = True,
_yield_nesting_keys: bool = False,
include_nested,
leaves_only,
):
self.composite = composite
self._yield_nesting_keys = _yield_nesting_keys
self.nested_keys = nested_keys
self.leaves_only = leaves_only
self.include_nested = include_nested

def __iter__(
self,
):
for key, item in self.composite.items():
if self.nested_keys and isinstance(item, CompositeSpec):
for subkey in item.keys():
yield (key, *subkey) if isinstance(subkey, tuple) else (key, subkey)
if self._yield_nesting_keys:
yield key
else:
if not isinstance(item, CompositeSpec) or len(item):
if self.include_nested and isinstance(item, CompositeSpec):
for subkey in item.keys(
include_nested=True, leaves_only=self.leaves_only
):
if not isinstance(subkey, tuple):
subkey = (subkey,)
yield (key, *subkey)
if not self.leaves_only:
yield key
elif not isinstance(item, CompositeSpec) or not self.leaves_only:
yield key

def __len__(self):
i = 0
for _ in self:
i += 1
return i

def __repr__(self):
return f"_CompositeSpecKeysView(keys={list(self)})"
2 changes: 1 addition & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
)
tensordict.unlock_()

obs_keys = self.observation_spec.keys(nested_keys=False)
obs_keys = self.observation_spec.keys(True, True)
# we deliberately do not update the input values, but we want to keep track of
# new keys considered as "input" by inverse transforms.
in_keys = self._get_in_keys_to_exclude(tensordict)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def read_obs(
# when queried with and without pixels
observations["observation"] = observations.pop("state")
if not isinstance(observations, (TensorDict, dict)):
(key,) = itertools.islice(self.observation_spec.keys(), 1)
(key,) = itertools.islice(self.observation_spec.keys(True, True), 1)
observations = {key: observations}
observations = self.observation_spec.encode(observations)
return observations
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def _get_gym_envs(): # noqa: F811
def _is_from_pixels(env):
observation_spec = env.observation_space
if isinstance(observation_spec, (Dict,)):
if "pixels" in set(observation_spec.keys()):
if "pixels" in set(observation_spec.keys(True, True)):
return True
if isinstance(observation_spec, (gym.spaces.dict.Dict,)):
if "pixels" in set(observation_spec.spaces.keys()):
if "pixels" in set(observation_spec.spaces.keys(True, True)):
return True
elif (
isinstance(observation_spec, gym.spaces.Box)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/r3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
if not isinstance(observation_spec, CompositeSpec):
raise ValueError("_R3MNet can only infer CompositeSpec")

keys = [key for key in observation_spec._specs.keys() if key in self.in_keys]
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
device = observation_spec[keys[0]].device
dim = observation_spec[keys[0]].shape[:-3]

Expand Down
16 changes: 8 additions & 8 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def new_fun(self, observation_spec):
if isinstance(observation_spec, CompositeSpec):
d = observation_spec._specs
for in_key, out_key in zip(self.in_keys, self.out_keys):
if in_key in observation_spec.keys():
if in_key in observation_spec.keys(True, True):
d[out_key] = function(self, observation_spec[in_key].clone())
return CompositeSpec(
d, shape=observation_spec.shape, device=observation_spec.device
Expand All @@ -85,7 +85,7 @@ def new_fun(self, input_spec):
if isinstance(input_spec, CompositeSpec):
d = input_spec._specs
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
if in_key in input_spec.keys():
if in_key in input_spec.keys(True, True):
d[out_key] = function(self, input_spec[in_key].clone())
return CompositeSpec(d, shape=input_spec.shape, device=input_spec.device)
else:
Expand Down Expand Up @@ -2066,7 +2066,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
# by def, there must be only one key
return observation_spec

keys = [key for key in observation_spec._specs.keys() if key in self.in_keys]
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]

sum_shape = sum(
[
Expand Down Expand Up @@ -2849,7 +2849,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
raise KeyError(
f"The key {in_key} was not found in the parent "
f"observation_spec with keys "
f"{list(self.parent.observation_spec.keys())}. "
f"{list(self.parent.observation_spec.keys(True))}. "
) from err

return tensordict
Expand Down Expand Up @@ -2880,7 +2880,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
episode_specs = {}
if isinstance(reward_spec, CompositeSpec):
# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
if not all(k in reward_spec.keys() for k in self.in_keys):
if not all(k in reward_spec.keys(True, True) for k in self.in_keys):
raise KeyError("Not all in_keys are present in ´reward_spec´")

# Define episode specs for all out_keys
Expand Down Expand Up @@ -3042,7 +3042,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict.exclude(*self.excluded_keys)

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
if any(key in observation_spec.keys() for key in self.excluded_keys):
if any(key in observation_spec.keys(True, True) for key in self.excluded_keys):
return CompositeSpec(
**{
key: value
Expand Down Expand Up @@ -3074,7 +3074,7 @@ def __init__(self, *selected_keys):

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.parent:
input_keys = self.parent.input_spec.keys()
input_keys = self.parent.input_spec.keys(True, True)
else:
input_keys = []
return tensordict.select(
Expand All @@ -3085,7 +3085,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.parent:
input_keys = self.parent.input_spec.keys()
input_keys = self.parent.input_spec.keys(True, True)
else:
input_keys = []
return tensordict.select(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/vip.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
if not isinstance(observation_spec, CompositeSpec):
raise ValueError("_VIPNet can only infer CompositeSpec")

keys = [key for key in observation_spec._specs.keys() if key in self.in_keys]
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
device = observation_spec[keys[0]].device
dim = observation_spec[keys[0]].shape[:-3]

Expand Down
8 changes: 5 additions & 3 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,22 +242,24 @@ def _check_isin(key, value, obs_spec, input_spec):
for _key, _value in value.items():
_check_isin(_key, _value, obs_spec, input_spec)
return
elif key in input_spec.keys(yield_nesting_keys=True):
elif key in input_spec.keys(True):
if not input_spec[key].is_in(value):
raise AssertionError(
f"input_spec.is_in failed for key {key}. "
f"Got input_spec={input_spec[key]} and real={value}."
)
return
elif key in obs_spec.keys(yield_nesting_keys=True):
elif key in obs_spec.keys(True):
if not obs_spec[key].is_in(value):
raise AssertionError(
f"obs_spec.is_in failed for key {key}. "
f"Got obs_spec={obs_spec[key]} and real={value}."
)
return
else:
raise KeyError(key)
raise KeyError(
f"key {key} was not found in input spec with keys {input_spec.keys(True)} or obs spec with keys {obs_spec.keys(True)}"
)


def _selective_unsqueeze(tensor: torch.Tensor, batch_size: torch.Size, dim: int = -1):
Expand Down
12 changes: 7 additions & 5 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,14 @@ def _create_td(self) -> None:
)
else:
if self._single_task:
self.env_input_keys = sorted(self.input_spec.keys(), key=_sort_keys)
self.env_input_keys = sorted(
self.input_spec.keys(True, True), key=_sort_keys
)
else:
env_input_keys = set()
for meta_data in self.meta_data:
env_input_keys = env_input_keys.union(
meta_data.specs["input_spec"].keys()
meta_data.specs["input_spec"].keys(True, True)
)
self.env_input_keys = sorted(env_input_keys, key=_sort_keys)
if not len(self.env_input_keys):
Expand Down Expand Up @@ -603,7 +605,7 @@ def _step(
tensordict_in = tensordict.clone(False)
# update the shared tensordict to keep the input entries up-to-date
self.shared_tensordict_parent.update_(
tensordict_in.select(*self.input_spec.keys(), strict=False)
tensordict_in.select(*self.input_spec.keys(True, True), strict=False)
)
# If a key is both in input and output spec, we should keep it because it has been modified
input_keys = set(self.input_spec.keys(True)) - set(
Expand Down Expand Up @@ -788,7 +790,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
self._assert_tensordict_shape(tensordict)

self.shared_tensordict_parent.update_(
tensordict.select(*self.input_spec.keys(), strict=False)
tensordict.select(*self.input_spec.keys(True, True), strict=False)
)
for i in range(self.num_workers):
self.parent_channels[i].send(("step", None))
Expand Down Expand Up @@ -1051,7 +1053,7 @@ def _run_worker_pipe_shared_mem(
_td = tensordict.clone(recurse=False)
_td = env._step(_td)
if step_keys is None:
step_keys = set(env.observation_spec.keys()).union(
step_keys = set(env.observation_spec.keys(True, True)).union(
{"done", "terminated", "reward"}
)
if pin_memory:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/planners/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
"Environment is batch_locked. MPCPlanners need an environnement that accepts batched inputs with any batch size"
)
out_keys = [action_key]
in_keys = list(env.observation_spec.keys())
in_keys = list(env.observation_spec.keys(True, True))
super().__init__(env, in_keys=in_keys, out_keys=out_keys)
self.env = env
self.action_spec = env.action_spec
Expand Down
Loading