Skip to content

Commit

Permalink
[Refactor] Migration due to tensordict 473 and 474 (pytorch#1354)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 4, 2023
1 parent d12afa0 commit 75a45be
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 52 deletions.
14 changes: 14 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2634,6 +2634,20 @@ def test_valid_indexing(spec_class):
assert spec_2d[1, ..., None]["k2"].shape == torch.Size([3, 1, 4, 6, 7])


def test_composite_contains():
spec = CompositeSpec(
a=CompositeSpec(b=CompositeSpec(c=UnboundedContinuousTensorSpec()))
)
assert "a" in spec.keys()
assert "a" in spec.keys(True)
assert ("a",) in spec.keys()
assert ("a",) in spec.keys(True)
assert ("a", "b", "c") in spec.keys(True)
assert ("a", "b", "c") in spec.keys(True, True)
assert ("a", ("b", ("c",))) in spec.keys(True)
assert ("a", ("b", ("c",))) in spec.keys(True, True)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
27 changes: 14 additions & 13 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest
import torch
from tensordict import TensorDict
from tensordict import TensorDict, unravel_key_list
from tensordict.nn import InteractionType, make_functional, TensorDictModule
from torch import nn
from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -619,15 +619,16 @@ def test_vmap_probabilistic(self, safe, spec_type):


class TestTDSequence:
def test_in_key_warning(self):
with pytest.warns(UserWarning, match='key "_" is for ignoring output'):
tensordict_module = SafeModule(
nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"]
)
with pytest.warns(UserWarning, match='key "_" is for ignoring output'):
tensordict_module = SafeModule(
nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"]
)
# Temporarily disabling this test until 473 is merged in tensordict
# def test_in_key_warning(self):
# with pytest.warns(UserWarning, match='key "_" is for ignoring output'):
# tensordict_module = SafeModule(
# nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"]
# )
# with pytest.warns(UserWarning, match='key "_" is for ignoring output'):
# tensordict_module = SafeModule(
# nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"]
# )

@pytest.mark.parametrize("safe", [True, False])
@pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"])
Expand Down Expand Up @@ -1529,7 +1530,7 @@ def forward(self, in_1, in_2):
in_keys=["x"],
out_keys=["out_1", "out_2", "out_3"],
)
assert set(ensured_module.in_keys) == {"x"}
assert set(unravel_key_list(ensured_module.in_keys)) == {("x",)}
assert isinstance(ensured_module, TensorDictModule)


