Skip to content

Commit

Permalink
[Versioning] Bump torch 2.0 as minimal version (#2200)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 10, 2024
1 parent 4d37ee1 commit 166467a
Show file tree
Hide file tree
Showing 14 changed files with 125 additions and 31 deletions.
4 changes: 2 additions & 2 deletions .github/unittest/linux_libs/scripts_gym/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [ "${CU_VERSION:-}" == cpu ] ; then
conda install pytorch==1.13.1 torchvision==0.14.1 cpuonly -c pytorch
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch
else
conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
fi

# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has
Expand Down
4 changes: 2 additions & 2 deletions .github/unittest/linux_olddeps/scripts_gym_0_13/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [ "${CU_VERSION:-}" == cpu ] ; then
conda install pytorch==1.13.1 torchvision==0.14.1 cpuonly -c pytorch
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch
else
conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
fi

# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has
Expand Down
3 changes: 0 additions & 3 deletions .github/unittest/linux_optdeps/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ else
pip3 install tensordict
fi

# smoke test
python -c "import functorch"

printf "* Installing torchrl\n"
python setup.py develop

Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Installation

TorchRL releases are synced with PyTorch, so make sure you always enjoy the latest
features of the library with the `most recent version of PyTorch <https://pytorch.org/get-started/locally/>`__ (although core features
are guaranteed to be backward compatible with pytorch>=1.13).
are guaranteed to be backward compatible with pytorch>=2.0).
Nightly releases can be installed via

.. code-block::
Expand Down
2 changes: 1 addition & 1 deletion knowledge_base/VERSIONING_ISSUES.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Versioning Issues

## Pytorch version
This issue is related to https://github.com/pytorch/rl/issues/689. Using PyTorch versions <1.13 and installing stable package leads to undefined symbol errors. For example:
This issue is related to https://github.com/pytorch/rl/issues/689. Using PyTorch versions <2.0 and installing stable package leads to undefined symbol errors. For example:
```
ImportError: /usr/local/lib/python3.7/dist-packages/torchrl/_torchrl.so: undefined symbol: _ZN8pybind116detail11type_casterIN2at6TensorEvE4loadENS_6handleEb
```
Expand Down
8 changes: 7 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from torchrl.envs.batched_envs import _stackable
from torchrl.envs.gym_like import default_info_dict_reader
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
from torchrl.envs.transforms.transforms import AutoResetEnv, AutoResetTransform
from torchrl.envs.utils import (
Expand Down Expand Up @@ -203,6 +203,12 @@ def test_env_seed(env_name, frame_skip, seed=0):
@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED, PONG_VERSIONED])
@pytest.mark.parametrize("frame_skip", [1, 4])
def test_rollout(env_name, frame_skip, seed=0):
if env_name is PONG_VERSIONED and version.parse(
gym_backend().__version__
) < version.parse("0.19"):
# Then 100 steps in pong are not sufficient to detect a difference
pytest.skip("can't detect difference in gym rollout with this gym version.")

env_name = env_name()
env = GymEnv(env_name, frame_skip=frame_skip)

Expand Down
7 changes: 3 additions & 4 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2823,16 +2823,15 @@ def _minari_selected_datasets():
_MINARI_DATASETS += keys


_minari_selected_datasets()


@pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found")
@pytest.mark.slow
class TestMinari:
@pytest.mark.parametrize("split", [False, True])
@pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS)
def test_load(self, selected_dataset, split):
global _MINARI_DATASETS
if not _MINARI_DATASETS:
_minari_selected_datasets()

