Skip to content

Commit

Permalink
[Refactor] More unravel fixes (pytorch#1357)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 6, 2023
1 parent ad370cb commit f9f975c
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 31 deletions.
5 changes: 3 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import pytest
import torch
from tensordict import TensorDict, unravel_key_list
from tensordict import TensorDict
from tensordict.nn import InteractionType, make_functional, TensorDictModule
from torch import nn
from torchrl._utils import unravel_key_list
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
Expand Down Expand Up @@ -1537,7 +1538,7 @@ def forward(self, in_1, in_2):
in_keys=["x"],
out_keys=["out_1", "out_2", "out_3"],
)
assert set(unravel_key_list(ensured_module.in_keys)) == {("x",)}
assert set(unravel_key_list(ensured_module.in_keys)) == {"x"}
assert isinstance(ensured_module, TensorDictModule)


Expand Down
4 changes: 2 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2830,7 +2830,7 @@ def test_trans_serial_env_check(self):
env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), NoopResetEnv())
with pytest.raises(
ValueError,
match="there is more than one done state in the parent environment",
match="The parent environment batch-size is non-null",
):
check_env_specs(env)

Expand Down Expand Up @@ -2904,7 +2904,7 @@ def test_noop_reset_env_error(self, random, device, compose):
transformed_env.append_transform(noop_reset_env)
with pytest.raises(
ValueError,
match="there is more than one done state in the parent environment",
match="The parent environment batch-size is non-null",
):
transformed_env.reset()

Expand Down
38 changes: 38 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from copy import copy
from distutils.util import strtobool
from functools import wraps

# from tensordict._tensordict import unravel_keys
from importlib import import_module
from typing import Any, Callable, cast, TypeVar, Union

Expand Down Expand Up @@ -529,3 +531,39 @@ def clone(self):
def get_trace():
"""A simple debugging util to spot where a function is being called."""
traceback.print_stack()


def unravel_key_list(key_list):
"""Temporary fix for change in behaviour in unravel_key_list."""
if isinstance(key_list, str):
raise TypeError("incompatible function arguments")
key_list_out = []
for key in key_list:
key = unravel_key(key)
if isinstance(key, tuple) and len(key) == 1:
key_list_out.append(key[0])
else:
key_list_out.append(key)
return key_list_out


def unravel_key(key):
"""Temporary fix for change in behaviour in the tensordict version.
The current behaviour is the behavious after update in tensordict.
This ensures that tests will be passing before and after merge on both parts.
"""
if not isinstance(key, (tuple, str)):
raise RuntimeError("key should be a Sequence<NestedKey>")
if isinstance(key, str):
return key
out = []
for subkey in key:
subkey = unravel_key(subkey)
if isinstance(subkey, str):
subkey = (subkey,)
out += subkey
if len(out) == 1:
return out[0]
return tuple(out)
4 changes: 2 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
import numpy as np
import torch
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import _getitem_batch_size, unravel_key
from tensordict.utils import _getitem_batch_size

from torchrl._utils import get_binary_env_var
from torchrl._utils import get_binary_env_var, unravel_key

DEVICE_TYPING = Union[torch.device, str, int]

Expand Down
7 changes: 4 additions & 3 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
ProbabilisticTensorDictModule,
repopulate_module,
)
from tensordict.utils import _normalize_key, is_seq_of_nested_key
from tensordict.utils import is_seq_of_nested_key
from torchrl._utils import unravel_key
from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs.transforms.transforms import Transform

Expand Down Expand Up @@ -159,8 +160,8 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
output_spec = super().transform_output_spec(output_spec)
# todo: here we'll need to use the reward_key once it's implemented
# parent = self.parent
in_key = _normalize_key(self.in_keys[0])
out_key = _normalize_key(self.out_keys[0])
in_key = unravel_key(self.in_keys[0])
out_key = unravel_key(self.out_keys[0])

if in_key == "reward" and out_key == "reward":
parent = self.parent
Expand Down
23 changes: 14 additions & 9 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import torch
from tensordict.nn import dispatch
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import expand_as_right, unravel_key, unravel_key_list
from tensordict.utils import expand_as_right
from torch import nn, Tensor

from torchrl._utils import unravel_key, unravel_key_list

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
BoundedTensorSpec,
Expand Down Expand Up @@ -2785,16 +2787,19 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
raise RuntimeError(
"NoopResetEnv.parent not found. Make sure that the parent is set."
)
if tensordict.get("done").numel() > 1:
done_key = parent.done_key
reward_key = parent.reward_key
if parent.batch_size.numel() > 1:
raise ValueError(
"there is more than one done state in the parent environment. "
"The parent environment batch-size is non-null. "
"NoopResetEnv is designed to work on single env instances, as partial reset "
"is currently not supported. If you feel like this is a missing feature, submit "
"an issue on TorchRL github repo. "
"In case you are trying to use NoopResetEnv over a batch of environments, know "
"that you can have a transformed batch of transformed envs, such as: "
"`TransformedEnv(ParallelEnv(3, lambda: TransformedEnv(MyEnv(), NoopResetEnv(3))), OtherTransform())`."
)

