Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refacto] Migration due to tensordict 473 and 474 #1354

Merged
merged 8 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Jul 4, 2023
commit 23d2e7be2f42b26c5ccfa103b82d6d4cdfbe0708
14 changes: 14 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2634,6 +2634,20 @@ def test_valid_indexing(spec_class):
assert spec_2d[1, ..., None]["k2"].shape == torch.Size([3, 1, 4, 6, 7])


def test_composite_contains():
spec = CompositeSpec(
a=CompositeSpec(b=CompositeSpec(c=UnboundedContinuousTensorSpec()))
)
assert "a" in spec.keys()
assert "a" in spec.keys(True)
assert ("a",) in spec.keys()
assert ("a",) in spec.keys(True)
assert ("a", "b", "c") in spec.keys(True)
assert ("a", "b", "c") in spec.keys(True, True)
assert ("a", ("b", ("c",))) in spec.keys(True)
assert ("a", ("b", ("c",))) in spec.keys(True, True)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
12 changes: 11 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import numpy as np
import torch
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import _getitem_batch_size
from tensordict.utils import _getitem_batch_size, unravel_key

from torchrl._utils import get_binary_env_var

Expand Down Expand Up @@ -3384,3 +3384,13 @@ def __len__(self):

def __repr__(self):
return f"_CompositeSpecKeysView(keys={list(self)})"

def __contains__(self, item):
item = unravel_key(item)
if len(item) == 1:
item = item[0]
for key in self.__iter__():
if key == item:
return True
else:
return False
6 changes: 3 additions & 3 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from tensordict.nn import dispatch
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import expand_as_right, unravel_keys
from tensordict.utils import expand_as_right, unravel_key
from torch import nn, Tensor

from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -3644,7 +3644,7 @@ class ExcludeTransform(Transform):
def __init__(self, *excluded_keys):
super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[])
try:
excluded_keys = [unravel_keys(key) for key in excluded_keys]
excluded_keys = [unravel_key(key) for key in excluded_keys]
except ValueError:
raise ValueError(
"excluded keys must be a list or tuple of strings or tuples of strings."
Expand Down Expand Up @@ -3690,7 +3690,7 @@ class SelectTransform(Transform):
def __init__(self, *selected_keys):
super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[])
try:
selected_keys = [unravel_keys(key) for key in selected_keys]
selected_keys = [unravel_key(key) for key in selected_keys]
except ValueError:
raise ValueError(
"selected keys must be a list or tuple of strings or tuples of strings."
Expand Down
10 changes: 5 additions & 5 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
TensorDict,
TensorDictBase,
)
from tensordict.utils import unravel_keys
from tensordict.utils import unravel_key

__all__ = [
"exploration_mode",
Expand Down Expand Up @@ -193,9 +193,9 @@ def step_mdp(
return next_tensordict
return out

action_key = unravel_keys((action_key,))
done_key = unravel_keys((done_key,))
reward_key = unravel_keys((reward_key,))
action_key = unravel_key((action_key,))
done_key = unravel_key((done_key,))
reward_key = unravel_key((reward_key,))

excluded = set()
if exclude_reward:
Expand Down Expand Up @@ -487,7 +487,7 @@ def __get__(self, owner_self, owner_cls):

def _sort_keys(element):
if isinstance(element, tuple):
element = unravel_keys(element)
element = unravel_key(element)
return "_-|-_".join(element)
return element

Expand Down
14 changes: 7 additions & 7 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from tensordict import TensorDict
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict.utils import unravel_keys
from tensordict.utils import unravel_key
from torch import multiprocessing as mp

from torchrl._utils import _check_for_faulty_process, VERBOSE
Expand Down Expand Up @@ -331,10 +331,10 @@ def _create_td(self) -> None:
self.env_output_keys = []
self.env_obs_keys = []
for key in self.output_spec["_observation_spec"].keys(True, True):
self.env_output_keys.append(unravel_keys(("next", key)))
self.env_output_keys.append(unravel_key(("next", key)))
self.env_obs_keys.append(key)
self.env_output_keys.append(unravel_keys(("next", self.reward_key)))
self.env_output_keys.append(unravel_keys(("next", self.done_key)))
self.env_output_keys.append(unravel_key(("next", self.reward_key)))
self.env_output_keys.append(unravel_key(("next", self.done_key)))
else:
env_input_keys = set()
for meta_data in self.meta_data:
Expand All @@ -355,15 +355,15 @@ def _create_td(self) -> None:
)
)
env_output_keys = env_output_keys.union(
unravel_keys(("next", key))
unravel_key(("next", key))
for key in meta_data.specs["output_spec"]["_observation_spec"].keys(
True, True
)
)
env_output_keys = env_output_keys.union(
{
unravel_keys(("next", self.reward_key)),
unravel_keys(("next", self.done_key)),
unravel_key(("next", self.reward_key)),
unravel_key(("next", self.done_key)),
}
)
self.env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
Expand Down
9 changes: 7 additions & 2 deletions torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Iterable, Optional, Type, Union

import torch
from tensordict import unravel_key_list

from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.tensordict import TensorDictBase
Expand Down Expand Up @@ -224,13 +225,17 @@ def register_spec(self, safe, spec):
elif spec is None:
spec = CompositeSpec()

if set(spec.keys(True, True)) != set(self.out_keys):
if sorted(unravel_key_list(list(spec.keys(True, True))), key=str) != sorted(
self.out_keys, key=str
vmoens marked this conversation as resolved.
Show resolved Hide resolved
):
# then assume that all the non indicated specs are None
for key in self.out_keys:
if key not in spec:
spec[key] = None

if set(spec.keys(True, True)) != set(self.out_keys):
if sorted(unravel_key_list(spec.keys(True, True)), key=str) != sorted(
unravel_key_list(self.out_keys), key=str
):
raise RuntimeError(
f"spec keys and out_keys do not match, got: {set(spec.keys(True))} and {set(self.out_keys)} respectively"
)
Expand Down