diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh index c1dde8bb7d0..75112146f9e 100755 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -39,7 +39,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y else - conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 -c pytorch -c nvidia -y + conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 "numpy<2.0" -c pytorch -c nvidia -y fi # Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 75a646c25c4..e202496852a 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -128,13 +128,11 @@ jobs: with: repository: pytorch/rl runner: "linux.g5.4xlarge.nvidia.gpu" - # gpu-arch-type: cuda - # gpu-arch-version: "11.7" docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | set -euo pipefail - export PYTHON_VERSION="3.9" + export PYTHON_VERSION="3.8" export CU_VERSION="cu116" export TAR_OPTIONS="--no-same-owner" if [[ "${{ github.ref }}" =~ release/* ]]; then diff --git a/test/test_distributions.py b/test/test_distributions.py index e283fb9a9b8..9a09834f0e9 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -6,6 +6,7 @@ import argparse import importlib.util import os +from typing import Tuple import pytest import torch @@ -685,7 +686,7 @@ class TestOrdinal: @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("logit_shape", [(10,), (1, 1), (10, 10), (5, 10, 20)]) def test_correct_sampling_shape( - self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str + self, logit_shape: Tuple[int, ...], dtype: torch.dtype, device: str ) -> None: logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device) @@ -753,7 +754,7 @@ class TestOneHotOrdinal: @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("logit_shape", [(10,), (10, 10), (5, 10, 20)]) def test_correct_sampling_shape( - self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str + self, logit_shape: Tuple[int, ...], dtype: torch.dtype, device: str ) -> None: logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device) diff --git a/test/test_transforms.py b/test/test_transforms.py index b0d660ecedc..679057bd6a6 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2089,10 +2089,17 @@ def make_env(max_steps=4): total_frames=99, frames_per_batch=8, ) - for d in collector: - # The env has one more traj because the collector calls reset during init - assert d["collector", "traj_ids"].max() == d["next", "traj_count"].max() - 1 - assert d["traj_count"].max() > 0 + + try: + traj_ids_collector = [] + traj_ids_env = [] + for d in collector: + traj_ids_collector.extend(d["collector", "traj_ids"].view(-1).tolist()) + traj_ids_env.extend(d["next", "traj_count"].view(-1).tolist()) + assert len(set(traj_ids_env)) == len(set(traj_ids_collector)) + finally: + collector.shutdown() + del collector def test_transform_compose(self): t = TrajCounter() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 3590d76d2ce..7fbfaab3280 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -42,7 +42,7 @@ ) from tensordict.base import NO_DEFAULT from tensordict.utils import _getitem_batch_size, NestedKey -from torchrl._utils import _make_ordinal_device, get_binary_env_var +from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for DEVICE_TYPING = Union[torch.device, str, int] @@ -193,14 +193,14 @@ def _slice_indexing(shape: list[int], idx: slice) -> List[int]: def _shape_indexing( - shape: Union[list[int], torch.Size, tuple[int]], idx: SHAPE_INDEX_TYPING + shape: Union[list[int], torch.Size, Tuple[int]], idx: SHAPE_INDEX_TYPING ) -> List[int]: """Given an input shape and an index, returns the size of the resulting indexed spec. This function includes indexing checks and may raise IndexErrors. Args: - shape (list[int], torch.Size, tuple[int): Input shape + shape (list[int], torch.Size, Tuple[int): Input shape idx (SHAPE_INDEX_TYPING): Index Returns: Shape of the resulting spec @@ -1020,7 +1020,7 @@ def unbind(self, dim: int = 0): class _LazyStackedMixin(Generic[T]): - def __init__(self, *specs: tuple[T, ...], dim: int) -> None: + def __init__(self, *specs: Tuple[T, ...], dim: int) -> None: self._specs = list(specs) self.dim = dim if self.dim < 0: @@ -1682,7 +1682,31 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) + @implement_for("torch", None, "2.1") def rand(self, shape: torch.Size = None) -> torch.Tensor: + if shape is None: + shape = self.shape[:-1] + else: + shape = _size([*shape, *self.shape[:-1]]) + mask = self.mask + n = int(self.space.n) + if mask is None: + m = torch.randint(n, shape, device=self.device) + else: + mask = mask.expand(_remove_neg_shapes(*shape, mask.shape[-1])) + if mask.ndim > 2: + mask_flat = torch.flatten(mask, 0, -2) + else: + mask_flat = mask + shape_out = mask.shape[:-1] + m = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) + out = torch.nn.functional.one_hot(m, n).to(self.dtype) + # torch.zeros((*shape, self.space.n), device=self.device, dtype=self.dtype) + # out.scatter_(-1, m, 1) + return out + + @implement_for("torch", "2.1") + def rand(self, shape: torch.Size = None) -> torch.Tensor: # noqa: F811 if shape is None: shape = self.shape[:-1] else: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 4193af8b751..945cd833c6c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -52,7 +52,13 @@ from torch import nn, Tensor from torch.utils._pytree import tree_map -from torchrl._utils import _append_last, _ends_with, _make_ordinal_device, _replace_last +from torchrl._utils import ( + _append_last, + _ends_with, + _make_ordinal_device, + _replace_last, + implement_for, +) from torchrl.data.tensor_specs import ( Binary, @@ -8772,7 +8778,14 @@ def __init__(self, out_key: NestedKey = "traj_count"): def _make_shared_value(self): self._traj_count = mp.Value("i", 0) + @implement_for("torch", None, "2.1") def __getstate__(self): + state = self.__dict__.copy() + state["_traj_count"] = None + return state + + @implement_for("torch", "2.1") + def __getstate__(self): # noqa: F811 state = super().__getstate__() state["_traj_count"] = None return state