Expand All @@ -1554,7 +1555,7 @@ def test_errs(self):
],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
with pytest.raises(ValueError, match="in_keys"):
with pytest.raises(TypeError, match="incompatible function arguments"):
lstm_module = LSTMModule(
input_size=3,
hidden_size=12,
Expand Down Expand Up @@ -1582,7 +1583,7 @@ def test_errs(self):
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0")],
)
with pytest.raises(ValueError, match="out_keys"):
with pytest.raises(TypeError, match="incompatible function arguments"):
lstm_module = LSTMModule(
input_size=3,
hidden_size=12,
Expand Down
12 changes: 11 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import numpy as np
import torch
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import _getitem_batch_size
from tensordict.utils import _getitem_batch_size, unravel_key

from torchrl._utils import get_binary_env_var

Expand Down Expand Up @@ -3384,3 +3384,13 @@ def __len__(self):

def __repr__(self):
return f"_CompositeSpecKeysView(keys={list(self)})"

def __contains__(self, item):
item = unravel_key(item)
if len(item) == 1:
item = item[0]
for key in self.__iter__():
if key == item:
return True
else:
return False
22 changes: 11 additions & 11 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from tensordict.nn import dispatch
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import expand_as_right, unravel_keys
from tensordict.utils import expand_as_right, unravel_key, unravel_key_list
from torch import nn, Tensor

from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -3644,9 +3644,9 @@ class ExcludeTransform(Transform):
def __init__(self, *excluded_keys):
super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[])
try:
excluded_keys = [unravel_keys(key) for key in excluded_keys]
except ValueError:
raise ValueError(
excluded_keys = unravel_key_list(excluded_keys)
except TypeError:
raise TypeError(
"excluded keys must be a list or tuple of strings or tuples of strings."
)
self.excluded_keys = excluded_keys
Expand All @@ -3664,10 +3664,10 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
if any(key in observation_spec.keys(True, True) for key in self.excluded_keys):
return CompositeSpec(
**{
{
key: value
for key, value in observation_spec.items()
if key not in self.excluded_keys
if unravel_key(key) not in self.excluded_keys
},
shape=observation_spec.shape,
)
Expand All @@ -3690,9 +3690,9 @@ class SelectTransform(Transform):
def __init__(self, *selected_keys):
super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[])
try:
selected_keys = [unravel_keys(key) for key in selected_keys]
except ValueError:
raise ValueError(
selected_keys = unravel_key_list(selected_keys)
except TypeError:
raise TypeError(
"selected keys must be a list or tuple of strings or tuples of strings."
)
self.selected_keys = selected_keys
Expand All @@ -3719,10 +3719,10 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
return CompositeSpec(
**{
{
key: value
for key, value in observation_spec.items()
if key in self.selected_keys
if unravel_key(key) in self.selected_keys
},
shape=observation_spec.shape,
)
Expand Down
11 changes: 5 additions & 6 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
TensorDict,
TensorDictBase,
)
from tensordict.utils import unravel_keys
from tensordict.utils import unravel_key

__all__ = [
"exploration_mode",
Expand Down Expand Up @@ -193,9 +193,9 @@ def step_mdp(
return next_tensordict
return out

action_key = unravel_keys((action_key,))
done_key = unravel_keys((done_key,))
reward_key = unravel_keys((reward_key,))
action_key = unravel_key((action_key,))
done_key = unravel_key((done_key,))
reward_key = unravel_key((reward_key,))

excluded = set()
if exclude_reward:
Expand Down Expand Up @@ -406,7 +406,6 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
fake_tensordict = fake_tensordict.expand(*real_tensordict.shape)
else:
fake_tensordict = torch.stack([fake_tensordict.clone() for _ in range(3)], -1)

if (
fake_tensordict.apply(lambda x: torch.zeros_like(x))
!= real_tensordict.apply(lambda x: torch.zeros_like(x))
Expand Down Expand Up @@ -487,7 +486,7 @@ def __get__(self, owner_self, owner_cls):

def _sort_keys(element):
if isinstance(element, tuple):
element = unravel_keys(element)
element = unravel_key(element)
return "_-|-_".join(element)
return element

Expand Down
14 changes: 7 additions & 7 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from tensordict import TensorDict
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict.utils import unravel_keys
from tensordict.utils import unravel_key
from torch import multiprocessing as mp

from torchrl._utils import _check_for_faulty_process, VERBOSE
Expand Down Expand Up @@ -331,10 +331,10 @@ def _create_td(self) -> None:
self.env_output_keys = []
self.env_obs_keys = []
for key in self.output_spec["_observation_spec"].keys(True, True):
self.env_output_keys.append(unravel_keys(("next", key)))
self.env_output_keys.append(unravel_key(("next", key)))
self.env_obs_keys.append(key)
self.env_output_keys.append(unravel_keys(("next", self.reward_key)))
self.env_output_keys.append(unravel_keys(("next", self.done_key)))
self.env_output_keys.append(unravel_key(("next", self.reward_key)))
self.env_output_keys.append(unravel_key(("next", self.done_key)))
else:
env_input_keys = set()
for meta_data in self.meta_data:
Expand All @@ -355,15 +355,15 @@ def _create_td(self) -> None:
)
)
env_output_keys = env_output_keys.union(
unravel_keys(("next", key))
unravel_key(("next", key))
for key in meta_data.specs["output_spec"]["_observation_spec"].keys(
True, True
)
)
env_output_keys = env_output_keys.union(
{
unravel_keys(("next", self.reward_key)),
unravel_keys(("next", self.done_key)),
unravel_key(("next", self.reward_key)),
unravel_key(("next", self.done_key)),
}
)
self.env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
Expand Down
14 changes: 9 additions & 5 deletions torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Iterable, Optional, Type, Union

import torch
from tensordict import unravel_key_list

from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.tensordict import TensorDictBase
Expand Down Expand Up @@ -217,22 +218,25 @@ def register_spec(self, safe, spec):
f"got more than one out_key for the TensorDictModule: {self.out_keys},\nbut only one spec. "
"Consider using a CompositeSpec object or no spec at all."
)
spec = CompositeSpec(**{self.out_keys[0]: spec})
spec = CompositeSpec({self.out_keys[0]: spec})
elif spec is not None and isinstance(spec, CompositeSpec):
if "_" in spec.keys() and spec["_"] is not None:
warnings.warn('got a spec with key "_": it will be ignored')
elif spec is None:
spec = CompositeSpec()

if set(spec.keys(True, True)) != set(self.out_keys):
# unravel_key_list(self.out_keys) can be removed once 473 is merged in tensordict
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
out_keys = set(unravel_key_list(self.out_keys))
if spec_keys != out_keys:
# then assume that all the non indicated specs are None
for key in self.out_keys:
if key not in spec:
spec[key] = None

if set(spec.keys(True, True)) != set(self.out_keys):
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
if spec_keys != out_keys:
raise RuntimeError(
f"spec keys and out_keys do not match, got: {set(spec.keys(True))} and {set(self.out_keys)} respectively"
f"spec keys and out_keys do not match, got: {spec_keys} and {out_keys} respectively"
)

self._spec = spec
Expand Down
16 changes: 9 additions & 7 deletions torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import warnings
from typing import Optional, Sequence, Type, Union

from tensordict import TensorDictBase, unravel_key_list

from tensordict.nn import (
InteractionType,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModule,
)
from tensordict.tensordict import TensorDictBase

from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.modules.distributions import Delta
from torchrl.modules.tensordict_module.common import _forward_hook_safe_action
Expand Down Expand Up @@ -129,22 +129,24 @@ def __init__(
f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. "
"Consider using a CompositeSpec object or no spec at all."
)
spec = CompositeSpec(**{self.out_keys[0]: spec})
spec = CompositeSpec({self.out_keys[0]: spec})
elif spec is not None and isinstance(spec, CompositeSpec):
if "_" in spec.keys():
warnings.warn('got a spec with key "_": it will be ignored')
elif spec is None:
spec = CompositeSpec()

if set(spec.keys(True, True)) != set(self.out_keys):
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
out_keys = set(unravel_key_list(self.out_keys))
if spec_keys != out_keys:
# then assume that all the non indicated specs are None
for key in self.out_keys:
if key not in spec:
spec[key] = None
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))

