Skip to content

Commit

Permalink
[Feature] VC1 integration (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 7, 2023
1 parent 3e512ff commit 002a58a
Show file tree
Hide file tree
Showing 14 changed files with 611 additions and 23 deletions.
9 changes: 5 additions & 4 deletions .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ unset PYTORCH_VERSION
# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config.

set -e
set -v

eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env
Expand Down Expand Up @@ -34,13 +35,13 @@ else
fi

# smoke test
python -c "import functorch"
python3 -c "import functorch"

# install snapshot
pip install git+https://github.com/pytorch/torchsnapshot
pip3 install git+https://github.com/pytorch/torchsnapshot

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git
pip3 install git+https://github.com/pytorch-labs/tensordict.git

printf "* Installing torchrl\n"
python setup.py develop
python3 setup.py develop
16 changes: 16 additions & 0 deletions .circleci/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ pip3 install git+https://github.com/pytorch-labs/tensordict.git
printf "* Installing torchrl\n"
python setup.py develop


if [ "${CU_VERSION:-}" != cpu ] ; then
printf "* Installing VC1\n"
python3 -c """
from torchrl.envs.transforms.vc1 import VC1Transform
VC1Transform.install_vc_models(auto_exit=True)
"""

python3 -c """
import vc_models
from vc_models.models.vit import model_utils
print(model_utils)
"""
fi

# ==================================================================================== #
# ================================ Run tests ========================================= #

Expand All @@ -151,6 +166,7 @@ python -m torch.utils.collect_env
export MKL_THREADING_LAYER=GNU
export CKPT_BACKEND=torch


pytest test/smoke_test.py -v --durations 200
pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test \
Expand Down
22 changes: 18 additions & 4 deletions .circleci/unittest/linux/scripts/run_test.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#!/usr/bin/env bash

set -e
set -v

eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env

export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
python3 -m torch.utils.collect_env
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'

Expand All @@ -19,8 +20,21 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU
export CKPT_BACKEND=torch

python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest --instafail -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py
# install vc1
python3 -c """
from torchrl.envs.transforms.vc1 import VC1Transform
VC1Transform.install_vc_models(auto_exit=True)
"""

python3 -c """
import vc_models
from vc_models.models.vit import model_utils
print(model_utils)
"""

python3 .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200
python3 .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
python3 .circleci/unittest/helpers/coverage_run_parallel.py -m pytest --instafail -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py -k VC1
python3 .circleci/unittest/helpers/coverage_run_parallel.py -m pytest --instafail -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py33
coverage combine
coverage xml -i
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ We also give users the ability to compose a replay buffer using the following co
TensorStorage
Writer
RoundRobinWriter
TensorDictRoundRobinWriter

Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes.
:class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery.
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ to be able to create this other composition:
ToTensorImage
UnsqueezeTransform
VecNorm
VC1Transform
VIPRewardTransform
VIPTransform

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def _main(argv):
"tqdm",
"hydra-core>=1.1",
"hydra-submitit-launcher",
"git",
],
"checkpointing": [
"torchsnapshot",
Expand Down
Loading

1 comment on commit 002a58a

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 002a58a Previous: 3e512ff Ratio
benchmarks/test_objectives_benchmarks.py::test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 158.06968316342306 iter/sec (stddev: 0.0006614598341198669) 339.3774929201242 iter/sec (stddev: 0.000135207431415453) 2.15
benchmarks/test_objectives_benchmarks.py::test_cql_speed 18.307000205398285 iter/sec (stddev: 0.0038074374219489275) 36.999548272516016 iter/sec (stddev: 0.0013107267237379887) 2.02
benchmarks/test_objectives_benchmarks.py::test_a2c_speed 85.4212555324842 iter/sec (stddev: 0.0007474653596848353) 173.6022277851647 iter/sec (stddev: 0.00026813353917611667) 2.03
benchmarks/test_objectives_benchmarks.py::test_ppo_speed 80.84765393800338 iter/sec (stddev: 0.0013935248030822345) 167.49650082467224 iter/sec (stddev: 0.00026963124924429436) 2.07
benchmarks/test_objectives_benchmarks.py::test_reinforce_speed 97.57530684771916 iter/sec (stddev: 0.0017098363011464558) 211.05872334936794 iter/sec (stddev: 0.0004300661954815756) 2.16
benchmarks/test_objectives_benchmarks.py::test_iql_speed 20.576744043601323 iter/sec (stddev: 0.003108476678582431) 44.86697354160158 iter/sec (stddev: 0.0007142151655193952) 2.18

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.