noops = (
self.noops if not self.random else torch.randint(self.noops, (1,)).item()
)
Expand All @@ -2806,7 +2811,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
i += 1
tensordict = parent.rand_step(tensordict)
tensordict = step_mdp(tensordict, exclude_done=False)
if tensordict.get("done"):
if tensordict.get(done_key):
tensordict = parent.reset(td_reset.clone(False))
break
else:
Expand All @@ -2815,15 +2820,15 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
trial += 1
if trial > _MAX_NOOPS_TRIALS:
tensordict = parent.rand_step(tensordict)
if tensordict.get(("next", "done")):
if tensordict.get(("next", done_key)):
raise RuntimeError(
f"parent is still done after a single random step (i={i})."
)
break

if tensordict.get("done"):
if tensordict.get(done_key):
raise RuntimeError("NoopResetEnv concluded with done environment")
return tensordict
return tensordict.exclude(reward_key, inplace=True)

def __repr__(self) -> str:
random = self.random
Expand Down Expand Up @@ -4129,9 +4134,9 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
output_spec["_observation_spec"][out_key] = output_spec[
"_done_spec"
].clone()
if "reward" in self.in_keys:
if ("reward",) in self.in_keys:
for i, out_key in enumerate(self.out_keys): # noqa: B007
if self.in_keys[i] == "reward":
if self.in_keys[i] == ("reward",):
break
else:
raise RuntimeError("Expected one key to be 'reward'")
Expand Down
13 changes: 7 additions & 6 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
TensorDict,
TensorDictBase,
)
from tensordict.utils import unravel_key

# from tensordict.utils import unravel_keys
from torchrl._utils import unravel_key

__all__ = [
"exploration_mode",
Expand Down Expand Up @@ -193,9 +195,9 @@ def step_mdp(
return next_tensordict
return out

action_key = unravel_key((action_key,))
done_key = unravel_key((done_key,))
reward_key = unravel_key((reward_key,))
action_key = unravel_key(action_key)
done_key = unravel_key(done_key)
reward_key = unravel_key(reward_key)

excluded = set()
if exclude_reward:
Expand All @@ -216,7 +218,6 @@ def step_mdp(
_set_single_key(tensordict, out, action_key)
for key in next_td.keys():
_set(next_td, out, key, total_key, excluded)

if next_tensordict is not None:
return next_tensordict.update(out)
else:
Expand Down Expand Up @@ -245,7 +246,7 @@ def _set_single_key(source, dest, key, clone=False):
def _set(source, dest, key, total_key, excluded):
total_key = total_key + (key,)
non_empty = False
if total_key not in excluded:
if unravel_key(total_key) not in excluded:
val = source.get(key)
if is_tensor_collection(val):
new_val = dest.get(key, None)
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import torch
from tensordict import TensorDict
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict.utils import unravel_key
from torch import multiprocessing as mp

from torchrl._utils import _check_for_faulty_process, VERBOSE
# from tensordict.utils import unravel_keys
from torchrl._utils import _check_for_faulty_process, unravel_key, VERBOSE
from torchrl.data.tensor_specs import (
CompositeSpec,
DiscreteTensorSpec,
Expand Down
4 changes: 3 additions & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from typing import Iterable, Optional, Type, Union

import torch
from tensordict import unravel_key_list

from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.tensordict import TensorDictBase

from torch import nn

# from tensordict import unravel_key_list
from torchrl._utils import unravel_key_list

from torchrl.data.tensor_specs import CompositeSpec, TensorSpec

from torchrl.data.utils import DEVICE_TYPING
Expand Down
3 changes: 2 additions & 1 deletion torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import warnings
from typing import Optional, Sequence, Type, Union

from tensordict import TensorDictBase, unravel_key_list
from tensordict import TensorDictBase

from tensordict.nn import (
InteractionType,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModule,
)
from torchrl._utils import unravel_key_list
from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.modules.distributions import Delta
from torchrl.modules.tensordict_module.common import _forward_hook_safe_action
Expand Down
6 changes: 3 additions & 3 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from tensordict.nn import TensorDictModuleBase as ModuleBase

from tensordict.tensordict import NO_DEFAULT, TensorDictBase
from tensordict.utils import prod, unravel_key_list
from tensordict.utils import prod

from torch import nn
from torchrl._utils import unravel_key_list

from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.objectives.value.functional import (
Expand Down Expand Up @@ -190,8 +191,7 @@ def __init__(
in_keys = unravel_key_list(in_keys)
out_keys = unravel_key_list(out_keys)
if not isinstance(in_keys, (tuple, list)) or (
len(in_keys) != 3
and not (len(in_keys) == 4 and in_keys[-1] == ("is_init",))
len(in_keys) != 3 and not (len(in_keys) == 4 and in_keys[-1] == "is_init")
):
raise ValueError(
f"LSTMModule expects 3 inputs: a value, and two hidden states (and potentially an 'is_init' marker). Got in_keys {in_keys} instead."
Expand Down

0 comments on commit f9f975c

Please sign in to comment.