diff --git a/.github/unittest/linux_libs/scripts_gym/install.sh b/.github/unittest/linux_libs/scripts_gym/install.sh index 89d6504f52b..98bb504493c 100755 --- a/.github/unittest/linux_libs/scripts_gym/install.sh +++ b/.github/unittest/linux_libs/scripts_gym/install.sh @@ -37,9 +37,9 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then - conda install pytorch==1.13.1 torchvision==0.14.1 cpuonly -c pytorch + conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch else - conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y + conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia fi # Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh index c3f24e52708..158007d5021 100755 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -37,9 +37,9 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then - conda install pytorch==1.13.1 torchvision==0.14.1 cpuonly -c pytorch + conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch else - conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y + conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia fi # Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has diff --git a/.github/unittest/linux_optdeps/scripts/install.sh b/.github/unittest/linux_optdeps/scripts/install.sh index 885d192e135..8ccbfbb8e19 100755 --- a/.github/unittest/linux_optdeps/scripts/install.sh +++ b/.github/unittest/linux_optdeps/scripts/install.sh @@ -29,9 +29,6 @@ else pip3 install tensordict fi -# smoke test -python -c "import functorch" - printf "* Installing torchrl\n" python setup.py develop diff --git a/docs/source/index.rst b/docs/source/index.rst index 44b7d406cd2..5f3882faed7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,7 +42,7 @@ Installation TorchRL releases are synced with PyTorch, so make sure you always enjoy the latest features of the library with the `most recent version of PyTorch `__ (although core features -are guaranteed to be backward compatible with pytorch>=1.13). +are guaranteed to be backward compatible with pytorch>=2.0). Nightly releases can be installed via .. code-block:: diff --git a/knowledge_base/VERSIONING_ISSUES.md b/knowledge_base/VERSIONING_ISSUES.md index 2a80aecddf3..1049199bbdf 100644 --- a/knowledge_base/VERSIONING_ISSUES.md +++ b/knowledge_base/VERSIONING_ISSUES.md @@ -1,7 +1,7 @@ # Versioning Issues ## Pytorch version -This issue is related to https://github.com/pytorch/rl/issues/689. Using PyTorch versions <1.13 and installing stable package leads to undefined symbol errors. For example: +This issue is related to https://github.com/pytorch/rl/issues/689. Using PyTorch versions <2.0 and installing stable package leads to undefined symbol errors. For example: ``` ImportError: /usr/local/lib/python3.7/dist-packages/torchrl/_torchrl.so: undefined symbol: _ZN8pybind116detail11type_casterIN2at6TensorEvE4loadENS_6handleEb ``` diff --git a/test/test_env.py b/test/test_env.py index e86cc06b14c..d6ebb16084c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -81,7 +81,7 @@ from torchrl.envs.batched_envs import _stackable from torchrl.envs.gym_like import default_info_dict_reader from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv -from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper +from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv from torchrl.envs.transforms.transforms import AutoResetEnv, AutoResetTransform from torchrl.envs.utils import ( @@ -203,6 +203,12 @@ def test_env_seed(env_name, frame_skip, seed=0): @pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED, PONG_VERSIONED]) @pytest.mark.parametrize("frame_skip", [1, 4]) def test_rollout(env_name, frame_skip, seed=0): + if env_name is PONG_VERSIONED and version.parse( + gym_backend().__version__ + ) < version.parse("0.19"): + # Then 100 steps in pong are not sufficient to detect a difference + pytest.skip("can't detect difference in gym rollout with this gym version.") + env_name = env_name() env = GymEnv(env_name, frame_skip=frame_skip) diff --git a/test/test_libs.py b/test/test_libs.py index fdd164b5e89..672b64cdb22 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2823,16 +2823,15 @@ def _minari_selected_datasets(): _MINARI_DATASETS += keys +_minari_selected_datasets() + + @pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found") @pytest.mark.slow class TestMinari: @pytest.mark.parametrize("split", [False, True]) @pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS) def test_load(self, selected_dataset, split): - global _MINARI_DATASETS - if not _MINARI_DATASETS: - _minari_selected_datasets() - torchrl_logger.info(f"dataset {selected_dataset}") data = MinariExperienceReplay( selected_dataset, batch_size=32, split_trajs=split diff --git a/test/test_modules.py b/test/test_modules.py index 932bbbcb194..59adbea653d 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -841,7 +841,7 @@ def test_multiagent_mlp_lazy(self): share_params=False, depth=2, ) - optim = torch.optim.SGD(mlp.parameters()) + optim = torch.optim.SGD(mlp.parameters(), lr=1e-3) for p in mlp.parameters(): if isinstance(p, torch.nn.parameter.UninitializedParameter): break @@ -975,7 +975,7 @@ def test_multiagent_cnn_lazy(self): in_features=None, kernel_sizes=3, ) - optim = torch.optim.SGD(cnn.parameters()) + optim = torch.optim.SGD(cnn.parameters(), lr=1e-3) for p in cnn.parameters(): if isinstance(p, torch.nn.parameter.UninitializedParameter): break diff --git a/test/test_rb.py b/test/test_rb.py index a8ce5ac7aea..a0aa24513a1 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -66,6 +66,7 @@ StorageEnsemble, TensorStorage, ) +from torchrl.data.replay_buffers.utils import tree_iter from torchrl.data.replay_buffers.writers import ( RoundRobinWriter, TensorDictMaxValueWriter, @@ -880,7 +881,7 @@ def test_extend_list_pytree(self, max_size, shape, storage): assert len(memory) == 10 assert len(memory._storage) == 10 sample = memory.sample(10) - for leaf in torch.utils._pytree.tree_leaves(sample): + for leaf in tree_iter(sample): assert (leaf.unique(sorted=True) == torch.arange(10)).all() memory = ReplayBuffer( storage=storage(max_size=max_size), @@ -2961,7 +2962,7 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls, sampler_cls): assert (s.exclude("index") == 1).all() assert s.numel() == 4 else: - for leaf in torch.utils._pytree.tree_leaves(s): + for leaf in tree_iter(s): assert leaf.shape[0] == 4 assert (leaf == 1).all() @@ -3122,6 +3123,13 @@ def test_simple_env(self, storage_type, checkpointer, tmpdir): ) rb = ReplayBuffer(storage=storage_type(100)) rb_test = ReplayBuffer(storage=storage_type(100)) + if torch.__version__ < "2.4.0" and checkpointer in ( + H5StorageCheckpointer, + NestedStorageCheckpointer, + ): + with pytest.raises(ValueError, match="Unsupported torch version"): + checkpointer() + return rb.storage.checkpointer = checkpointer() rb_test.storage.checkpointer = checkpointer() for data in collector: @@ -3144,12 +3152,20 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir): ) rb = ReplayBuffer(storage=storage_type(100, ndim=2)) rb_test = ReplayBuffer(storage=storage_type(100, ndim=2)) + if torch.__version__ < "2.4.0" and checkpointer in ( + H5StorageCheckpointer, + NestedStorageCheckpointer, + ): + with pytest.raises(ValueError, match="Unsupported torch version"): + checkpointer() + return rb.storage.checkpointer = checkpointer() rb_test.storage.checkpointer = checkpointer() for data in collector: rb.extend(data) assert rb._storage.max_size == 102 rb.dumps(tmpdir) + rb.dumps(tmpdir) rb_test.loads(tmpdir) assert_allclose_td(rb_test[:], rb[:]) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 75a4ba9b24b..48223135e53 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -24,7 +24,7 @@ from torchrl._utils import _replace_last, implement_for, logger from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage -from torchrl.data.replay_buffers.utils import _is_int +from torchrl.data.replay_buffers.utils import _is_int, unravel_index try: from torchrl._torchrl import ( @@ -204,7 +204,9 @@ def _single_sample(self, len_storage, batch_size): def _storage_len(self, storage): return len(storage) - def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: + def sample( + self, storage: Storage, batch_size: int + ) -> Tuple[Any, dict]: # noqa: F811 len_storage = self._storage_len(storage) if len_storage == 0: raise RuntimeError(_EMPTY_STORAGE_ERROR) @@ -221,7 +223,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: self.len_storage = len_storage index = self._single_sample(len_storage, batch_size) if storage.ndim > 1: - index = torch.unravel_index(index, storage.shape) + index = unravel_index(index, storage.shape) # we 'always' return the indices. The 'drop_last' just instructs the # sampler to turn to `ran_out = True` whenever the next sample # will be too short. This will be read by the replay buffer @@ -470,7 +472,7 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: # weight = np.power(weight / (p_min + self._eps), -self._beta) weight = torch.pow(weight / p_min, -self._beta) if storage.ndim > 1: - index = torch.unravel_index(index, storage.shape) + index = unravel_index(index, storage.shape) return index, {"_weight": weight} return index, {"_weight": weight} @@ -1807,7 +1809,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] if storage.ndim > 1: # we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted) # This is because the lengths come as they would for a permuted storage - preceding_stop_idx = torch.unravel_index( + preceding_stop_idx = unravel_index( preceding_stop_idx, (storage.shape[-1], *storage.shape[:-1]) ) preceding_stop_idx = (preceding_stop_idx[-1], *preceding_stop_idx[:-1]) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 8de1b1d94a5..4a041d16c8d 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -25,7 +25,6 @@ from tensordict.memmap import MemoryMappedTensor from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten - from torchrl._utils import implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.checkpointers import ( ListStorageCheckpointer, @@ -33,7 +32,12 @@ StorageEnsembleCheckpointer, TensorStorageCheckpointer, ) -from torchrl.data.replay_buffers.utils import _init_pytree, _is_int, INT_CLASSES +from torchrl.data.replay_buffers.utils import ( + _init_pytree, + _is_int, + INT_CLASSES, + tree_iter, +) class Storage: @@ -425,7 +429,7 @@ def _total_shape(self): if is_tensor_collection(self._storage): _total_shape = self._storage.shape[: self.ndim] else: - leaf, *_ = torch.utils._pytree.tree_leaves(self._storage) + leaf = next(tree_iter(self._storage)) _total_shape = leaf.shape[: self.ndim] self.__dict__["_total_shape_value"] = _total_shape return _total_shape @@ -462,7 +466,7 @@ def _max_size_along_dim0(self, *, single_data=None, batched_data=None): if is_tensor_collection(data): datashape = data.shape[: self.ndim] else: - for leaf in torch.utils._pytree.tree_leaves(data): + for leaf in tree_iter(data): datashape = leaf.shape[: self.ndim] break if batched_data is not None: @@ -615,8 +619,7 @@ def _get_new_len(self, data, cursor): if is_tensor_collection(data) or isinstance(data, torch.Tensor): numel = data.shape[:ndim].numel() else: - # unfortunately tree_flatten isn't an iterator so we will have to flatten it all - leaf, *_ = torch.utils._pytree.tree_leaves(data) + leaf = next(tree_iter(data)) numel = leaf.shape[:ndim].numel() self._len = min(self._len + numel, self.max_size) diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index 1fe0eb077c5..39cd7015d75 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -6,8 +6,10 @@ from __future__ import annotations import contextlib +import itertools import math +import operator import os import typing from pathlib import Path @@ -597,6 +599,14 @@ class TED2Nested(TED2Flat): _shift: int = None _is_full: bool = None + def __init__(self, *args, **kwargs): + if not hasattr(torch, "_nested_compute_contiguous_strides_offsets"): + raise ValueError( + f"Unsupported torch version {torch.__version__}. " + f"torch>=2.4 is required for {type(self).__name__} to be used." + ) + return super().__init__(*args, **kwargs) + def __call__(self, data: TensorDictBase, path: Path = None): data = super().__call__(data, path=path) @@ -949,3 +959,64 @@ def _roll_inplace(tensor, shift, out, index_dest=None, index_source=None): slice1 = out[:-slice0_shift] slice1.copy_(source1) return out + + +# Copy-paste of unravel-index for PT 2.0 +def _unravel_index( + indices: Tensor, shape: Union[int, typing.Sequence[int], torch.Size] +) -> typing.Tuple[Tensor, ...]: + res_tensor = _unravel_index_impl(indices, shape) + return res_tensor.unbind(-1) + + +def _unravel_index_impl( + indices: Tensor, shape: Union[int, typing.Sequence[int]] +) -> Tensor: + if isinstance(shape, (int, torch.SymInt)): + shape = torch.Size([shape]) + else: + shape = torch.Size(shape) + + coefs = list( + reversed( + list( + itertools.accumulate( + reversed(shape[1:] + torch.Size([1])), func=operator.mul + ) + ) + ) + ) + return indices.unsqueeze(-1).floor_divide( + torch.tensor(coefs, device=indices.device, dtype=torch.int64) + ) % torch.tensor(shape, device=indices.device, dtype=torch.int64) + + +@implement_for("torch", None, "2.2") +def unravel_index(indices, shape): + """A version-compatible wrapper around torch.unravel_index.""" + return _unravel_index(indices, shape) + + +@implement_for("torch", "2.2") +def unravel_index(indices, shape): # noqa: F811 + """A version-compatible wrapper around torch.unravel_index.""" + return torch.unravel_index(indices, shape) + + +@implement_for("torch", None, "2.3") +def tree_iter(pytree): + """A version-compatible wrapper around tree_iter.""" + flat_tree, _ = torch.utils._pytree.tree_flatten(pytree) + yield from flat_tree + + +@implement_for("torch", "2.3", "2.4") +def tree_iter(pytree): # noqa: F811 + """A version-compatible wrapper around tree_iter.""" + yield from torch.utils._pytree.tree_leaves(pytree) + + +@implement_for("torch", "2.4") +def tree_iter(pytree): # noqa: F811 + """A version-compatible wrapper around tree_iter.""" + yield from torch.utils._pytree.tree_iter(pytree) diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 7ac5d9873e5..11cc363b461 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -445,7 +445,7 @@ class VmapModule(TensorDictModuleBase): def __init__(self, module: TensorDictModuleBase, vmap_dim=None): if not _has_functorch: - raise ImportError("VmapModule requires torch>=1.13.") + raise ImportError("VmapModule requires torch>=2.0.") super().__init__() self.in_keys = module.in_keys self.out_keys = module.out_keys diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 9e137df98cb..b977a3440dd 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -46,7 +46,7 @@ from functorch import vmap except ImportError: raise ImportError( - "vmap couldn't be found. Make sure you have torch>1.13 installed." + "vmap couldn't be found. Make sure you have torch>2.0 installed." ) from err