Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] _make_ordinal_device #2237

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Jun 18, 2024
commit 7c0a7c83b44a540dced9f60128c2128c3d0b55d2
8 changes: 2 additions & 6 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.stats import chisquare
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict.utils import _unravel_key_to_tuple

from torchrl._utils import _make_ordinal_device
from torchrl.data.tensor_specs import (
_keys_to_empty_composite_spec,
BinaryDiscreteTensorSpec,
Expand All @@ -29,11 +29,7 @@
UnboundedContinuousTensorSpec,
UnboundedDiscreteTensorSpec,
)
from torchrl.data.utils import (
_make_ordinal_device,
check_no_exclusive_keys,
consolidate_spec,
)
from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])
Expand Down
8 changes: 8 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,11 @@ def _can_be_pickled(obj):
return True
except (pickle.PickleError, AttributeError, TypeError):
return False


def _make_ordinal_device(device: torch.device):
if device is None:
return device
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import _make_ordinal_device, CloudpickleWrapper, DEVICE_TYPING
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from tensordict.utils import expand_as_right, expand_right
from torch import Tensor

from torchrl._utils import accept_remote_rref_udf_invocation
from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
RandomSampler,
Expand Down Expand Up @@ -58,7 +58,7 @@
Writer,
WriterEnsemble,
)
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.transforms.transforms import _InvertTransform


Expand Down
3 changes: 1 addition & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tensordict.memmap import MemoryMappedTensor
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from torchrl._utils import implement_for, logger as torchrl_logger
from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger
from torchrl.data.replay_buffers.checkpointers import (
ListStorageCheckpointer,
StorageCheckpointerBase,
Expand All @@ -38,7 +38,6 @@
INT_CLASSES,
tree_iter,
)
from torchrl.data.utils import _make_ordinal_device


class Storage:
Expand Down
3 changes: 1 addition & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
unravel_key,
)
from tensordict.utils import _getitem_batch_size, NestedKey
from torchrl._utils import get_binary_env_var
from torchrl.data.utils import _make_ordinal_device
from torchrl._utils import _make_ordinal_device, get_binary_env_var

DEVICE_TYPING = Union[torch.device, str, int]

Expand Down
8 changes: 0 additions & 8 deletions torchrl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,3 @@ def _find_action_space(action_space):
f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}."
)
return action_space


def _make_ordinal_device(device: torch.device):
if device is None:
return device
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
8 changes: 2 additions & 6 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@
from torch import multiprocessing as mp
from torchrl._utils import (
_check_for_faulty_process,
_make_ordinal_device,
_ProcessNoWarn,
logger as torchrl_logger,
VERBOSE,
)
from torchrl.data.tensor_specs import CompositeSpec
from torchrl.data.utils import (
_make_ordinal_device,
CloudpickleWrapper,
contains_lazy_spec,
DEVICE_TYPING,
)
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData
from torchrl.envs.env_creator import get_env_metadata

Expand Down
3 changes: 2 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tensordict.utils import NestedKey
from torchrl._utils import (
_ends_with,
_make_ordinal_device,
_replace_last,
implement_for,
prod,
Expand All @@ -31,7 +32,7 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.utils import (
_make_compatible_policy,
_repr_by_depth,
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/habitat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import importlib.util

import torch

from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl._utils import _make_ordinal_device
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import _classproperty
Expand Down
3 changes: 1 addition & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from torch import nn, Tensor
from torch.utils._pytree import tree_map

from torchrl._utils import _append_last, _ends_with, _replace_last
from torchrl._utils import _append_last, _ends_with, _make_ordinal_device, _replace_last

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
Expand All @@ -58,7 +58,6 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import _make_ordinal_device
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict
from torchrl.envs.transforms import functional as F
from torchrl.envs.transforms.utils import (
Expand Down
3 changes: 2 additions & 1 deletion torchrl/trainers/helpers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from typing import Optional

import torch
from torchrl._utils import _make_ordinal_device

from torchrl.data.replay_buffers.replay_buffers import (
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.data.utils import DEVICE_TYPING


def make_replay_buffer(
Expand Down
Loading