diff --git a/test/test_specs.py b/test/test_specs.py index 5cc2ed97226..6b779811f1d 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -13,7 +13,7 @@ from scipy.stats import chisquare from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from tensordict.utils import _unravel_key_to_tuple - +from torchrl._utils import _make_ordinal_device from torchrl.data.tensor_specs import ( _keys_to_empty_composite_spec, BinaryDiscreteTensorSpec, @@ -3689,6 +3689,51 @@ def test_sample(self): assert nts.zero((2,)).shape == (2, 3, 4) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device") +def test_device_ordinal(): + device = torch.device("cpu") + assert _make_ordinal_device(device) == torch.device("cpu") + device = torch.device("cuda") + assert _make_ordinal_device(device) == torch.device("cuda:0") + device = torch.device("cuda:0") + assert _make_ordinal_device(device) == torch.device("cuda:0") + device = None + assert _make_ordinal_device(device) is None + + device = torch.device("cuda") + unb = UnboundedContinuousTensorSpec((-1, 1, 2), device=device) + assert unb.device == torch.device("cuda:0") + unbd = UnboundedDiscreteTensorSpec((-1, 1, 2), device=device) + assert unbd.device == torch.device("cuda:0") + bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1, device=device) + assert bound.device == torch.device("cuda:0") + oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4, device=device) + assert oneh.device == torch.device("cuda:0") + disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4, device=device) + assert disc.device == torch.device("cuda:0") + moneh = MultiOneHotDiscreteTensorSpec( + shape=(-1, 1, 2, 7), nvec=[3, 4], device=device + ) + assert moneh.device == torch.device("cuda:0") + mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device) + assert mdisc.device == torch.device("cuda:0") + mdisc = NonTensorSpec(shape=(-1, 1, 2, 2), device=device) + assert mdisc.device == torch.device("cuda:0") + + spec = CompositeSpec( + unb=unb, + unbd=unbd, + bound=bound, + oneh=oneh, + disc=disc, + moneh=moneh, + mdisc=mdisc, + shape=(-1, 1, 2), + device=device, + ) + assert spec.device == torch.device("cuda:0") + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 453eb7fe085..895f3d80fdc 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -778,3 +778,11 @@ def _can_be_pickled(obj): return True except (pickle.PickleError, AttributeError, TypeError): return False + + +def _make_ordinal_device(device: torch.device): + if device is None: + return device + if device.type == "cuda" and device.index is None: + return torch.device("cuda", index=torch.cuda.current_device()) + return device diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index a068536ec67..11befbf0ee3 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -40,6 +40,7 @@ from torchrl._utils import ( _check_for_faulty_process, _ends_with, + _make_ordinal_device, _ProcessNoWarn, _replace_last, accept_remote_rref_udf_invocation, @@ -822,10 +823,16 @@ def _get_devices( env_device: torch.device, device: torch.device, ): - device = torch.device(device) if device else device - storing_device = torch.device(storing_device) if storing_device else device - policy_device = torch.device(policy_device) if policy_device else device - env_device = torch.device(env_device) if env_device else device + device = _make_ordinal_device(torch.device(device) if device else device) + storing_device = _make_ordinal_device( + torch.device(storing_device) if storing_device else device + ) + policy_device = _make_ordinal_device( + torch.device(policy_device) if policy_device else device + ) + env_device = _make_ordinal_device( + torch.device(env_device) if env_device else device + ) if storing_device is None and (env_device == policy_device): storing_device = env_device return storing_device, policy_device, env_device diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 52d547297ac..ded5e77579c 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -30,7 +30,7 @@ from tensordict.utils import expand_as_right, expand_right from torch import Tensor -from torchrl._utils import accept_remote_rref_udf_invocation +from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation from torchrl.data.replay_buffers.samplers import ( PrioritizedSampler, RandomSampler, @@ -1457,7 +1457,7 @@ def __init__(self, device: DEVICE_TYPING | None = None): self.out = None if device is None: device = "cpu" - self.device = torch.device(device) + self.device = _make_ordinal_device(torch.device(device)) def __call__(self, list_of_tds): if self.out is None: diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 5ab81112ff1..3c540c7ff3e 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -25,7 +25,7 @@ from tensordict.memmap import MemoryMappedTensor from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten -from torchrl._utils import implement_for, logger as torchrl_logger +from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.checkpointers import ( ListStorageCheckpointer, StorageCheckpointerBase, @@ -405,7 +405,7 @@ def __init__( else: self._len = 0 self.device = ( - torch.device(device) + _make_ordinal_device(torch.device(device)) if device != "auto" else storage.device if storage is not None @@ -983,7 +983,11 @@ def __init__( self.scratch_dir = str(scratch_dir) if self.scratch_dir[-1] != "/": self.scratch_dir += "/" - self.device = torch.device(device) if device != "auto" else torch.device("cpu") + self.device = ( + _make_ordinal_device(torch.device(device)) + if device != "auto" + else torch.device("cpu") + ) if self.device.type != "cpu": raise ValueError( "Memory map device other than CPU isn't supported. To cast your data to the desired device, " diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 6681bab4cea..002ca9c5fde 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -39,8 +39,7 @@ unravel_key, ) from tensordict.utils import _getitem_batch_size, NestedKey - -from torchrl._utils import get_binary_env_var +from torchrl._utils import _make_ordinal_device, get_binary_env_var DEVICE_TYPING = Union[torch.device, str, int] @@ -91,7 +90,7 @@ def _default_dtype_and_device( if dtype is None: dtype = torch.get_default_dtype() if device is not None: - device = torch.device(device) + device = _make_ordinal_device(torch.device(device)) elif not allow_none_device: device = torch.zeros(()).device return dtype, device @@ -536,6 +535,14 @@ def decorator(func): return decorator + @property + def device(self) -> torch.device: + return self._device + + @device.setter + def device(self, device: torch.device | None) -> None: + self._device = _make_ordinal_device(device) + def clear_device_(self): """A no-op for all leaf specs (which must have a device).""" return self @@ -3804,7 +3811,9 @@ def __init__(self, *args, shape=None, device=None, **kwargs): for key, value in kwargs.items(): self.set(key, value) - _device = torch.device(device) if device is not None else device + _device = ( + _make_ordinal_device(torch.device(device)) if device is not None else device + ) if len(kwargs): for key, item in self.items(): if item is None: @@ -3847,7 +3856,7 @@ def device(self, device: DEVICE_TYPING): raise RuntimeError( "To erase the device of a composite spec, call " "spec.clear_device_()." ) - device = torch.device(device) + device = _make_ordinal_device(torch.device(device)) self.to(device) def clear_device_(self): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 0f9f8c9a6d7..e5aaa873870 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -31,6 +31,7 @@ from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, + _make_ordinal_device, _ProcessNoWarn, logger as torchrl_logger, VERBOSE, @@ -346,7 +347,9 @@ def __init__( "memmap and shared memory are mutually exclusive features." ) self._batch_size = None - self._device = torch.device(device) if device is not None else device + self._device = ( + _make_ordinal_device(torch.device(device)) if device is not None else device + ) self._dummy_env_str = None self._seeds = None self.__dict__["_input_spec"] = None @@ -859,7 +862,7 @@ def start(self) -> None: def to(self, device: DEVICE_TYPING): self._non_blocking = None - device = torch.device(device) + device = _make_ordinal_device(torch.device(device)) if device == self.device: return self self._device = device @@ -1152,7 +1155,7 @@ def __getattr__(self, attr: str) -> Any: ) def to(self, device: DEVICE_TYPING): - device = torch.device(device) + device = _make_ordinal_device(torch.device(device)) if device == self.device: return self super().to(device) @@ -1885,7 +1888,7 @@ def __getattr__(self, attr: str) -> Any: ) def to(self, device: DEVICE_TYPING): - device = torch.device(device) + device = _make_ordinal_device(torch.device(device)) if device == self.device: return self super().to(device) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 61ee51939e2..c965e7dedf3 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -19,6 +19,7 @@ from tensordict.utils import NestedKey from torchrl._utils import ( _ends_with, + _make_ordinal_device, _replace_last, implement_for, prod, @@ -154,7 +155,7 @@ def clone(self): def to(self, device: DEVICE_TYPING) -> EnvMetaData: if device is not None: - device = torch.device(device) + device = _make_ordinal_device(torch.device(device)) device_map = {key: device for key in self.device_map} tensordict = self.tensordict.contiguous().to(device) specs = self.specs.to(device) @@ -348,7 +349,7 @@ def __init__( ): self.__dict__.setdefault("_batch_size", None) if device is not None: - self.__dict__["_device"] = torch.device(device) + self.__dict__["_device"] = _make_ordinal_device(torch.device(device)) output_spec = self.__dict__.get("_output_spec") if output_spec is not None: self.__dict__["_output_spec"] = ( @@ -2947,7 +2948,7 @@ def __del__(self): pass def to(self, device: DEVICE_TYPING) -> EnvBase: - device = torch.device(device) + device = _make_ordinal_device(torch.device(device)) if device == self.device: return self self.__dict__["_input_spec"] = self.input_spec.to(device).lock_() diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index 894d56ef5b6..53752147acc 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -6,7 +6,7 @@ import importlib.util import torch - +from torchrl._utils import _make_ordinal_device from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.libs.gym import GymEnv, set_gym_backend @@ -118,7 +118,7 @@ def _build_gym_env(self, env, pixels_only): return super()._build_gym_env(env, pixels_only) def to(self, device: DEVICE_TYPING) -> EnvBase: - device = torch.device(device) + device = _make_ordinal_device(torch.device(device)) if device.type != "cuda": raise ValueError("The device must be of type cuda for Habitat.") device_num = device.index diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 5a313ea0ec4..7802fba368d 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -44,7 +44,7 @@ from torch import nn, Tensor from torch.utils._pytree import tree_map -from torchrl._utils import _append_last, _ends_with, _replace_last +from torchrl._utils import _append_last, _ends_with, _make_ordinal_device, _replace_last from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, @@ -2084,13 +2084,21 @@ def __new__(cls, *args, **kwargs): def __init__( self, - unsqueeze_dim: int, + dim: int = None, allow_positive_dim: bool = False, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, + **kwargs, ): + if "unsqueeze_dim" in kwargs: + warnings.warn( + "The `unsqueeze_dim` kwarg will be removed in v0.6. Please use `dim` instead." + ) + dim = kwargs["unsqueeze_dim"] + elif dim is None: + raise TypeError("dim must be provided.") if in_keys is None: in_keys = [] # default if out_keys is None: @@ -2106,19 +2114,19 @@ def __init__( out_keys_inv=out_keys_inv, ) self.allow_positive_dim = allow_positive_dim - if unsqueeze_dim >= 0 and not allow_positive_dim: + if dim >= 0 and not allow_positive_dim: raise RuntimeError( - "unsqueeze_dim should be smaller than 0 to accommodate for " + "dim should be smaller than 0 to accommodate for " "envs of different batch_sizes. Turn allow_positive_dim to accommodate " "for positive unsqueeze_dim." ) - self._unsqueeze_dim = unsqueeze_dim + self._dim = dim @property def unsqueeze_dim(self): - if self._unsqueeze_dim >= 0 and self.parent is not None: - return len(self.parent.batch_size) + self._unsqueeze_dim - return self._unsqueeze_dim + if self._dim >= 0 and self.parent is not None: + return len(self.parent.batch_size) + self._dim + return self._dim def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation = observation.unsqueeze(self.unsqueeze_dim) @@ -3808,7 +3816,7 @@ def __init__( in_keys_inv=None, out_keys_inv=None, ): - device = self.device = torch.device(device) + device = self.device = _make_ordinal_device(torch.device(device)) self.orig_device = ( torch.device(orig_device) if orig_device is not None else orig_device ) diff --git a/torchrl/trainers/helpers/replay_buffer.py b/torchrl/trainers/helpers/replay_buffer.py index de5102e0613..6ccbb15a291 100644 --- a/torchrl/trainers/helpers/replay_buffer.py +++ b/torchrl/trainers/helpers/replay_buffer.py @@ -6,6 +6,7 @@ from typing import Optional import torch +from torchrl._utils import _make_ordinal_device from torchrl.data.replay_buffers.replay_buffers import ( ReplayBuffer, @@ -20,7 +21,7 @@ def make_replay_buffer( device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821 ) -> ReplayBuffer: # noqa: F821 """Builds a replay buffer using the config built from ReplayArgsConfig.""" - device = torch.device(device) + device = _make_ordinal_device(torch.device(device)) if not cfg.prb: sampler = RandomSampler() else: