Skip to content

Commit

Permalink
[CI,Feature] Upgrade to gymnasium (pytorch#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 9, 2023
1 parent 74fdadc commit 5ea4f05
Show file tree
Hide file tree
Showing 22 changed files with 383 additions and 66 deletions.
3 changes: 0 additions & 3 deletions .circleci/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ dependencies:
- hypothesis
- future
- cloudpickle
- gym
- pygame
- gym[accept-rom-license]
- gym[atari]
- moviepy
- tqdm
- pytest
Expand Down
32 changes: 32 additions & 0 deletions .circleci/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,35 @@ fi
pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

conda deactivate
conda activate "${env_dir}"

if [[ $OSTYPE != 'darwin'* ]]; then
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
# rename them
PY_VERSION=$(python --version)
echo "installing ale-py for ${PY_PY_VERSION}"
if [[ $PY_VERSION == *"3.7"* ]]; then
wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.8"* ]]; then
wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.9"* ]]; then
wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.10"* ]]; then
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
fi
echo "installing gymnasium"
pip install "gymnasium[atari,accept-rom-license]"
else
pip install "gymnasium[atari,accept-rom-license]"
fi
3 changes: 0 additions & 3 deletions .circleci/unittest/linux_examples/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ dependencies:
- hypothesis
- future
- cloudpickle
- gym
- pygame
- gym[accept-rom-license]
- gym[atari]
- moviepy
- tqdm
- pytest
Expand Down
30 changes: 30 additions & 0 deletions .circleci/unittest/linux_examples/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,33 @@ fi
pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

conda deactivate
conda activate "${env_dir}"

if [[ $OSTYPE != 'darwin'* ]]; then
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
# rename them
PY_VERSION=$(python --version)
if [[ $PY_VERSION == *"3.7"* ]]; then
wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.8"* ]]; then
wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.9"* ]]; then
wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.10"* ]]; then
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
fi
pip install "gymnasium[atari,accept-rom-license]"
else
pip install "gymnasium[atari,accept-rom-license]"
fi
4 changes: 0 additions & 4 deletions .circleci/unittest/linux_libs/scripts_envpool/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ dependencies:
- hypothesis
- future
- cloudpickle
- gym
- pygame
- gym[accept-rom-license]
- gym[atari]
- moviepy
- pytest-cov
- pytest-mock
Expand All @@ -22,4 +19,3 @@ dependencies:
- scipy
- dm_control
- coverage
- envpool
34 changes: 34 additions & 0 deletions .circleci/unittest/linux_libs/scripts_envpool/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,37 @@ cat "${this_dir}/environment.yml"
pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

conda deactivate
conda activate "${env_dir}"

if [[ $OSTYPE != 'darwin'* ]]; then
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
# rename them
PY_VERSION=$(python --version)
echo "installing ale-py for ${PY_PY_VERSION}"
if [[ $PY_VERSION == *"3.7"* ]]; then
wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.8"* ]]; then
wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.9"* ]]; then
wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.10"* ]]; then
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
fi
echo "installing gym"
# envpool does not currently work with gymnasium
pip install "gym[atari,accept-rom-license]"
else
pip install "gym[atari,accept-rom-license]"
fi
pip install envpool
46 changes: 46 additions & 0 deletions .circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,49 @@ do
conda deactivate
conda env remove --prefix ./cloned_env
done

# For this version "gym[accept-rom-license]" is required.
for GYM_VERSION in '0.27'
do
# Create a copy of the conda env and work with this
conda deactivate
conda create --prefix ./cloned_env --clone ./env -y
conda activate ./cloned_env

echo "Testing gym version: ${GYM_VERSION}"
pip3 install 'gymnasium[accept-rom-license]'==$GYM_VERSION


if [[ $OSTYPE != 'darwin'* ]]; then
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
# rename them
PY_VERSION=$(python --version)
if [[ $PY_VERSION == *"3.7"* ]]; then
wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.8"* ]]; then
wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.9"* ]]; then
wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.10"* ]]; then
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
fi
pip install gymnasium[atari]
else
pip install gymnasium[atari]
fi

$DIR/run_test.sh

# delete the conda copy
conda deactivate
conda env remove --prefix ./cloned_env
done
3 changes: 0 additions & 3 deletions .circleci/unittest/linux_stable/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ dependencies:
- hypothesis
- future
- cloudpickle
- gym
- pygame
- gym[accept-rom-license]
- gym[atari]
- moviepy
- tqdm
- pytest
Expand Down
32 changes: 32 additions & 0 deletions .circleci/unittest/linux_stable/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,35 @@ fi
pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

conda deactivate
conda activate "${env_dir}"

if [[ $OSTYPE != 'darwin'* ]]; then
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
# rename them
PY_VERSION=$(python --version)
echo "installing ale-py for ${PY_PY_VERSION}"
if [[ $PY_VERSION == *"3.7"* ]]; then
wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.8"* ]]; then
wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.9"* ]]; then
wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.10"* ]]; then
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
fi
echo "installing gymnasium"
pip install "gymnasium[atari,accept-rom-license]"
else
pip install "gymnasium[atari,accept-rom-license]"
fi
10 changes: 10 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def _set_gym_environments(): # noqa: F811
PONG_VERSIONED = "ALE/Pong-v5"


@implement_for("gymnasium", "0.27.0", None)
def _set_gym_environments(): # noqa: F811
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED

CARTPOLE_VERSIONED = "CartPole-v1"
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
PENDULUM_VERSIONED = "Pendulum-v1"
PONG_VERSIONED = "ALE/Pong-v5"


if _has_gym:
_set_gym_environments()

Expand Down
11 changes: 10 additions & 1 deletion test/smoke_test_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,16 @@ def test_dm_control_pixels():


def test_gym():
import gym # noqa: F401
try:
import gymnasium as gym
except ImportError as err:
ERROR = err
try:
import gym # noqa: F401
except ImportError as err:
raise ImportError(
f"gym and gymnasium load failed. Gym got error {err}."
) from ERROR

assert _has_gym
env = GymEnv(PONG_VERSIONED)
Expand Down
10 changes: 8 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@

gym_version = None
if _has_gym:
import gym
try:
import gymnasium as gym
except ModuleNotFoundError:
import gym

gym_version = version.parse(gym.__version__)

Expand Down Expand Up @@ -1055,7 +1058,10 @@ def test_batch_unlocked_with_batch_size(device):
)
@pytest.mark.parametrize("device", get_available_devices())
def test_info_dict_reader(device, seed=0):
import gym
try:
import gymnasium as gym
except ModuleNotFoundError:
import gym

