diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 1f479399853..10dd004f6ca 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -7,9 +7,10 @@ import pytest import torch -from tensordict import TensorDict, unravel_key_list +from tensordict import TensorDict from tensordict.nn import InteractionType, make_functional, TensorDictModule from torch import nn +from torchrl._utils import unravel_key_list from torchrl.data.tensor_specs import ( BoundedTensorSpec, CompositeSpec, @@ -1537,7 +1538,7 @@ def forward(self, in_1, in_2): in_keys=["x"], out_keys=["out_1", "out_2", "out_3"], ) - assert set(unravel_key_list(ensured_module.in_keys)) == {("x",)} + assert set(unravel_key_list(ensured_module.in_keys)) == {"x"} assert isinstance(ensured_module, TensorDictModule) diff --git a/test/test_transforms.py b/test/test_transforms.py index 3af9443c7d9..094a70b30df 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2830,7 +2830,7 @@ def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), NoopResetEnv()) with pytest.raises( ValueError, - match="there is more than one done state in the parent environment", + match="The parent environment batch-size is non-null", ): check_env_specs(env) @@ -2904,7 +2904,7 @@ def test_noop_reset_env_error(self, random, device, compose): transformed_env.append_transform(noop_reset_env) with pytest.raises( ValueError, - match="there is more than one done state in the parent environment", + match="The parent environment batch-size is non-null", ): transformed_env.reset() diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 8d590b05210..d8eb5b5f5d3 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -17,6 +17,8 @@ from copy import copy from distutils.util import strtobool from functools import wraps + +# from tensordict._tensordict import unravel_keys from importlib import import_module from typing import Any, Callable, cast, TypeVar, Union @@ -529,3 +531,39 @@ def clone(self): def get_trace(): """A simple debugging util to spot where a function is being called.""" traceback.print_stack() + + +def unravel_key_list(key_list): + """Temporary fix for change in behaviour in unravel_key_list.""" + if isinstance(key_list, str): + raise TypeError("incompatible function arguments") + key_list_out = [] + for key in key_list: + key = unravel_key(key) + if isinstance(key, tuple) and len(key) == 1: + key_list_out.append(key[0]) + else: + key_list_out.append(key) + return key_list_out + + +def unravel_key(key): + """Temporary fix for change in behaviour in the tensordict version. + + The current behaviour is the behavious after update in tensordict. + + This ensures that tests will be passing before and after merge on both parts. + """ + if not isinstance(key, (tuple, str)): + raise RuntimeError("key should be a Sequence") + if isinstance(key, str): + return key + out = [] + for subkey in key: + subkey = unravel_key(subkey) + if isinstance(subkey, str): + subkey = (subkey,) + out += subkey + if len(out) == 1: + return out[0] + return tuple(out) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 781ab9fd8e7..091ae2fbef4 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -32,9 +32,9 @@ import numpy as np import torch from tensordict.tensordict import TensorDict, TensorDictBase -from tensordict.utils import _getitem_batch_size, unravel_key +from tensordict.utils import _getitem_batch_size -from torchrl._utils import get_binary_env_var +from torchrl._utils import get_binary_env_var, unravel_key DEVICE_TYPING = Union[torch.device, str, int] diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index a7e43867fe4..78565a0c8c5 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -12,7 +12,8 @@ ProbabilisticTensorDictModule, repopulate_module, ) -from tensordict.utils import _normalize_key, is_seq_of_nested_key +from tensordict.utils import is_seq_of_nested_key +from torchrl._utils import unravel_key from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs.transforms.transforms import Transform @@ -159,8 +160,8 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: output_spec = super().transform_output_spec(output_spec) # todo: here we'll need to use the reward_key once it's implemented # parent = self.parent - in_key = _normalize_key(self.in_keys[0]) - out_key = _normalize_key(self.out_keys[0]) + in_key = unravel_key(self.in_keys[0]) + out_key = unravel_key(self.out_keys[0]) if in_key == "reward" and out_key == "reward": parent = self.parent diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 821a69e931d..a8f5953e7ee 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -15,9 +15,11 @@ import torch from tensordict.nn import dispatch from tensordict.tensordict import TensorDict, TensorDictBase -from tensordict.utils import expand_as_right, unravel_key, unravel_key_list +from tensordict.utils import expand_as_right from torch import nn, Tensor +from torchrl._utils import unravel_key, unravel_key_list + from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -2785,9 +2787,11 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: raise RuntimeError( "NoopResetEnv.parent not found. Make sure that the parent is set." ) - if tensordict.get("done").numel() > 1: + done_key = parent.done_key + reward_key = parent.reward_key + if parent.batch_size.numel() > 1: raise ValueError( - "there is more than one done state in the parent environment. " + "The parent environment batch-size is non-null. " "NoopResetEnv is designed to work on single env instances, as partial reset " "is currently not supported. If you feel like this is a missing feature, submit " "an issue on TorchRL github repo. " @@ -2795,6 +2799,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: "that you can have a transformed batch of transformed envs, such as: " "`TransformedEnv(ParallelEnv(3, lambda: TransformedEnv(MyEnv(), NoopResetEnv(3))), OtherTransform())`." ) + noops = ( self.noops if not self.random else torch.randint(self.noops, (1,)).item() ) @@ -2806,7 +2811,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: i += 1 tensordict = parent.rand_step(tensordict) tensordict = step_mdp(tensordict, exclude_done=False) - if tensordict.get("done"): + if tensordict.get(done_key): tensordict = parent.reset(td_reset.clone(False)) break else: @@ -2815,15 +2820,15 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: trial += 1 if trial > _MAX_NOOPS_TRIALS: tensordict = parent.rand_step(tensordict) - if tensordict.get(("next", "done")): + if tensordict.get(("next", done_key)): raise RuntimeError( f"parent is still done after a single random step (i={i})." ) break - if tensordict.get("done"): + if tensordict.get(done_key): raise RuntimeError("NoopResetEnv concluded with done environment") - return tensordict + return tensordict.exclude(reward_key, inplace=True) def __repr__(self) -> str: random = self.random @@ -4129,9 +4134,9 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: output_spec["_observation_spec"][out_key] = output_spec[ "_done_spec" ].clone() - if "reward" in self.in_keys: + if ("reward",) in self.in_keys: for i, out_key in enumerate(self.out_keys): # noqa: B007 - if self.in_keys[i] == "reward": + if self.in_keys[i] == ("reward",): break else: raise RuntimeError("Expected one key to be 'reward'") diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 740ed5f2355..e7f36d40583 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -26,7 +26,9 @@ TensorDict, TensorDictBase, ) -from tensordict.utils import unravel_key + +# from tensordict.utils import unravel_keys +from torchrl._utils import unravel_key __all__ = [ "exploration_mode", @@ -193,9 +195,9 @@ def step_mdp( return next_tensordict return out - action_key = unravel_key((action_key,)) - done_key = unravel_key((done_key,)) - reward_key = unravel_key((reward_key,)) + action_key = unravel_key(action_key) + done_key = unravel_key(done_key) + reward_key = unravel_key(reward_key) excluded = set() if exclude_reward: @@ -216,7 +218,6 @@ def step_mdp( _set_single_key(tensordict, out, action_key) for key in next_td.keys(): _set(next_td, out, key, total_key, excluded) - if next_tensordict is not None: return next_tensordict.update(out) else: @@ -245,7 +246,7 @@ def _set_single_key(source, dest, key, clone=False): def _set(source, dest, key, total_key, excluded): total_key = total_key + (key,) non_empty = False - if total_key not in excluded: + if unravel_key(total_key) not in excluded: val = source.get(key) if is_tensor_collection(val): new_val = dest.get(key, None) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index e7c1dc73806..3eac84d5cc8 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -20,10 +20,10 @@ import torch from tensordict import TensorDict from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase -from tensordict.utils import unravel_key from torch import multiprocessing as mp -from torchrl._utils import _check_for_faulty_process, VERBOSE +# from tensordict.utils import unravel_keys +from torchrl._utils import _check_for_faulty_process, unravel_key, VERBOSE from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index aa0c34f98c4..7fdf391b14d 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -12,13 +12,15 @@ from typing import Iterable, Optional, Type, Union import torch -from tensordict import unravel_key_list from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.tensordict import TensorDictBase from torch import nn +# from tensordict import unravel_key_list +from torchrl._utils import unravel_key_list + from torchrl.data.tensor_specs import CompositeSpec, TensorSpec from torchrl.data.utils import DEVICE_TYPING diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 1b11e580d66..eb008f2903b 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -6,7 +6,7 @@ import warnings from typing import Optional, Sequence, Type, Union -from tensordict import TensorDictBase, unravel_key_list +from tensordict import TensorDictBase from tensordict.nn import ( InteractionType, @@ -14,6 +14,7 @@ ProbabilisticTensorDictSequential, TensorDictModule, ) +from torchrl._utils import unravel_key_list from torchrl.data.tensor_specs import CompositeSpec, TensorSpec from torchrl.modules.distributions import Delta from torchrl.modules.tensordict_module.common import _forward_hook_safe_action diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 33e71059d11..18a2fb01b36 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -9,9 +9,10 @@ from tensordict.nn import TensorDictModuleBase as ModuleBase from tensordict.tensordict import NO_DEFAULT, TensorDictBase -from tensordict.utils import prod, unravel_key_list +from tensordict.utils import prod from torch import nn +from torchrl._utils import unravel_key_list from torchrl.data import UnboundedContinuousTensorSpec from torchrl.objectives.value.functional import ( @@ -190,8 +191,7 @@ def __init__( in_keys = unravel_key_list(in_keys) out_keys = unravel_key_list(out_keys) if not isinstance(in_keys, (tuple, list)) or ( - len(in_keys) != 3 - and not (len(in_keys) == 4 and in_keys[-1] == ("is_init",)) + len(in_keys) != 3 and not (len(in_keys) == 4 and in_keys[-1] == "is_init") ): raise ValueError( f"LSTMModule expects 3 inputs: a value, and two hidden states (and potentially an 'is_init' marker). Got in_keys {in_keys} instead."