Skip to content

Commit

Permalink
Linux GPU Brax Unittests (pytorch#1133)
Browse files Browse the repository at this point in the history
  • Loading branch information
osalpekar authored Jun 20, 2023
1 parent 63fb59b commit b0e19ff
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 9 deletions.
8 changes: 4 additions & 4 deletions .circleci/unittest/linux_libs/scripts_brax/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ unset PYTORCH_VERSION
# so no need to set PYTORCH_VERSION.
# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config.

set -e
set -euxo pipefail

eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env
Expand All @@ -30,13 +30,13 @@ if [ "${CU_VERSION:-}" == cpu ] ; then
# conda install -y pytorch torchvision cpuonly -c pytorch-nightly
# use pip to install pytorch as conda can frequently pick older release
# conda install -y pytorch cpuonly -c pytorch-nightly
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall --progress-bar off
else
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall --progress-bar off
fi

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git
pip install git+https://github.com/pytorch-labs/tensordict.git --progress-bar off

# smoke test
python -c "import functorch;import tensordict"
Expand Down
9 changes: 9 additions & 0 deletions .circleci/unittest/linux_libs/scripts_brax/run_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/env bash

set -euxo pipefail

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
bash ${this_dir}/setup_env.sh
bash ${this_dir}/install.sh
bash ${this_dir}/run_test.sh
bash ${this_dir}/post_process.sh
8 changes: 5 additions & 3 deletions .circleci/unittest/linux_libs/scripts_brax/run_test.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#!/usr/bin/env bash

set -e
set -euxo pipefail

eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env
apt-get update && apt-get install -y git wget


export PYTORCH_TEST_WITH_SLOW='1'
Expand All @@ -17,7 +16,7 @@ env_dir="${root_dir}/env"
lib_dir="${env_dir}/lib"

# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU
# more logging
export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON
Expand All @@ -27,6 +26,9 @@ export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON

# this workflow only tests the libs
python -c "import brax"
python -c "import brax.envs"
python -c "import jax"
python3 -c 'import torch;t = torch.ones([2,2], device="cuda:0");print(t);print("tensor device:" + str(t.device))'

python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestBrax --error-for-skips
coverage combine
Expand Down
5 changes: 3 additions & 2 deletions .circleci/unittest/linux_libs/scripts_brax/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# Do not install PyTorch and torchvision here, otherwise they also get cached.

set -e
set -euxo pipefail

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Avoid error: "fatal: unsafe repository"
Expand Down Expand Up @@ -57,6 +57,7 @@ pip install pip --upgrade
conda env update --file "${this_dir}/environment.yml" --prune

#yum makecache
#yum -y install glfw-devel
# sudo yum -y install glfw
yum -y install glfw-devel
#yum -y install libGLEW
#yum -y install gcc-c++
32 changes: 32 additions & 0 deletions .github/workflows/test-linux-brax.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Brax Tests on Linux

on:
pull_request:
push:
branches:
- nightly
- main
- release/*
workflow_dispatch:

jobs:
unittests:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/rl
runner: "linux.g5.4xlarge.nvidia.gpu"
gpu-arch-type: cuda
gpu-arch-version: "11.7"
timeout: 120
script: |
set -euo pipefail
export PYTHON_VERSION="3.8"
export CU_VERSION="11.7"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
nvidia-smi
bash .circleci/unittest/linux_libs/scripts_brax/run_all.sh

0 comments on commit b0e19ff

Please sign in to comment.