Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] _make_ordinal_device #2237

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Jun 18, 2024
commit 748b2babe83245a1b0f80aadf3b7c41542ecc5da
50 changes: 49 additions & 1 deletion test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
UnboundedContinuousTensorSpec,
UnboundedDiscreteTensorSpec,
)
from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec
from torchrl.data.utils import (
_make_ordinal_device,
check_no_exclusive_keys,
consolidate_spec,
)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])
Expand Down Expand Up @@ -3689,6 +3693,50 @@ def test_sample(self):
assert nts.zero((2,)).shape == (2, 3, 4)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")
def test_device_ordinal():
device = torch.device("cpu")
assert _make_ordinal_device(device) == torch.device("cpu")
device = torch.device("cuda")
assert _make_ordinal_device(device) == torch.device("cuda:0")
device = torch.device("cuda:0")
assert _make_ordinal_device(device) == torch.device("cuda:0")
device = None
assert _make_ordinal_device(device) is None

device = torch.device("cuda")
unb = UnboundedContinuousTensorSpec((-1, 1, 2), device=device)
assert unb.device == torch.device("cuda:0")
unbd = UnboundedDiscreteTensorSpec((-1, 1, 2), device=device)
assert unbd.device == torch.device("cuda:0")
bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1, device=device)
assert bound.device == torch.device("cuda:0")
oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4, device=device)
assert oneh.device == torch.device("cuda:0")
disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4, device=device)
assert disc.device == torch.device("cuda:0")
moneh = MultiOneHotDiscreteTensorSpec(
shape=(-1, 1, 2, 7), nvec=[3, 4], device=device
)
assert moneh.device == torch.device("cuda:0")
mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device)
assert mdisc.device == torch.device("cuda:0")
mdisc = NonTensorSpec(shape=(-1, 1, 2, 2), device=device)
assert mdisc.device == torch.device("cuda:0")

spec = CompositeSpec(
unb=unb,
unbd=unbd,
bound=bound,
oneh=oneh,
disc=disc,
moneh=moneh,
mdisc=mdisc,
shape=(-1, 1, 2),
)
assert spec.device == torch.device("cuda:0")


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
16 changes: 11 additions & 5 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.data.utils import _make_ordinal_device, CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
Expand Down Expand Up @@ -820,10 +820,16 @@ def _get_devices(
env_device: torch.device,
device: torch.device,
):
device = torch.device(device) if device else device
storing_device = torch.device(storing_device) if storing_device else device
policy_device = torch.device(policy_device) if policy_device else device
env_device = torch.device(env_device) if env_device else device
device = _make_ordinal_device(torch.device(device) if device else device)
storing_device = _make_ordinal_device(
torch.device(storing_device) if storing_device else device
)
policy_device = _make_ordinal_device(
torch.device(policy_device) if policy_device else device
)
env_device = _make_ordinal_device(
torch.device(env_device) if env_device else device
)
if storing_device is None and (env_device == policy_device):
storing_device = env_device
return storing_device, policy_device, env_device
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
Writer,
WriterEnsemble,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.envs.transforms.transforms import _InvertTransform


Expand Down Expand Up @@ -1457,7 +1457,7 @@ def __init__(self, device: DEVICE_TYPING | None = None):
self.out = None
if device is None:
device = "cpu"
self.device = torch.device(device)
self.device = _make_ordinal_device(torch.device(device))

def __call__(self, list_of_tds):
if self.out is None:
Expand Down
9 changes: 7 additions & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
INT_CLASSES,
tree_iter,
)
from torchrl.data.utils import _make_ordinal_device


class Storage:
Expand Down Expand Up @@ -405,7 +406,7 @@ def __init__(
else:
self._len = 0
self.device = (
torch.device(device)
_make_ordinal_device(torch.device(device))
if device != "auto"
else storage.device
if storage is not None
Expand Down Expand Up @@ -983,7 +984,11 @@ def __init__(
self.scratch_dir = str(scratch_dir)
if self.scratch_dir[-1] != "/":
self.scratch_dir += "/"
self.device = torch.device(device) if device != "auto" else torch.device("cpu")
self.device = (
_make_ordinal_device(torch.device(device))
if device != "auto"
else torch.device("cpu")
)
if self.device.type != "cpu":
raise ValueError(
"Memory map device other than CPU isn't supported. To cast your data to the desired device, "
Expand Down
18 changes: 14 additions & 4 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
unravel_key,
)
from tensordict.utils import _getitem_batch_size, NestedKey

from torchrl._utils import get_binary_env_var
from torchrl.data.utils import _make_ordinal_device

DEVICE_TYPING = Union[torch.device, str, int]

Expand Down Expand Up @@ -91,7 +91,7 @@ def _default_dtype_and_device(
if dtype is None:
dtype = torch.get_default_dtype()
if device is not None:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
elif not allow_none_device:
device = torch.zeros(()).device
return dtype, device
Expand Down Expand Up @@ -536,6 +536,14 @@ def decorator(func):

return decorator

@property
def device(self) -> torch.device:
return self._device

@device.setter
def device(self, device: torch.device | None) -> None:
self._device = _make_ordinal_device(device)

def clear_device_(self):
"""A no-op for all leaf specs (which must have a device)."""
return self
Expand Down Expand Up @@ -3802,7 +3810,9 @@ def __init__(self, *args, shape=None, device=None, **kwargs):
for key, value in kwargs.items():
self.set(key, value)

_device = torch.device(device) if device is not None else device
_device = (
_make_ordinal_device(torch.device(device)) if device is not None else device
)
if len(kwargs):
for key, item in self.items():
if item is None:
Expand Down Expand Up @@ -3845,7 +3855,7 @@ def device(self, device: DEVICE_TYPING):
raise RuntimeError(
"To erase the device of a composite spec, call " "spec.clear_device_()."
)
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
self.to(device)

def clear_device_(self):
Expand Down
8 changes: 8 additions & 0 deletions torchrl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,11 @@ def _find_action_space(action_space):
f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}."
)
return action_space


def _make_ordinal_device(device: torch.device):
if device is None:
return device
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
17 changes: 12 additions & 5 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
VERBOSE,
)
from torchrl.data.tensor_specs import CompositeSpec
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.data.utils import (
_make_ordinal_device,
CloudpickleWrapper,
contains_lazy_spec,
DEVICE_TYPING,
)
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData
from torchrl.envs.env_creator import get_env_metadata

Expand Down Expand Up @@ -346,7 +351,9 @@ def __init__(
"memmap and shared memory are mutually exclusive features."
)
self._batch_size = None
self._device = torch.device(device) if device is not None else device
self._device = (
_make_ordinal_device(torch.device(device)) if device is not None else device
)
self._dummy_env_str = None
self._seeds = None
self.__dict__["_input_spec"] = None
Expand Down Expand Up @@ -835,7 +842,7 @@ def start(self) -> None:

def to(self, device: DEVICE_TYPING):
self._non_blocking = None
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
self._device = device
Expand Down Expand Up @@ -1114,7 +1121,7 @@ def __getattr__(self, attr: str) -> Any:
)

