Skip to content

Commit

Permalink
[Feature] _make_ordinal_device (#2237)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 19, 2024
1 parent 038a615 commit c44a521
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 34 deletions.
47 changes: 46 additions & 1 deletion test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.stats import chisquare
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict.utils import _unravel_key_to_tuple

from torchrl._utils import _make_ordinal_device
from torchrl.data.tensor_specs import (
_keys_to_empty_composite_spec,
BinaryDiscreteTensorSpec,
Expand Down Expand Up @@ -3689,6 +3689,51 @@ def test_sample(self):
assert nts.zero((2,)).shape == (2, 3, 4)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")
def test_device_ordinal():
device = torch.device("cpu")
assert _make_ordinal_device(device) == torch.device("cpu")
device = torch.device("cuda")
assert _make_ordinal_device(device) == torch.device("cuda:0")
device = torch.device("cuda:0")
assert _make_ordinal_device(device) == torch.device("cuda:0")
device = None
assert _make_ordinal_device(device) is None

device = torch.device("cuda")
unb = UnboundedContinuousTensorSpec((-1, 1, 2), device=device)
assert unb.device == torch.device("cuda:0")
unbd = UnboundedDiscreteTensorSpec((-1, 1, 2), device=device)
assert unbd.device == torch.device("cuda:0")
bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1, device=device)
assert bound.device == torch.device("cuda:0")
oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4, device=device)
assert oneh.device == torch.device("cuda:0")
disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4, device=device)
assert disc.device == torch.device("cuda:0")
moneh = MultiOneHotDiscreteTensorSpec(
shape=(-1, 1, 2, 7), nvec=[3, 4], device=device
)
assert moneh.device == torch.device("cuda:0")
mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device)
assert mdisc.device == torch.device("cuda:0")
mdisc = NonTensorSpec(shape=(-1, 1, 2, 2), device=device)
assert mdisc.device == torch.device("cuda:0")

spec = CompositeSpec(
unb=unb,
unbd=unbd,
bound=bound,
oneh=oneh,
disc=disc,
moneh=moneh,
mdisc=mdisc,
shape=(-1, 1, 2),
device=device,
)
assert spec.device == torch.device("cuda:0")


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
8 changes: 8 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,11 @@ def _can_be_pickled(obj):
return True
except (pickle.PickleError, AttributeError, TypeError):
return False


