Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] _make_ordinal_device #2237

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.stats import chisquare
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict.utils import _unravel_key_to_tuple

from torchrl._utils import _make_ordinal_device
from torchrl.data.tensor_specs import (
_keys_to_empty_composite_spec,
BinaryDiscreteTensorSpec,
Expand Down Expand Up @@ -3689,6 +3689,51 @@ def test_sample(self):
assert nts.zero((2,)).shape == (2, 3, 4)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")
def test_device_ordinal():
device = torch.device("cpu")
assert _make_ordinal_device(device) == torch.device("cpu")
device = torch.device("cuda")
assert _make_ordinal_device(device) == torch.device("cuda:0")
device = torch.device("cuda:0")
assert _make_ordinal_device(device) == torch.device("cuda:0")
device = None
assert _make_ordinal_device(device) is None

device = torch.device("cuda")
unb = UnboundedContinuousTensorSpec((-1, 1, 2), device=device)
assert unb.device == torch.device("cuda:0")
unbd = UnboundedDiscreteTensorSpec((-1, 1, 2), device=device)
assert unbd.device == torch.device("cuda:0")
bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1, device=device)
assert bound.device == torch.device("cuda:0")
oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4, device=device)
assert oneh.device == torch.device("cuda:0")
disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4, device=device)
assert disc.device == torch.device("cuda:0")
moneh = MultiOneHotDiscreteTensorSpec(
shape=(-1, 1, 2, 7), nvec=[3, 4], device=device
)
assert moneh.device == torch.device("cuda:0")
mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device)
assert mdisc.device == torch.device("cuda:0")
mdisc = NonTensorSpec(shape=(-1, 1, 2, 2), device=device)
assert mdisc.device == torch.device("cuda:0")

spec = CompositeSpec(
unb=unb,
unbd=unbd,
bound=bound,
oneh=oneh,
disc=disc,
moneh=moneh,
mdisc=mdisc,
shape=(-1, 1, 2),
device=device,
)
assert spec.device == torch.device("cuda:0")


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
8 changes: 8 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,11 @@ def _can_be_pickled(obj):
return True
except (pickle.PickleError, AttributeError, TypeError):
return False