def to(self, device: DEVICE_TYPING):
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
super().to(device)
Expand Down Expand Up @@ -1789,7 +1796,7 @@ def __getattr__(self, attr: str) -> Any:
)

def to(self, device: DEVICE_TYPING):
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
super().to(device)
Expand Down
8 changes: 4 additions & 4 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.envs.utils import (
_make_compatible_policy,
_repr_by_depth,
Expand Down Expand Up @@ -154,7 +154,7 @@ def clone(self):

def to(self, device: DEVICE_TYPING) -> EnvMetaData:
if device is not None:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
device_map = {key: device for key in self.device_map}
tensordict = self.tensordict.contiguous().to(device)
specs = self.specs.to(device)
Expand Down Expand Up @@ -348,7 +348,7 @@ def __init__(
):
self.__dict__.setdefault("_batch_size", None)
if device is not None:
self.__dict__["_device"] = torch.device(device)
self.__dict__["_device"] = _make_ordinal_device(torch.device(device))
output_spec = self.__dict__.get("_output_spec")
if output_spec is not None:
self.__dict__["_output_spec"] = (
Expand Down Expand Up @@ -2947,7 +2947,7 @@ def __del__(self):
pass

def to(self, device: DEVICE_TYPING) -> EnvBase:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
self.__dict__["_input_spec"] = self.input_spec.to(device).lock_()
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/habitat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch

from torchrl.data.utils import DEVICE_TYPING
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import _classproperty
Expand Down Expand Up @@ -118,7 +118,7 @@ def _build_gym_env(self, env, pixels_only):
return super()._build_gym_env(env, pixels_only)

def to(self, device: DEVICE_TYPING) -> EnvBase:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device.type != "cuda":
raise ValueError("The device must be of type cuda for Habitat.")
device_num = device.index
Expand Down
25 changes: 17 additions & 8 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import _make_ordinal_device
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict
from torchrl.envs.transforms import functional as F
from torchrl.envs.transforms.utils import (
Expand Down Expand Up @@ -2084,13 +2085,21 @@ def __new__(cls, *args, **kwargs):

def __init__(
self,
unsqueeze_dim: int,
dim: int = None,
allow_positive_dim: bool = False,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
**kwargs,
):
if "unsqueeze_dim" in kwargs:
warnings.warn(
"The `unsqueeze_dim` kwarg will be removed in v0.6. Please use `dim` instead."
)
dim = kwargs["unsqueeze_dim"]
elif dim is None:
raise TypeError("dim must be provided.")
if in_keys is None:
in_keys = [] # default
if out_keys is None:
Expand All @@ -2106,19 +2115,19 @@ def __init__(
out_keys_inv=out_keys_inv,
)
self.allow_positive_dim = allow_positive_dim
if unsqueeze_dim >= 0 and not allow_positive_dim:
if dim >= 0 and not allow_positive_dim:
raise RuntimeError(
"unsqueeze_dim should be smaller than 0 to accommodate for "
"dim should be smaller than 0 to accommodate for "
"envs of different batch_sizes. Turn allow_positive_dim to accommodate "
"for positive unsqueeze_dim."
)
self._unsqueeze_dim = unsqueeze_dim
self._dim = dim

@property
def unsqueeze_dim(self):
if self._unsqueeze_dim >= 0 and self.parent is not None:
return len(self.parent.batch_size) + self._unsqueeze_dim
return self._unsqueeze_dim
if self._dim >= 0 and self.parent is not None:
return len(self.parent.batch_size) + self._dim
return self._dim

def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
observation = observation.unsqueeze(self.unsqueeze_dim)
Expand Down Expand Up @@ -3808,7 +3817,7 @@ def __init__(
in_keys_inv=None,
out_keys_inv=None,
):
device = self.device = torch.device(device)
device = self.device = _make_ordinal_device(torch.device(device))
self.orig_device = (
torch.device(orig_device) if orig_device is not None else orig_device
)
Expand Down
4 changes: 2 additions & 2 deletions torchrl/trainers/helpers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
)
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.utils import DEVICE_TYPING
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING


def make_replay_buffer(
device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821
) -> ReplayBuffer: # noqa: F821
"""Builds a replay buffer using the config built from ReplayArgsConfig."""
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if not cfg.prb:
sampler = RandomSampler()
else:
Expand Down
Loading