torchrl_logger.info(f"dataset {selected_dataset}")
data = MinariExperienceReplay(
selected_dataset, batch_size=32, split_trajs=split
Expand Down
4 changes: 2 additions & 2 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ def test_multiagent_mlp_lazy(self):
share_params=False,
depth=2,
)
optim = torch.optim.SGD(mlp.parameters())
optim = torch.optim.SGD(mlp.parameters(), lr=1e-3)
for p in mlp.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
Expand Down Expand Up @@ -975,7 +975,7 @@ def test_multiagent_cnn_lazy(self):
in_features=None,
kernel_sizes=3,
)
optim = torch.optim.SGD(cnn.parameters())
optim = torch.optim.SGD(cnn.parameters(), lr=1e-3)
for p in cnn.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
Expand Down
20 changes: 18 additions & 2 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
StorageEnsemble,
TensorStorage,
)
from torchrl.data.replay_buffers.utils import tree_iter
from torchrl.data.replay_buffers.writers import (
RoundRobinWriter,
TensorDictMaxValueWriter,
Expand Down Expand Up @@ -880,7 +881,7 @@ def test_extend_list_pytree(self, max_size, shape, storage):
assert len(memory) == 10
assert len(memory._storage) == 10
sample = memory.sample(10)
for leaf in torch.utils._pytree.tree_leaves(sample):
for leaf in tree_iter(sample):
assert (leaf.unique(sorted=True) == torch.arange(10)).all()
memory = ReplayBuffer(
storage=storage(max_size=max_size),
Expand Down Expand Up @@ -2961,7 +2962,7 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls, sampler_cls):
assert (s.exclude("index") == 1).all()
assert s.numel() == 4
else:
for leaf in torch.utils._pytree.tree_leaves(s):
for leaf in tree_iter(s):
assert leaf.shape[0] == 4
assert (leaf == 1).all()

Expand Down Expand Up @@ -3122,6 +3123,13 @@ def test_simple_env(self, storage_type, checkpointer, tmpdir):
)
rb = ReplayBuffer(storage=storage_type(100))
rb_test = ReplayBuffer(storage=storage_type(100))
if torch.__version__ < "2.4.0" and checkpointer in (
H5StorageCheckpointer,
NestedStorageCheckpointer,
):
with pytest.raises(ValueError, match="Unsupported torch version"):
checkpointer()
return
rb.storage.checkpointer = checkpointer()
rb_test.storage.checkpointer = checkpointer()
for data in collector:
Expand All @@ -3144,12 +3152,20 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir):
)
rb = ReplayBuffer(storage=storage_type(100, ndim=2))
rb_test = ReplayBuffer(storage=storage_type(100, ndim=2))
if torch.__version__ < "2.4.0" and checkpointer in (
H5StorageCheckpointer,
NestedStorageCheckpointer,
):
with pytest.raises(ValueError, match="Unsupported torch version"):
checkpointer()
return
rb.storage.checkpointer = checkpointer()
rb_test.storage.checkpointer = checkpointer()
for data in collector:
rb.extend(data)
assert rb._storage.max_size == 102
rb.dumps(tmpdir)
rb.dumps(tmpdir)
rb_test.loads(tmpdir)
assert_allclose_td(rb_test[:], rb[:])

Expand Down
12 changes: 7 additions & 5 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from torchrl._utils import _replace_last, implement_for, logger
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _is_int
from torchrl.data.replay_buffers.utils import _is_int, unravel_index

