Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Versioning] Bump torch 2.0 as minimal version #2200

Merged
merged 21 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Jun 3, 2024
commit df1e7f4d8eb26231f297de0f847b0a5a06ae7e88
4 changes: 2 additions & 2 deletions .github/unittest/linux_libs/scripts_gym/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [ "${CU_VERSION:-}" == cpu ] ; then
conda install pytorch==1.13.1 torchvision==0.14.1 cpuonly -c pytorch
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch
else
conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y
conda install pytorch==2.0 torchvision==0.15 pytorch-cuda=11.6 -c pytorch -c nvidia -y
fi

# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has
Expand Down
4 changes: 2 additions & 2 deletions .github/unittest/linux_olddeps/scripts_gym_0_13/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [ "${CU_VERSION:-}" == cpu ] ; then
conda install pytorch==1.13.1 torchvision==0.14.1 cpuonly -c pytorch
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch
else
conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y
conda install pytorch==2.0 torchvision==0.15 pytorch-cuda=11.6 -c pytorch -c nvidia -y
fi

# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Installation

TorchRL releases are synced with PyTorch, so make sure you always enjoy the latest
features of the library with the `most recent version of PyTorch <https://pytorch.org/get-started/locally/>`__ (although core features
are guaranteed to be backward compatible with pytorch>=1.13).
are guaranteed to be backward compatible with pytorch>=2.0).
Nightly releases can be installed via

.. code-block::
Expand Down
2 changes: 1 addition & 1 deletion knowledge_base/VERSIONING_ISSUES.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Versioning Issues

## Pytorch version
This issue is related to https://github.com/pytorch/rl/issues/689. Using PyTorch versions <1.13 and installing stable package leads to undefined symbol errors. For example:
This issue is related to https://github.com/pytorch/rl/issues/689. Using PyTorch versions <2.0 and installing stable package leads to undefined symbol errors. For example:
```
ImportError: /usr/local/lib/python3.7/dist-packages/torchrl/_torchrl.so: undefined symbol: _ZN8pybind116detail11type_casterIN2at6TensorEvE4loadENS_6handleEb
```
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ class VmapModule(TensorDictModuleBase):

def __init__(self, module: TensorDictModuleBase, vmap_dim=None):
if not _has_functorch:
raise ImportError("VmapModule requires torch>=1.13.")
raise ImportError("VmapModule requires torch>=2.0.")
super().__init__()
self.in_keys = module.in_keys
self.out_keys = module.out_keys
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from functorch import vmap
except ImportError:
raise ImportError(
"vmap couldn't be found. Make sure you have torch>1.13 installed."
"vmap couldn't be found. Make sure you have torch>2.0 installed."
) from err


Expand Down
Loading