if set(spec.keys(True, True)) != set(self.out_keys):
if spec_keys != out_keys:
raise RuntimeError(
f"spec keys and out_keys do not match, got: {set(spec.keys(True, True))} and {set(self.out_keys)} respectively"
f"spec keys and out_keys do not match, got: {spec_keys} and {out_keys} respectively"
)

self._spec = spec
Expand Down
7 changes: 5 additions & 2 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tensordict.nn import TensorDictModuleBase as ModuleBase

from tensordict.tensordict import NO_DEFAULT, TensorDictBase
from tensordict.utils import prod
from tensordict.utils import prod, unravel_key_list

from torch import nn

Expand Down Expand Up @@ -187,8 +187,11 @@ def __init__(
elif out_key:
out_keys = [out_key, *self.DEFAULT_OUT_KEYS]

in_keys = unravel_key_list(in_keys)
out_keys = unravel_key_list(out_keys)
if not isinstance(in_keys, (tuple, list)) or (
len(in_keys) != 3 and not (len(in_keys) == 4 and in_keys[-1] == "is_init")
len(in_keys) != 3
and not (len(in_keys) == 4 and in_keys[-1] == ("is_init",))
):
raise ValueError(
f"LSTMModule expects 3 inputs: a value, and two hidden states (and potentially an 'is_init' marker). Got in_keys {in_keys} instead."
Expand Down

0 comments on commit 75a45be

Please sign in to comment.