Skip to content

Commit

Permalink
[Version] Updating to torch 1.13 (pytorch#627)
Browse files Browse the repository at this point in the history
* init

* new release updates

* tutorials

* amend

* lint

* lint

* lint
  • Loading branch information
vmoens authored Oct 31, 2022
1 parent 469c871 commit f58862b
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ jobs:
bash <(curl -s https://codecov.io/bash) -Z -F habitat-gpu
- run:
name: Post Process
command: docker run -t --gpus all -v $PWD:$PWD .circleci/unittest/linux_libs/scripts_habitat/post_process.sh
command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_libs/scripts_habitat/post_process.sh
- store_test_results:
path: test-results

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ export MKL_THREADING_LAYER=GNU

coverage run -m pytest test/smoke_test.py -v --durations 20
coverage run -m pytest test/smoke_test_deps.py -v --durations 20 -k 'test_gym or test_dm_control_pixels or test_dm_control'
MUJOCO_GL=egl coverage run -m xvfb-run -a pytest --instafail -v --durations 20
#MUJOCO_GL=egl coverage run -m xvfb-run -a pytest --instafail -v --durations 20
MUJOCO_GL=egl coverage run -m pytest --instafail -v --durations 20
#pytest --instafail -v --durations 20
#python test/test_libs.py
coverage xml -i
5 changes: 1 addition & 4 deletions .circleci/unittest/linux_stable/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ else
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
fi

printf "Installing functorch\n"
pip3 install functorch

# smoke test
python -c "import functorch"
python -c "import torch;import functorch"

printf "* Installing torchrl\n"
printf "g++ version: "
Expand Down
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ Language: Cpp
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignConsecutiveMacros: false
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignConsecutiveAssignments: false
AlignEscapedNewlines: Left
AlignOperands: true
AlignTrailingComments: true
Expand Down
15 changes: 6 additions & 9 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ jobs:
steps:
- name: Checkout torchrl
uses: actions/checkout@v2
- name: Install PyTorch 1.12 RC
- name: Install PyTorch RC
run: |
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
python3 -mpip install torch==1.12 ${{ matrix.cuda_support[1] }}
python3 -mpip install "git+https://github.com/pytorch/functorch.git@release/0.2"
python3 -mpip install torch ${{ matrix.cuda_support[1] }}
- name: Build wheel
run: |
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
Expand Down Expand Up @@ -60,10 +59,9 @@ jobs:
architecture: x64
- name: Checkout torchrl
uses: actions/checkout@v2
- name: Install PyTorch 1.12 RC
- name: Install PyTorch RC
run: |
python3 -mpip install torch==1.12 --extra-index-url https://download.pytorch.org/whl/cpu
python3 -mpip install "git+https://github.com/pytorch/functorch.git@release/0.2"
python3 -mpip install torch --extra-index-url https://download.pytorch.org/whl/cpu
- name: Build wheel
run: |
export CC=clang CXX=clang++
Expand Down Expand Up @@ -95,10 +93,9 @@ jobs:
architecture: x64
- name: Checkout torchrl
uses: actions/checkout@v2
- name: Install PyTorch 1.12 RC
- name: Install PyTorch RC
run: |
python3 -mpip install torch==1.12 torchvision --extra-index-url https://download.pytorch.org/whl/cpu
python3 -mpip install "git+https://github.com/pytorch/functorch.git@release/0.2"
python3 -mpip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu
- name: Upgrade pip
run: |
python3 -mpip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ conda install pytorch torchvision torchaudio cudatoolkit=11.6 -c pytorch -c cond
# For CPU-only build
conda install pytorch torchvision cpuonly -c pytorch

# Functorch will be integrated in torch from 1.13. As of now, we still need the latest pip release
# For torch 1.12 (and not above), one should install functorch separately:
pip3 install functorch
```
Expand Down
2 changes: 2 additions & 0 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def test_selectkeys(self):
assert key2 not in td_out.keys()

def test_selectkeys_statedict(self):
if not _has_ts:
os.environ["CKPT_BACKEND"] = "torch"
trainer = mocking_trainer()
key1 = "first key"
key2 = "second key"
Expand Down
15 changes: 7 additions & 8 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@
from torch import nn, optim

from torchrl._utils import KeyDependentDefaultDict

try:
from tqdm import tqdm

_has_tqdm = True
except ImportError:
_has_tqdm = False

from torchrl._utils import _CKPT_BACKEND
from torchrl.collectors.collectors import _DataCollector
from torchrl.data import (
Expand All @@ -40,6 +32,13 @@
from torchrl.objectives.common import LossModule
from torchrl.trainers.loggers import Logger

try:
from tqdm import tqdm

_has_tqdm = True
except ImportError:
_has_tqdm = False

try:
from torchsnapshot import StateDict, Snapshot

Expand Down
12 changes: 12 additions & 0 deletions tutorials/src/envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from gym.envs.classic_control.pendulum import PendulumEnv


class PendulumWithSafety(PendulumEnv):
def step(self, u):
out = super().step(u)
sin, cos, vel = out[0]
safe = sin > 0 and cos > 0 # some quadrant is considered safe
safety_cost = abs(sin) * float(sin <= 0) + abs(cos) * float(cos <= 0)
out[-1]["safety_cost"] = safety_cost
out[-1]["safe"] = safe
return out
41 changes: 41 additions & 0 deletions tutorials/train_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c2e7018e-62a9-4d3f-9e75-343e8910e981",
"metadata": {},
"source": [
"# TorchRL overview"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c75a1ad-128c-4a8c-b387-7021dd6767a1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit f58862b

Please sign in to comment.