def _make_ordinal_device(device: torch.device):
if device is None:
return device
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
15 changes: 11 additions & 4 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from torchrl._utils import (
_check_for_faulty_process,
_ends_with,
_make_ordinal_device,
_ProcessNoWarn,
_replace_last,
accept_remote_rref_udf_invocation,
Expand Down Expand Up @@ -822,10 +823,16 @@ def _get_devices(
env_device: torch.device,
device: torch.device,
):
device = torch.device(device) if device else device
storing_device = torch.device(storing_device) if storing_device else device
policy_device = torch.device(policy_device) if policy_device else device
env_device = torch.device(env_device) if env_device else device
device = _make_ordinal_device(torch.device(device) if device else device)
storing_device = _make_ordinal_device(
torch.device(storing_device) if storing_device else device
)
policy_device = _make_ordinal_device(
torch.device(policy_device) if policy_device else device
)
env_device = _make_ordinal_device(
torch.device(env_device) if env_device else device
)
if storing_device is None and (env_device == policy_device):
storing_device = env_device
return storing_device, policy_device, env_device
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from tensordict.utils import expand_as_right, expand_right
from torch import Tensor

from torchrl._utils import accept_remote_rref_udf_invocation
from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
RandomSampler,
Expand Down Expand Up @@ -1457,7 +1457,7 @@ def __init__(self, device: DEVICE_TYPING | None = None):
self.out = None
if device is None:
device = "cpu"
self.device = torch.device(device)
self.device = _make_ordinal_device(torch.device(device))

def __call__(self, list_of_tds):
if self.out is None:
Expand Down
10 changes: 7 additions & 3 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tensordict.memmap import MemoryMappedTensor
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from torchrl._utils import implement_for, logger as torchrl_logger
from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger
from torchrl.data.replay_buffers.checkpointers import (
ListStorageCheckpointer,
StorageCheckpointerBase,
Expand Down Expand Up @@ -405,7 +405,7 @@ def __init__(
else:
self._len = 0
self.device = (
torch.device(device)
_make_ordinal_device(torch.device(device))
if device != "auto"
else storage.device
if storage is not None
Expand Down Expand Up @@ -983,7 +983,11 @@ def __init__(
self.scratch_dir = str(scratch_dir)
if self.scratch_dir[-1] != "/":
self.scratch_dir += "/"
self.device = torch.device(device) if device != "auto" else torch.device("cpu")
self.device = (
_make_ordinal_device(torch.device(device))
if device != "auto"
else torch.device("cpu")
)
if self.device.type != "cpu":
raise ValueError(
"Memory map device other than CPU isn't supported. To cast your data to the desired device, "
Expand Down
19 changes: 14 additions & 5 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
unravel_key,
)
from tensordict.utils import _getitem_batch_size, NestedKey

from torchrl._utils import get_binary_env_var
from torchrl._utils import _make_ordinal_device, get_binary_env_var

DEVICE_TYPING = Union[torch.device, str, int]

Expand Down Expand Up @@ -91,7 +90,7 @@ def _default_dtype_and_device(
if dtype is None:
dtype = torch.get_default_dtype()
if device is not None:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
elif not allow_none_device:
device = torch.zeros(()).device
return dtype, device
Expand Down Expand Up @@ -536,6 +535,14 @@ def decorator(func):

return decorator

@property
def device(self) -> torch.device:
return self._device

@device.setter
def device(self, device: torch.device | None) -> None:
self._device = _make_ordinal_device(device)

def clear_device_(self):
"""A no-op for all leaf specs (which must have a device)."""
return self
Expand Down Expand Up @@ -3804,7 +3811,9 @@ def __init__(self, *args, shape=None, device=None, **kwargs):
for key, value in kwargs.items():
self.set(key, value)

_device = torch.device(device) if device is not None else device
_device = (
_make_ordinal_device(torch.device(device)) if device is not None else device
)
if len(kwargs):
for key, item in self.items():
if item is None:
Expand Down Expand Up @@ -3847,7 +3856,7 @@ def device(self, device: DEVICE_TYPING):
raise RuntimeError(
"To erase the device of a composite spec, call " "spec.clear_device_()."
)
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
self.to(device)

def clear_device_(self):
Expand Down
11 changes: 7 additions & 4 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch import multiprocessing as mp
from torchrl._utils import (
_check_for_faulty_process,
_make_ordinal_device,
_ProcessNoWarn,
logger as torchrl_logger,
VERBOSE,
Expand Down Expand Up @@ -346,7 +347,9 @@ def __init__(
"memmap and shared memory are mutually exclusive features."
)
self._batch_size = None
self._device = torch.device(device) if device is not None else device
self._device = (
_make_ordinal_device(torch.device(device)) if device is not None else device
)
self._dummy_env_str = None
self._seeds = None
self.__dict__["_input_spec"] = None
Expand Down Expand Up @@ -859,7 +862,7 @@ def start(self) -> None:

def to(self, device: DEVICE_TYPING):
self._non_blocking = None
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
self._device = device
Expand Down Expand Up @@ -1152,7 +1155,7 @@ def __getattr__(self, attr: str) -> Any:
)

def to(self, device: DEVICE_TYPING):
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
super().to(device)
Expand Down Expand Up @@ -1885,7 +1888,7 @@ def __getattr__(self, attr: str) -> Any:
)

def to(self, device: DEVICE_TYPING):
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
super().to(device)
Expand Down
7 changes: 4 additions & 3 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tensordict.utils import NestedKey
from torchrl._utils import (
_ends_with,
_make_ordinal_device,
_replace_last,
implement_for,
prod,
Expand Down Expand Up @@ -154,7 +155,7 @@ def clone(self):

def to(self, device: DEVICE_TYPING) -> EnvMetaData:
if device is not None:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
device_map = {key: device for key in self.device_map}
tensordict = self.tensordict.contiguous().to(device)
specs = self.specs.to(device)
Expand Down Expand Up @@ -348,7 +349,7 @@ def __init__(
):
self.__dict__.setdefault("_batch_size", None)
if device is not None:
self.__dict__["_device"] = torch.device(device)
self.__dict__["_device"] = _make_ordinal_device(torch.device(device))
output_spec = self.__dict__.get("_output_spec")
if output_spec is not None:
self.__dict__["_output_spec"] = (
Expand Down Expand Up @@ -2947,7 +2948,7 @@ def __del__(self):
pass

def to(self, device: DEVICE_TYPING) -> EnvBase:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
self.__dict__["_input_spec"] = self.input_spec.to(device).lock_()
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/habitat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import importlib.util

import torch

from torchrl._utils import _make_ordinal_device
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
Expand Down Expand Up @@ -118,7 +118,7 @@ def _build_gym_env(self, env, pixels_only):
return super()._build_gym_env(env, pixels_only)

def to(self, device: DEVICE_TYPING) -> EnvBase:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device.type != "cuda":
raise ValueError("The device must be of type cuda for Habitat.")
device_num = device.index
Expand Down
26 changes: 17 additions & 9 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from torch import nn, Tensor
from torch.utils._pytree import tree_map

from torchrl._utils import _append_last, _ends_with, _replace_last
from torchrl._utils import _append_last, _ends_with, _make_ordinal_device, _replace_last

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
Expand Down Expand Up @@ -2084,13 +2084,21 @@ def __new__(cls, *args, **kwargs):

def __init__(
self,
unsqueeze_dim: int,
dim: int = None,
allow_positive_dim: bool = False,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
**kwargs,
):
if "unsqueeze_dim" in kwargs:
warnings.warn(
"The `unsqueeze_dim` kwarg will be removed in v0.6. Please use `dim` instead."
)
dim = kwargs["unsqueeze_dim"]
elif dim is None:
raise TypeError("dim must be provided.")
if in_keys is None:
in_keys = [] # default
if out_keys is None:
Expand All @@ -2106,19 +2114,19 @@ def __init__(
out_keys_inv=out_keys_inv,
)
self.allow_positive_dim = allow_positive_dim
if unsqueeze_dim >= 0 and not allow_positive_dim:
if dim >= 0 and not allow_positive_dim:
raise RuntimeError(
"unsqueeze_dim should be smaller than 0 to accommodate for "
"dim should be smaller than 0 to accommodate for "
"envs of different batch_sizes. Turn allow_positive_dim to accommodate "
"for positive unsqueeze_dim."
)
self._unsqueeze_dim = unsqueeze_dim
self._dim = dim

@property
def unsqueeze_dim(self):
if self._unsqueeze_dim >= 0 and self.parent is not None:
return len(self.parent.batch_size) + self._unsqueeze_dim
return self._unsqueeze_dim
if self._dim >= 0 and self.parent is not None:
return len(self.parent.batch_size) + self._dim
return self._dim

def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
observation = observation.unsqueeze(self.unsqueeze_dim)
Expand Down Expand Up @@ -3808,7 +3816,7 @@ def __init__(
in_keys_inv=None,
out_keys_inv=None,
):
device = self.device = torch.device(device)
device = self.device = _make_ordinal_device(torch.device(device))
self.orig_device = (
torch.device(orig_device) if orig_device is not None else orig_device
)
Expand Down
3 changes: 2 additions & 1 deletion torchrl/trainers/helpers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional

import torch
from torchrl._utils import _make_ordinal_device

from torchrl.data.replay_buffers.replay_buffers import (
ReplayBuffer,
Expand All @@ -20,7 +21,7 @@ def make_replay_buffer(
device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821
) -> ReplayBuffer: # noqa: F821
"""Builds a replay buffer using the config built from ReplayArgsConfig."""
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if not cfg.prb:
sampler = RandomSampler()
else:
Expand Down

0 comments on commit c44a521

Please sign in to comment.