env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED), device=device)
env.set_info_dict_reader(default_info_dict_reader(["x_position"]))
Expand Down
33 changes: 23 additions & 10 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,23 @@
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator

if _has_gym:
import gym

gym_version = version.parse(gym.__version__)
if gym_version > version.parse("0.19"):
from gym.wrappers.pixel_observation import PixelObservationWrapper
else:
from torchrl.envs.libs.utils import (
GymPixelObservationWrapper as PixelObservationWrapper,
)
try:
import gymnasium as gym
from gymnasium import __version__ as gym_version

gym_version = version.parse(gym_version)
from gymnasium.wrappers.pixel_observation import PixelObservationWrapper
except ModuleNotFoundError:
import gym

gym_version = version.parse(gym.__version__)
if gym_version > version.parse("0.19"):
from gym.wrappers.pixel_observation import PixelObservationWrapper
else:
from torchrl.envs.libs.utils import (
GymPixelObservationWrapper as PixelObservationWrapper,
)


if _has_dmc:
from dm_control import suite
Expand Down Expand Up @@ -163,6 +171,11 @@ def _make_gym_environment(env_name): # noqa: F811
return gym.make(env_name, render_mode="rgb_array")


@implement_for("gymnasium", "0.27", None)
def _make_gym_environment(env_name): # noqa: F811
return gym.make(env_name, render_mode="rgb_array")


@pytest.mark.skipif(not _has_dmc, reason="no dm_control library found")
@pytest.mark.parametrize("env_name,task", [["cheetah", "run"]])
@pytest.mark.parametrize("frame_skip", [1, 3])
Expand Down Expand Up @@ -281,7 +294,7 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs):
and env_kwargs.get("from_pixels", False)
and torch.cuda.device_count() == 0
):
pytest.skip(
raise pytest.skip(
"Skipping test as rendering is not supported in tests before gym 0.26."
)
env = env_lib(*env_args, **env_kwargs)
Expand Down
10 changes: 5 additions & 5 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,11 +994,6 @@ def test_equality_composite(self):
assert ts != ts_other


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)


class TestSpec:
@pytest.mark.parametrize(
"action_spec_cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec]
Expand Down Expand Up @@ -1672,3 +1667,8 @@ def test_unboundeddiscrete(
spec = UnboundedDiscreteTensorSpec(shape=shape1, device="cpu", dtype=torch.long)
assert spec == spec.clone()
assert spec is not spec.clone()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading

0 comments on commit 5ea4f05

Please sign in to comment.