Skip to content

Commit

Permalink
[Refactor] Follow-up on tensordict PR 473 (pytorch#1361)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 6, 2023
1 parent 0ef740e commit b5cddf0
Show file tree
Hide file tree
Showing 10 changed files with 14 additions and 59 deletions.
3 changes: 1 addition & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

import pytest
import torch
from tensordict import TensorDict
from tensordict import TensorDict, unravel_key_list
from tensordict.nn import InteractionType, make_functional, TensorDictModule
from torch import nn
from torchrl._utils import unravel_key_list
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
Expand Down
38 changes: 0 additions & 38 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from copy import copy
from distutils.util import strtobool
from functools import wraps

# from tensordict._tensordict import unravel_keys
from importlib import import_module
from typing import Any, Callable, cast, TypeVar, Union

Expand Down Expand Up @@ -531,39 +529,3 @@ def clone(self):
def get_trace():
"""A simple debugging util to spot where a function is being called."""
traceback.print_stack()


def unravel_key_list(key_list):
"""Temporary fix for change in behaviour in unravel_key_list."""
if isinstance(key_list, str):
raise TypeError("incompatible function arguments")
key_list_out = []
for key in key_list:
key = unravel_key(key)
if isinstance(key, tuple) and len(key) == 1:
key_list_out.append(key[0])
else:
key_list_out.append(key)
return key_list_out


def unravel_key(key):
"""Temporary fix for change in behaviour in the tensordict version.
The current behaviour is the behavious after update in tensordict.
This ensures that tests will be passing before and after merge on both parts.
"""
if not isinstance(key, (tuple, str)):
raise RuntimeError("key should be a Sequence<NestedKey>")
if isinstance(key, str):
return key
out = []
for subkey in key:
subkey = unravel_key(subkey)
if isinstance(subkey, str):
subkey = (subkey,)
out += subkey
if len(out) == 1:
return out[0]
return tuple(out)
3 changes: 2 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@

import numpy as np
import torch
from tensordict import unravel_key
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import _getitem_batch_size

from torchrl._utils import get_binary_env_var, unravel_key
from torchrl._utils import get_binary_env_var

DEVICE_TYPING = Union[torch.device, str, int]

Expand Down
3 changes: 1 addition & 2 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

import torch

from tensordict import TensorDictBase
from tensordict import TensorDictBase, unravel_key
from tensordict.nn import (
make_functional,
ProbabilisticTensorDictModule,
repopulate_module,
)
from tensordict.utils import is_seq_of_nested_key
from torchrl._utils import unravel_key
from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs.transforms.transforms import Transform

Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from typing import Any, List, Optional, OrderedDict, Sequence, Tuple, Union

import torch

from tensordict import unravel_key, unravel_key_list
from tensordict.nn import dispatch
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import expand_as_right
from torch import nn, Tensor

from torchrl._utils import unravel_key, unravel_key_list

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
BoundedTensorSpec,
Expand Down
5 changes: 1 addition & 4 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pkg_resources
import torch

from tensordict import is_tensor_collection
from tensordict import is_tensor_collection, unravel_key
from tensordict.nn.probabilistic import ( # noqa
# Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated!
# Please use the `set_/interaction_type` ones above with the InteractionType enum instead.
Expand All @@ -27,9 +27,6 @@
TensorDictBase,
)

# from tensordict.utils import unravel_keys
from torchrl._utils import unravel_key

__all__ = [
"exploration_mode",
"exploration_type",
Expand Down
7 changes: 3 additions & 4 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@

import numpy as np
import torch
from tensordict import TensorDict

from tensordict import TensorDict, unravel_key
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from torch import multiprocessing as mp

# from tensordict.utils import unravel_keys
from torchrl._utils import _check_for_faulty_process, unravel_key, VERBOSE
from torchrl._utils import _check_for_faulty_process, VERBOSE
from torchrl.data.tensor_specs import (
CompositeSpec,
DiscreteTensorSpec,
Expand Down
5 changes: 2 additions & 3 deletions torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@

import torch

from tensordict import unravel_key_list

from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.tensordict import TensorDictBase

from torch import nn

# from tensordict import unravel_key_list
from torchrl._utils import unravel_key_list

from torchrl.data.tensor_specs import CompositeSpec, TensorSpec

from torchrl.data.utils import DEVICE_TYPING
Expand Down
3 changes: 1 addition & 2 deletions torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
import warnings
from typing import Optional, Sequence, Type, Union

from tensordict import TensorDictBase
from tensordict import TensorDictBase, unravel_key_list

from tensordict.nn import (
InteractionType,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModule,
)
from torchrl._utils import unravel_key_list
from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.modules.distributions import Delta
from torchrl.modules.tensordict_module.common import _forward_hook_safe_action
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from typing import Optional, Tuple

import torch
from tensordict import unravel_key_list

from tensordict.nn import TensorDictModuleBase as ModuleBase

from tensordict.tensordict import NO_DEFAULT, TensorDictBase
from tensordict.utils import prod

from torch import nn
from torchrl._utils import unravel_key_list

from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.objectives.value.functional import (
Expand Down

0 comments on commit b5cddf0

Please sign in to comment.