def _make_ordinal_device(device: torch.device):
if device is None:
return device
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
15 changes: 11 additions & 4 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from torchrl._utils import (
_check_for_faulty_process,
_ends_with,
_make_ordinal_device,
_ProcessNoWarn,
_replace_last,
accept_remote_rref_udf_invocation,
Expand Down Expand Up @@ -820,10 +821,16 @@ def _get_devices(
env_device: torch.device,
device: torch.device,
):
device = torch.device(device) if device else device
storing_device = torch.device(storing_device) if storing_device else device
policy_device = torch.device(policy_device) if policy_device else device
env_device = torch.device(env_device) if env_device else device
device = _make_ordinal_device(torch.device(device) if device else device)
storing_device = _make_ordinal_device(
torch.device(storing_device) if storing_device else device
)
policy_device = _make_ordinal_device(
torch.device(policy_device) if policy_device else device
)
env_device = _make_ordinal_device(
torch.device(env_device) if env_device else device
)
if storing_device is None and (env_device == policy_device):
storing_device = env_device
return storing_device, policy_device, env_device
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from tensordict.utils import expand_as_right, expand_right
from torch import Tensor

from torchrl._utils import accept_remote_rref_udf_invocation
from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
RandomSampler,
Expand Down Expand Up @@ -1457,7 +1457,7 @@ def __init__(self, device: DEVICE_TYPING | None = None):
self.out = None
if device is None:
device = "cpu"
self.device = torch.device(device)
self.device = _make_ordinal_device(torch.device(device))

def __call__(self, list_of_tds):
if self.out is None:
Expand Down
10 changes: 7 additions & 3 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tensordict.memmap import MemoryMappedTensor
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from torchrl._utils import implement_for, logger as torchrl_logger
from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger
from torchrl.data.replay_buffers.checkpointers import (
ListStorageCheckpointer,
StorageCheckpointerBase,
Expand Down Expand Up @@ -405,7 +405,7 @@ def __init__(
else:
self._len = 0
self.device = (
torch.device(device)
_make_ordinal_device(torch.device(device))
if device != "auto"
else storage.device
if storage is not None
Expand Down Expand Up @@ -983,7 +983,11 @@ def __init__(
self.scratch_dir = str(scratch_dir)
if self.scratch_dir[-1] != "/":
self.scratch_dir += "/"
self.device = torch.device(device) if device != "auto" else torch.device("cpu")
self.device = (
_make_ordinal_device(torch.device(device))
if device != "auto"
else torch.device("cpu")
)
if self.device.type != "cpu":
raise ValueError(
"Memory map device other than CPU isn't supported. To cast your data to the desired device, "
Expand Down
19 changes: 14 additions & 5 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
unravel_key,
)
from tensordict.utils import _getitem_batch_size, NestedKey

from torchrl._utils import get_binary_env_var
from torchrl._utils import _make_ordinal_device, get_binary_env_var

DEVICE_TYPING = Union[torch.device, str, int]

Expand Down Expand Up @@ -91,7 +90,7 @@ def _default_dtype_and_device(
if dtype is None:
dtype = torch.get_default_dtype()
if device is not None:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
elif not allow_none_device:
device = torch.zeros(()).device
return dtype, device
Expand Down Expand Up @@ -536,6 +535,14 @@ def decorator(func):

return decorator

@property
def device(self) -> torch.device:
return self._device

@device.setter
def device(self, device: torch.device | None) -> None:
self._device = _make_ordinal_device(device)

def clear_device_(self):
"""A no-op for all leaf specs (which must have a device)."""
return self
Expand Down Expand Up @@ -3802,7 +3809,9 @@ def __init__(self, *args, shape=None, device=None, **kwargs):
for key, value in kwargs.items():
self.set(key, value)

_device = torch.device(device) if device is not None else device
_device = (
_make_ordinal_device(torch.device(device)) if device is not None else device
)
if len(kwargs):
for key, item in self.items():
if item is None:
Expand Down Expand Up @@ -3845,7 +3854,7 @@ def device(self, device: DEVICE_TYPING):
raise RuntimeError(
"To erase the device of a composite spec, call " "spec.clear_device_()."
)
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
self.to(device)

def clear_device_(self):
Expand Down
11 changes: 7 additions & 4 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch import multiprocessing as mp
from torchrl._utils import (
_check_for_faulty_process,
_make_ordinal_device,
_ProcessNoWarn,
logger as torchrl_logger,
VERBOSE,
Expand Down Expand Up @@ -346,7 +347,9 @@ def __init__(
"memmap and shared memory are mutually exclusive features."
)
self._batch_size = None
self._device = torch.device(device) if device is not None else device
self._device = (
_make_ordinal_device(torch.device(device)) if device is not None else device
)
self._dummy_env_str = None
self._seeds = None
self.__dict__["_input_spec"] = None
Expand Down Expand Up @@ -835,7 +838,7 @@ def start(self) -> None:

def to(self, device: DEVICE_TYPING):
self._non_blocking = None
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
self._device = device
Expand Down Expand Up @@ -1114,7 +1117,7 @@ def __getattr__(self, attr: str) -> Any:
)

def to(self, device: DEVICE_TYPING):
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
super().to(device)
Expand Down Expand Up @@ -1789,7 +1792,7 @@ def __getattr__(self, attr: str) -> Any:
)

def to(self, device: DEVICE_TYPING):
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
super().to(device)
Expand Down
7 changes: 4 additions & 3 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tensordict.utils import NestedKey
from torchrl._utils import (
_ends_with,
_make_ordinal_device,
_replace_last,
implement_for,
prod,
Expand Down Expand Up @@ -154,7 +155,7 @@ def clone(self):

def to(self, device: DEVICE_TYPING) -> EnvMetaData:
if device is not None:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
device_map = {key: device for key in self.device_map}
tensordict = self.tensordict.contiguous().to(device)
specs = self.specs.to(device)
Expand Down Expand Up @@ -348,7 +349,7 @@ def __init__(
):
self.__dict__.setdefault("_batch_size", None)
if device is not None:
self.__dict__["_device"] = torch.device(device)
self.__dict__["_device"] = _make_ordinal_device(torch.device(device))
output_spec = self.__dict__.get("_output_spec")
if output_spec is not None:
self.__dict__["_output_spec"] = (
Expand Down Expand Up @@ -2947,7 +2948,7 @@ def __del__(self):
pass

def to(self, device: DEVICE_TYPING) -> EnvBase:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
self.__dict__["_input_spec"] = self.input_spec.to(device).lock_()
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/habitat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import importlib.util

import torch

from torchrl._utils import _make_ordinal_device
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
Expand Down Expand Up @@ -118,7 +118,7 @@ def _build_gym_env(self, env, pixels_only):
return super()._build_gym_env(env, pixels_only)

def to(self, device: DEVICE_TYPING) -> EnvBase:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device.type != "cuda":
raise ValueError("The device must be of type cuda for Habitat.")
device_num = device.index
Expand Down
26 changes: 17 additions & 9 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from torch import nn, Tensor
from torch.utils._pytree import tree_map

from torchrl._utils import _append_last, _ends_with, _replace_last
from torchrl._utils import _append_last, _ends_with, _make_ordinal_device, _replace_last

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
Expand Down Expand Up @@ -2084,13 +2084,21 @@ def __new__(cls, *args, **kwargs):

def __init__(
self,
unsqueeze_dim: int,
dim: int = None,
allow_positive_dim: bool = False,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
**kwargs,
):
if "unsqueeze_dim" in kwargs:
warnings.warn(
"The `unsqueeze_dim` kwarg will be removed in v0.6. Please use `dim` instead."
)
dim = kwargs["unsqueeze_dim"]
elif dim is None:
raise TypeError("dim must be provided.")
if in_keys is None:
in_keys = [] # default
if out_keys is None:
Expand All @@ -2106,19 +2114,19 @@ def __init__(
out_keys_inv=out_keys_inv,
)
self.allow_positive_dim = allow_positive_dim
if unsqueeze_dim >= 0 and not allow_positive_dim:
if dim >= 0 and not allow_positive_dim:
raise RuntimeError(
"unsqueeze_dim should be smaller than 0 to accommodate for "
"dim should be smaller than 0 to accommodate for "
"envs of different batch_sizes. Turn allow_positive_dim to accommodate "
"for positive unsqueeze_dim."
)
self._unsqueeze_dim = unsqueeze_dim
self._dim = dim

@property
def unsqueeze_dim(self):
if self._unsqueeze_dim >= 0 and self.parent is not None:
return len(self.parent.batch_size) + self._unsqueeze_dim
return self._unsqueeze_dim
if self._dim >= 0 and self.parent is not None:
return len(self.parent.batch_size) + self._dim
return self._dim

def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
observation = observation.unsqueeze(self.unsqueeze_dim)
Expand Down Expand Up @@ -3808,7 +3816,7 @@ def __init__(
in_keys_inv=None,
out_keys_inv=None,
):
device = self.device = torch.device(device)
device = self.device = _make_ordinal_device(torch.device(device))
self.orig_device = (
torch.device(orig_device) if orig_device is not None else orig_device
)
Expand Down
3 changes: 2 additions & 1 deletion torchrl/trainers/helpers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional

import torch
from torchrl._utils import _make_ordinal_device

from torchrl.data.replay_buffers.replay_buffers import (
ReplayBuffer,
Expand All @@ -20,7 +21,7 @@ def make_replay_buffer(
device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821
) -> ReplayBuffer: # noqa: F821
"""Builds a replay buffer using the config built from ReplayArgsConfig."""
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if not cfg.prb:
sampler = RandomSampler()
else:
Expand Down
Loading