diff --git a/.circleci/config.yml b/.circleci/config.yml index bbb9a767727..6a46376563b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -656,32 +656,102 @@ workflows: unittest: jobs: + - unittest_macos_cpu: + cu_version: cpu + name: unittest_macos_cpu_py3.7 + python_version: '3.7' + - unittest_linux_cpu: + cu_version: cpu + name: unittest_linux_cpu_py3.7 + python_version: '3.7' + - unittest_linux_gpu: + cu_version: cu113 + name: unittest_linux_gpu_py3.7 + python_version: '3.7' + - unittest_linux_optdeps_gpu: + cu_version: cu113 + name: unittest_linux_optdeps_gpu_py3.7 + python_version: '3.7' + - unittest_linux_stable_cpu: + cu_version: cpu + name: unittest_linux_stable_cpu_py3.7 + python_version: '3.7' + - unittest_linux_stable_gpu: + cu_version: cu113 + name: unittest_linux_stable_gpu_py3.7 + python_version: '3.7' + - unittest_macos_cpu: cu_version: cpu name: unittest_macos_cpu_py3.8 python_version: '3.8' - - unittest_linux_cpu: cu_version: cpu name: unittest_linux_cpu_py3.8 python_version: '3.8' - - unittest_linux_gpu: cu_version: cu113 name: unittest_linux_gpu_py3.8 python_version: '3.8' - - unittest_linux_optdeps_gpu: cu_version: cu113 name: unittest_linux_optdeps_gpu_py3.8 python_version: '3.8' - - unittest_linux_stable_cpu: cu_version: cpu name: unittest_linux_stable_cpu_py3.8 python_version: '3.8' - - unittest_linux_stable_gpu: cu_version: cu113 name: unittest_linux_stable_gpu_py3.8 python_version: '3.8' + + - unittest_macos_cpu: + cu_version: cpu + name: unittest_macos_cpu_py3.9 + python_version: '3.9' + - unittest_linux_cpu: + cu_version: cpu + name: unittest_linux_cpu_py3.9 + python_version: '3.9' + - unittest_linux_gpu: + cu_version: cu113 + name: unittest_linux_gpu_py3.9 + python_version: '3.9' + - unittest_linux_optdeps_gpu: + cu_version: cu113 + name: unittest_linux_optdeps_gpu_py3.9 + python_version: '3.9' + - unittest_linux_stable_cpu: + cu_version: cpu + name: unittest_linux_stable_cpu_py3.9 + python_version: '3.9' + - unittest_linux_stable_gpu: + cu_version: cu113 + name: unittest_linux_stable_gpu_py3.9 + python_version: '3.9' + + - unittest_macos_cpu: + cu_version: cpu + name: unittest_macos_cpu_py3.10 + python_version: '3.10' + - unittest_linux_cpu: + cu_version: cpu + name: unittest_linux_cpu_py3.10 + python_version: '3.10' + - unittest_linux_gpu: + cu_version: cu113 + name: unittest_linux_gpu_py3.10 + python_version: '3.10' + - unittest_linux_optdeps_gpu: + cu_version: cu113 + name: unittest_linux_optdeps_gpu_py3.10 + python_version: '3.10' + - unittest_linux_stable_cpu: + cu_version: cpu + name: unittest_linux_stable_cpu_py3.10 + python_version: '3.10' + - unittest_linux_stable_gpu: + cu_version: cu113 + name: unittest_linux_stable_gpu_py3.10 + python_version: '3.10' diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.circleci/unittest/linux/scripts/environment.yml index 01676cfe837..f0a4db6676e 100644 --- a/.circleci/unittest/linux/scripts/environment.yml +++ b/.circleci/unittest/linux/scripts/environment.yml @@ -8,7 +8,6 @@ dependencies: - hypothesis - future - cloudpickle - - gym_retro - gym - pygame - gym[accept-rom-license] @@ -24,5 +23,4 @@ dependencies: - dm_control - mujoco_py - hydra-core - - pyrender - tensorboard diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index 23267bdc8d1..da265cd6305 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -36,18 +36,14 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release - if [[ $OSTYPE == 'darwin'* ]]; then - pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu - else - pip3 install torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre - fi + conda install -y pytorch torchvision torchaudio cpuonly -c pytorch-nightly else - pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu113 + conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch-nightly fi printf "Installing functorch\n" -pip install ninja # Makes the build go faster -pip install "git+https://github.com/pytorch/functorch.git" +python -m pip install ninja # Makes the build go faster +python -m pip install "git+https://github.com/pytorch/functorch.git" # smoke test python -c "import functorch" diff --git a/.circleci/unittest/linux_optdeps/scripts/install.sh b/.circleci/unittest/linux_optdeps/scripts/install.sh index 9cba799bd3a..1aac9899282 100755 --- a/.circleci/unittest/linux_optdeps/scripts/install.sh +++ b/.circleci/unittest/linux_optdeps/scripts/install.sh @@ -36,18 +36,14 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release - if [[ $OSTYPE == 'darwin'* ]]; then - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu - else - pip3 install torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre - fi + conda install -y pytorch cpuonly -c pytorch-nightly else - pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu113 + conda install -y pytorch cudatoolkit=11.3 -c pytorch-nightly fi printf "Installing functorch\n" -pip install ninja # Makes the build go faster -pip install "git+https://github.com/pytorch/functorch.git" +python -m pip install ninja # Makes the build go faster +python -m pip install "git+https://github.com/pytorch/functorch.git" # smoke test python -c "import functorch" diff --git a/.circleci/unittest/linux_stable/scripts/environment.yml b/.circleci/unittest/linux_stable/scripts/environment.yml index 99b4ae39649..e5a232bb295 100644 --- a/.circleci/unittest/linux_stable/scripts/environment.yml +++ b/.circleci/unittest/linux_stable/scripts/environment.yml @@ -9,7 +9,6 @@ dependencies: - hypothesis - future - cloudpickle - - gym_retro - gym - pygame - gym[accept-rom-license] @@ -25,5 +24,4 @@ dependencies: - dm_control - mujoco_py - hydra-core - - pyrender - tensorboard diff --git a/.circleci/unittest/linux_stable/scripts/install.sh b/.circleci/unittest/linux_stable/scripts/install.sh index 47321c32001..0c8ec43a7e8 100755 --- a/.circleci/unittest/linux_stable/scripts/install.sh +++ b/.circleci/unittest/linux_stable/scripts/install.sh @@ -37,13 +37,13 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then # conda install -y pytorch torchvision cpuonly -c pytorch-nightly # use pip to install pytorch as conda can frequently pick older release - pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + conda install -y pytorch torchvision torchaudio cpuonly -c pytorch else - pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 + conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch fi printf "Installing functorch\n" -pip install git+https://github.com/pytorch/functorch.git@release/0.2 +python -m pip install git+https://github.com/pytorch/functorch.git@release/0.2 # smoke test python -c "import functorch" diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml new file mode 100644 index 00000000000..ebb67cf6de8 --- /dev/null +++ b/.github/workflows/wheels.yml @@ -0,0 +1,130 @@ +name: Wheels +on: + pull_request: + types: [opened, synchronize, reopened] + push: + branches: + - release/0.0.1 + +jobs: + + build-wheel-linux: + runs-on: ubuntu-18.04 + strategy: + matrix: + python_version: [["3.7", "cp37-cp37m"], ["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]] + cuda_support: [["", "--extra-index-url https://download.pytorch.org/whl/cpu", "\"['cpu', '11.3', '11.6']\"", "cpu"], ["+cu102", "", "\"['10.2']\"", "cuda102"]] + container: pytorch/manylinux-${{ matrix.cuda_support[3] }} + steps: + - name: Checkout torchrl + uses: actions/checkout@v2 + - name: Install PyTorch 1.12 RC + run: | + export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" + python3 -mpip install torch==1.12 ${{ matrix.cuda_support[1] }} + python3 -mpip install "git+https://github.com/pytorch/functorch.git@release/0.2" + - name: Build wheel + run: | + export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" + python3 -mpip install wheel + VERSION_TAG=${{ matrix.cuda_support[0] }} PYTORCH_CUDA_RESTRICTIONS=${{ matrix.cuda_support[2] }} python3 setup.py bdist_wheel + # NB: wheels have the linux_x86_64 tag so we rename to manylinux1 + # find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \; + # pytorch/pytorch binaries are also manylinux_2_17 compliant but they + # pretend that they're manylinux1 compliant so we do the same. + - name: Show auditwheel output; confirm 2-17 + run: | + python3 -mpip install auditwheel + auditwheel show dist/* + - name: Upload wheel for the test-wheel job + uses: actions/upload-artifact@v2 + with: + name: torchrl-linux-${{ matrix.python_version[0] }}.whl + path: dist/torchrl-0.0.1-*.whl + - name: Upload wheel for download + uses: actions/upload-artifact@v2 + with: + name: torchrl-batch.whl + path: dist/*.whl + + build-wheel-mac: + runs-on: macos-latest + strategy: + matrix: + python_version: [["3.7", "3.7"], ["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]] + steps: + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python_version[1] }} + architecture: x64 + - name: Checkout torchrl + uses: actions/checkout@v2 + - name: Install PyTorch 1.12 RC + run: | + python3 -mpip install torch==1.12 --extra-index-url https://download.pytorch.org/whl/cpu + python3 -mpip install "git+https://github.com/pytorch/functorch.git@release/0.2" + - name: Build wheel + run: | + export CC=clang CXX=clang++ + python3 -mpip install wheel + python3 setup.py bdist_wheel + - name: Upload wheel for the test-wheel job + uses: actions/upload-artifact@v2 + with: + name: torchrl-mac-${{ matrix.python_version[0] }}.whl + path: dist/*.whl + - name: Upload wheel for download + uses: actions/upload-artifact@v2 + with: + name: torchrl-batch.whl + path: dist/*.whl + + test-wheel: + needs: [build-wheel-linux, build-wheel-mac] + strategy: + matrix: + os: [["linux", "ubuntu-18.04"], ["mac", "macos-latest"]] + python_version: [ "3.7", "3.8", "3.9", "3.10" ] + runs-on: ${{ matrix.os[1] }} + steps: + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python_version }} + architecture: x64 + - name: Checkout torchrl + uses: actions/checkout@v2 + - name: Install PyTorch 1.12 RC + run: | + python3 -mpip install torch==1.12 torchvision --extra-index-url https://download.pytorch.org/whl/cpu + python3 -mpip install "git+https://github.com/pytorch/functorch.git@release/0.2" + - name: Upgrade pip + run: | + python3 -mpip install --upgrade pip + - name: Install test dependencies + run: | + python3 -mpip install numpy pytest pytest-cov codecov unittest-xml-reporting pillow>=4.1.1 scipy av networkx expecttest pyyaml + - name: Download built wheels + uses: actions/download-artifact@v2 + with: + name: torchrl-${{ matrix.os[0] }}-${{ matrix.python_version }}.whl + path: /tmp/wheels + - name: Install built wheels + run: | + python3 -mpip install /tmp/wheels/* + - name: Log version string + run: | + # Avoid ambiguity of "import torchrl" by deleting the source files. + rm -rf torchrl/ + python -c "import torchrl; print(torchrl.__version__)" + - name: Run tests + run: | + set -e + export IN_CI=1 + mkdir test-reports + python -m torch.utils.collect_env + python -c "import torchrl; print(torchrl.__version__)" + EXIT_STATUS=0 + pytest test/smoke_test.py -v --durations 20 + exit $EXIT_STATUS diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 6378d485197..8fd7693af0a 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -7,11 +7,11 @@ import os.path import re -import numpy as np import pytest import torch from _utils_internal import get_available_devices from torch import multiprocessing as mp +from torchrl import prod from torchrl.data import SavedTensorDict, TensorDict, MemmapTensor from torchrl.data.tensordict.tensordict import ( assert_allclose_td, @@ -810,7 +810,7 @@ def test_view(self, td_name): td = getattr(self, td_name) td_view = td.view(-1) tensor = td.get("a") - tensor = tensor.view(-1, tensor.numel() // np.prod(td.batch_size)) + tensor = tensor.view(-1, tensor.numel() // prod(td.batch_size)) tensor = torch.ones_like(tensor) if td_name == "sub_td": td_view.set_("a", tensor) diff --git a/test/test_transforms.py b/test/test_transforms.py index b42b0935a46..a32004955e4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse -import math from multiprocessing import Lock import pytest @@ -12,6 +11,7 @@ from mocking_classes import ContinuousActionVecMockEnv from torch import Tensor from torch import multiprocessing as mp +from torchrl import prod from torchrl.data import NdBoundedTensorSpec, CompositeSpec from torchrl.data import TensorDict from torchrl.envs import EnvCreator, SerialEnv @@ -448,7 +448,7 @@ def test_flatten(self, keys, size, nchannels, batch, device): ) td.set("dont touch", dont_touch.clone()) flatten(td) - expected_size = math.prod(size + [nchannels]) + expected_size = prod(size + [nchannels]) for key in keys: assert td.get(key).shape[-3] == expected_size assert (td.get("dont touch") == dont_touch).all() diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 6a5cea6ba25..bead3693e38 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -5,7 +5,10 @@ import abc import collections +import math import time +import typing +from typing import Optional, Type, Tuple from warnings import warn import numpy as np @@ -112,3 +115,10 @@ def __init__(self, fun=lambda x: x): def __missing__(self, key): value = self.fun(key) return value + + +def prod(sequence): + if hasattr(math, "prod"): + return math.prod(sequence) + else: + return int(np.prod(sequence)) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b551c4cd971..fd5229d0cd5 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import abc -import math import os import queue import time @@ -20,7 +19,7 @@ from torch.utils.data import IterableDataset from torchrl.envs.utils import set_exploration_mode, step_tensordict -from .. import _check_for_faulty_process +from .. import _check_for_faulty_process, prod from ..modules.tensordict_module import ProbabilisticTensorDictModule from .utils import split_trajectories @@ -485,7 +484,7 @@ def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" if index is not None: # check that the env supports partial reset - if np.prod(self.env.batch_size) == 0: + if prod(self.env.batch_size) == 0: raise RuntimeError("resetting unique env with index is not permitted.") reset_workers = torch.zeros( *self.env.batch_size, @@ -981,7 +980,7 @@ def iterator(self) -> Iterator[_TensorDict]: out = split_trajectories(out) frames += out.get("mask").sum() else: - frames += math.prod(out.shape) + frames += prod(out.shape) if self.postprocs: self.postprocs = self.postprocs.to(out.device) out = self.postprocs(out) diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index d01f3c4138e..30d749d4c66 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -2,16 +2,20 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Union, get_args +# import tree +import typing +from typing import Union import numpy as np import torch - -# import tree from torch import Tensor INT_CLASSES_TYPING = Union[int, np.integer] -INT_CLASSES = get_args(INT_CLASSES_TYPING) +if hasattr(typing, "get_args"): + INT_CLASSES = typing.get_args(INT_CLASSES_TYPING) +else: + # python 3.7 + INT_CLASSES = (int, np.integer) def fields_pin_memory(input): @@ -41,13 +45,7 @@ def stack_tensors(input): if not len(input): raise RuntimeError("input length must be non-null") if isinstance(input[0], torch.Tensor): - size = input[0].size() - if len(size) == 0: - return torch.stack(input) - else: - # torch.cat is much faster than torch.stack - # https://github.com/pytorch/pytorch/issues/22462 - return torch.cat(input).view(-1, *size) + return torch.stack(input) else: return np.stack(input) @@ -60,7 +58,6 @@ def stack_fields(input): def first_field(data) -> Tensor: raise NotImplementedError - # return next(iter(tree.flatten(data))) def to_torch( diff --git a/torchrl/data/tensordict/memmap.py b/torchrl/data/tensordict/memmap.py index 1fe55c40ef2..5c1b3c638de 100644 --- a/torchrl/data/tensordict/memmap.py +++ b/torchrl/data/tensordict/memmap.py @@ -8,12 +8,12 @@ import functools import os import tempfile -from math import prod from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch +from torchrl import prod from torchrl.data.tensordict.utils import _getitem_batch_size from torchrl.data.utils import ( DEVICE_TYPING, diff --git a/torchrl/data/tensordict/metatensor.py b/torchrl/data/tensordict/metatensor.py index 96021ad33b7..3a753a8b0d7 100644 --- a/torchrl/data/tensordict/metatensor.py +++ b/torchrl/data/tensordict/metatensor.py @@ -6,10 +6,10 @@ from __future__ import annotations import functools -import math from numbers import Number from typing import Callable, List, Optional, Sequence, Tuple, Union +import numpy as np import torch from torchrl.data.utils import DEVICE_TYPING, INDEX_TYPING @@ -92,7 +92,7 @@ def __init__( self.dtype = dtype self.requires_grad = requires_grad self._ndim = len(shape) - self._numel = math.prod(shape) + self._numel = np.prod(shape) self._is_shared = bool(_is_shared) self._is_memmap = bool(_is_memmap) if _is_memmap: diff --git a/torchrl/data/tensordict/tensordict.py b/torchrl/data/tensordict/tensordict.py index 1fb56bae65b..423fcd5f957 100644 --- a/torchrl/data/tensordict/tensordict.py +++ b/torchrl/data/tensordict/tensordict.py @@ -7,7 +7,6 @@ import abc import functools -import math import tempfile import textwrap import uuid @@ -34,7 +33,7 @@ import numpy as np import torch -from torchrl import KeyDependentDefaultDict +from torchrl import KeyDependentDefaultDict, prod from torchrl.data.tensordict.memmap import MemmapTensor from torchrl.data.tensordict.metatensor import MetaTensor from torchrl.data.tensordict.utils import ( @@ -185,7 +184,7 @@ def is_memmap(self, no_check: bool = True) -> bool: def numel(self) -> int: """Total number of elements in the batch.""" - return max(1, math.prod(self.batch_size)) + return max(1, prod(self.batch_size)) def _check_batch_size(self) -> None: bs = [value.shape[: self.batch_dims] for key, value in self.items_meta()] + [ diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index a2ec4fe22d0..5990b400983 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import typing from typing import Any, Callable, List, Sequence, Tuple, Union import numpy as np @@ -26,6 +27,10 @@ value: key for key, value in numpy_to_torch_dtype_dict.items() } DEVICE_TYPING = Union[torch.device, str, int] +if hasattr(typing, "get_args"): + DEVICE_TYPING_ARGS = typing.get_args(DEVICE_TYPING) +else: + DEVICE_TYPING_ARGS = (torch.device, str, int) INDEX_TYPING = Union[None, int, slice, Tensor, List[Any], Tuple[Any, ...]] diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 27058873128..74a6f075992 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -6,7 +6,6 @@ from __future__ import annotations import abc -import math from collections import OrderedDict from numbers import Number from typing import Any, Callable, Iterator, Optional, Union, Dict @@ -14,7 +13,7 @@ import numpy as np import torch -from torchrl import seed_generator +from torchrl import seed_generator, prod from torchrl.data import CompositeSpec, TensorDict, TensorSpec from ..data.tensordict.tensordict import _TensorDict from ..data.utils import DEVICE_TYPING @@ -335,7 +334,7 @@ def reset( return tensordict def numel(self) -> int: - return math.prod(self.batch_size) + return prod(self.batch_size) def set_seed(self, seed: int) -> int: """Sets the seed of the environment and returns the next seed to be used ( diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index c02d898f117..3b13fd5dc8f 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Optional, Sequence, Union, get_args +from typing import Optional, Sequence, Union import torch from torch import nn, distributions as d @@ -13,7 +13,8 @@ __all__ = ["NoisyLinear", "NoisyLazyLinear", "reset_noise"] -from torchrl.data.utils import DEVICE_TYPING +from torchrl import prod +from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS from torchrl.envs.utils import exploration_mode from torchrl.modules.distributions.utils import _cast_transform_device from torchrl.modules.utils import inv_softplus @@ -361,7 +362,7 @@ def forward(self, mu, state, _eps_gSDE): ) elif (_eps_gSDE is None and exploration_mode() == "random") or ( _eps_gSDE is not None - and _eps_gSDE.numel() == math.prod(state.shape[:-1]) + and _eps_gSDE.numel() == prod(state.shape[:-1]) and (_eps_gSDE == 0).all() ): _eps_gSDE = torch.randn( @@ -389,7 +390,7 @@ def forward(self, mu, state, _eps_gSDE): return mu, sigma, action, _eps_gSDE def to(self, device_or_dtype: Union[torch.dtype, DEVICE_TYPING]): - if isinstance(device_or_dtype, get_args(DEVICE_TYPING)): + if isinstance(device_or_dtype, DEVICE_TYPING_ARGS): self.transform = _cast_transform_device(self.transform, device_or_dtype) return super().to(device_or_dtype) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index be8fc6de60b..526b67780b2 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -6,11 +6,11 @@ from numbers import Number from typing import Dict, List, Optional, Sequence, Tuple, Type, Union -import numpy as np import torch from torch import nn from torch.nn import functional as F +from torchrl import prod from torchrl.data import DEVICE_TYPING from torchrl.modules.models.utils import ( _find_depth, @@ -180,7 +180,7 @@ def __init__( _out_features_num = out_features if not isinstance(out_features, Number): - _out_features_num = np.prod(out_features) + _out_features_num = prod(out_features) self.out_features = out_features self._out_features_num = _out_features_num self.activation_class = activation_class