try:
from torchrl._torchrl import (
Expand Down Expand Up @@ -204,7 +204,9 @@ def _single_sample(self, len_storage, batch_size):
def _storage_len(self, storage):
return len(storage)

def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
def sample(
self, storage: Storage, batch_size: int
) -> Tuple[Any, dict]: # noqa: F811
len_storage = self._storage_len(storage)
if len_storage == 0:
raise RuntimeError(_EMPTY_STORAGE_ERROR)
Expand All @@ -221,7 +223,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
self.len_storage = len_storage
index = self._single_sample(len_storage, batch_size)
if storage.ndim > 1:
index = torch.unravel_index(index, storage.shape)
index = unravel_index(index, storage.shape)
# we 'always' return the indices. The 'drop_last' just instructs the
# sampler to turn to `ran_out = True` whenever the next sample
# will be too short. This will be read by the replay buffer
Expand Down Expand Up @@ -470,7 +472,7 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
# weight = np.power(weight / (p_min + self._eps), -self._beta)
weight = torch.pow(weight / p_min, -self._beta)
if storage.ndim > 1:
index = torch.unravel_index(index, storage.shape)
index = unravel_index(index, storage.shape)
return index, {"_weight": weight}

return index, {"_weight": weight}
Expand Down Expand Up @@ -1807,7 +1809,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
if storage.ndim > 1:
# we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted)
# This is because the lengths come as they would for a permuted storage
preceding_stop_idx = torch.unravel_index(
preceding_stop_idx = unravel_index(
preceding_stop_idx, (storage.shape[-1], *storage.shape[:-1])
)
preceding_stop_idx = (preceding_stop_idx[-1], *preceding_stop_idx[:-1])
Expand Down
15 changes: 9 additions & 6 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@
from tensordict.memmap import MemoryMappedTensor
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten

from torchrl._utils import implement_for, logger as torchrl_logger
from torchrl.data.replay_buffers.checkpointers import (
ListStorageCheckpointer,
StorageCheckpointerBase,
StorageEnsembleCheckpointer,
TensorStorageCheckpointer,
)
from torchrl.data.replay_buffers.utils import _init_pytree, _is_int, INT_CLASSES
from torchrl.data.replay_buffers.utils import (
_init_pytree,
_is_int,
INT_CLASSES,
tree_iter,
)


class Storage:
Expand Down Expand Up @@ -425,7 +429,7 @@ def _total_shape(self):
if is_tensor_collection(self._storage):
_total_shape = self._storage.shape[: self.ndim]
else:
leaf, *_ = torch.utils._pytree.tree_leaves(self._storage)
leaf = next(tree_iter(self._storage))
_total_shape = leaf.shape[: self.ndim]
self.__dict__["_total_shape_value"] = _total_shape
return _total_shape
Expand Down Expand Up @@ -462,7 +466,7 @@ def _max_size_along_dim0(self, *, single_data=None, batched_data=None):
if is_tensor_collection(data):
datashape = data.shape[: self.ndim]
else:
for leaf in torch.utils._pytree.tree_leaves(data):
for leaf in tree_iter(data):
datashape = leaf.shape[: self.ndim]
break
if batched_data is not None:
Expand Down Expand Up @@ -615,8 +619,7 @@ def _get_new_len(self, data, cursor):
if is_tensor_collection(data) or isinstance(data, torch.Tensor):
numel = data.shape[:ndim].numel()
else:
# unfortunately tree_flatten isn't an iterator so we will have to flatten it all
leaf, *_ = torch.utils._pytree.tree_leaves(data)
leaf = next(tree_iter(data))
numel = leaf.shape[:ndim].numel()
self._len = min(self._len + numel, self.max_size)

Expand Down
71 changes: 71 additions & 0 deletions torchrl/data/replay_buffers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from __future__ import annotations

import contextlib
import itertools

import math
import operator
import os
import typing
from pathlib import Path
Expand Down Expand Up @@ -597,6 +599,14 @@ class TED2Nested(TED2Flat):
_shift: int = None
_is_full: bool = None

def __init__(self, *args, **kwargs):
if not hasattr(torch, "_nested_compute_contiguous_strides_offsets"):
raise ValueError(
f"Unsupported torch version {torch.__version__}. "
f"torch>=2.4 is required for {type(self).__name__} to be used."
)
return super().__init__(*args, **kwargs)

def __call__(self, data: TensorDictBase, path: Path = None):
data = super().__call__(data, path=path)

Expand Down Expand Up @@ -949,3 +959,64 @@ def _roll_inplace(tensor, shift, out, index_dest=None, index_source=None):
slice1 = out[:-slice0_shift]
slice1.copy_(source1)
return out


# Copy-paste of unravel-index for PT 2.0
def _unravel_index(
indices: Tensor, shape: Union[int, typing.Sequence[int], torch.Size]
) -> typing.Tuple[Tensor, ...]:
res_tensor = _unravel_index_impl(indices, shape)
return res_tensor.unbind(-1)


def _unravel_index_impl(
indices: Tensor, shape: Union[int, typing.Sequence[int]]
) -> Tensor:
if isinstance(shape, (int, torch.SymInt)):
shape = torch.Size([shape])
else:
shape = torch.Size(shape)

coefs = list(
reversed(
list(
itertools.accumulate(
reversed(shape[1:] + torch.Size([1])), func=operator.mul
)
)
)
)
return indices.unsqueeze(-1).floor_divide(
torch.tensor(coefs, device=indices.device, dtype=torch.int64)
) % torch.tensor(shape, device=indices.device, dtype=torch.int64)


@implement_for("torch", None, "2.2")
def unravel_index(indices, shape):
"""A version-compatible wrapper around torch.unravel_index."""
return _unravel_index(indices, shape)


@implement_for("torch", "2.2")
def unravel_index(indices, shape): # noqa: F811
"""A version-compatible wrapper around torch.unravel_index."""
return torch.unravel_index(indices, shape)


@implement_for("torch", None, "2.3")
def tree_iter(pytree):
"""A version-compatible wrapper around tree_iter."""
flat_tree, _ = torch.utils._pytree.tree_flatten(pytree)
yield from flat_tree


@implement_for("torch", "2.3", "2.4")
def tree_iter(pytree): # noqa: F811
"""A version-compatible wrapper around tree_iter."""
yield from torch.utils._pytree.tree_leaves(pytree)


@implement_for("torch", "2.4")
def tree_iter(pytree): # noqa: F811
"""A version-compatible wrapper around tree_iter."""
yield from torch.utils._pytree.tree_iter(pytree)
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ class VmapModule(TensorDictModuleBase):

def __init__(self, module: TensorDictModuleBase, vmap_dim=None):
if not _has_functorch:
raise ImportError("VmapModule requires torch>=1.13.")
raise ImportError("VmapModule requires torch>=2.0.")
super().__init__()
self.in_keys = module.in_keys
self.out_keys = module.out_keys
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from functorch import vmap
except ImportError:
raise ImportError(
"vmap couldn't be found. Make sure you have torch>1.13 installed."
"vmap couldn't be found. Make sure you have torch>2.0 installed."
) from err


Expand Down

0 comments on commit 166467a

Please sign in to comment.