diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 3ae16869835..30e01cfc4b5 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -28,6 +28,7 @@ dependencies: - mlflow - av - coverage - - ray<2.8.0 + - ray - transformers - ninja + - timm diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index b5066472907..38235043d3f 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -88,7 +88,8 @@ conda deactivate conda activate "${env_dir}" echo "installing gymnasium" -pip3 install "gymnasium[atari,ale-py,accept-rom-license]" +pip3 install "gymnasium" +pip3 install ale_py pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py pip3 install mujoco -U diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index 2f5210135fe..6d27071791b 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -27,5 +27,5 @@ dependencies: - mlflow - av - coverage - - ray<2.8.0 + - ray - virtualenv diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml index 9efcbbfa640..d34011e7bdc 100644 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml @@ -24,5 +24,5 @@ dependencies: - dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control} - patchelf - pyopengl==3.1.4 - - ray<2.8.0 + - ray - av diff --git a/.github/unittest/linux_optdeps/scripts/environment.yml b/.github/unittest/linux_optdeps/scripts/environment.yml index 7263c14192f..fcc3c3481d0 100644 --- a/.github/unittest/linux_optdeps/scripts/environment.yml +++ b/.github/unittest/linux_optdeps/scripts/environment.yml @@ -17,4 +17,4 @@ dependencies: - pyyaml - scipy - coverage - - ray<2.8.0 + - ray diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index d2e13eddd63..e8728180c67 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -22,7 +22,7 @@ jobs: tests-cpu: strategy: matrix: - python_version: ["3.8", "3.9", "3.10", "3.11"] + python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: @@ -51,7 +51,7 @@ jobs: tests-gpu: strategy: matrix: - python_version: ["3.10"] + python_version: ["3.11"] cuda_arch_version: ["12.1"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main diff --git a/setup.py b/setup.py index 95dc0802a4f..73541790e8f 100644 --- a/setup.py +++ b/setup.py @@ -274,6 +274,7 @@ def _main(argv): "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Development Status :: 4 - Beta", diff --git a/test/_utils_internal.py b/test/_utils_internal.py index e43c0ff2ecf..61b0c003f9d 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -56,11 +56,32 @@ def HALFCHEETAH_VERSIONED(): def PONG_VERSIONED(): # load gym + # Gymnasium says that the ale_py behaviour changes from 1.0 + # but with python 3.12 it is already the case with 0.29.1 + try: + import ale_py # noqa + except ImportError: + pass + if gym_backend() is not None: _set_gym_environments() return _PONG_VERSIONED +def BREAKOUT_VERSIONED(): + # load gym + # Gymnasium says that the ale_py behaviour changes from 1.0 + # but with python 3.12 it is already the case with 0.29.1 + try: + import ale_py # noqa + except ImportError: + pass + + if gym_backend() is not None: + _set_gym_environments() + return _BREAKOUT_VERSIONED + + def PENDULUM_VERSIONED(): # load gym if gym_backend() is not None: @@ -69,42 +90,46 @@ def PENDULUM_VERSIONED(): def _set_gym_environments(): - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED _CARTPOLE_VERSIONED = None _HALFCHEETAH_VERSIONED = None _PENDULUM_VERSIONED = None _PONG_VERSIONED = None + _BREAKOUT_VERSIONED = None @implement_for("gym", None, "0.21.0") def _set_gym_environments(): # noqa: F811 - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED _CARTPOLE_VERSIONED = "CartPole-v0" _HALFCHEETAH_VERSIONED = "HalfCheetah-v2" _PENDULUM_VERSIONED = "Pendulum-v0" _PONG_VERSIONED = "Pong-v4" + _BREAKOUT_VERSIONED = "Breakout-v4" @implement_for("gym", "0.21.0", None) def _set_gym_environments(): # noqa: F811 - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED _CARTPOLE_VERSIONED = "CartPole-v1" _HALFCHEETAH_VERSIONED = "HalfCheetah-v4" _PENDULUM_VERSIONED = "Pendulum-v1" _PONG_VERSIONED = "ALE/Pong-v5" + _BREAKOUT_VERSIONED = "ALE/Breakout-v5" @implement_for("gymnasium") def _set_gym_environments(): # noqa: F811 - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED _CARTPOLE_VERSIONED = "CartPole-v1" _HALFCHEETAH_VERSIONED = "HalfCheetah-v4" _PENDULUM_VERSIONED = "Pendulum-v1" _PONG_VERSIONED = "ALE/Pong-v5" + _BREAKOUT_VERSIONED = "ALE/Breakout-v5" if _has_gym: diff --git a/test/test_transforms.py b/test/test_transforms.py index fcfd6f08aff..caac5c4f055 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -22,6 +22,7 @@ import torch from _utils_internal import ( # noqa + BREAKOUT_VERSIONED, dtype_fixture, get_default_devices, HALFCHEETAH_VERSIONED, @@ -248,7 +249,10 @@ def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -257,7 +261,10 @@ def test_trans_serial_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( @@ -267,7 +274,10 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) @@ -572,7 +582,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_serial_trans_env_check(self): def make_env(): @@ -604,7 +617,10 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = ContinuousActionVecMockEnv() @@ -650,7 +666,10 @@ def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -674,7 +693,10 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @@ -1621,7 +1643,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_serial_trans_env_check(self): def make_env(): @@ -1637,7 +1662,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), StepCounter(10)) @@ -1921,7 +1949,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): ct = CatTensors( @@ -1946,7 +1977,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize( @@ -2298,7 +2332,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): keys = ["pixels"] @@ -2313,7 +2350,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_gym, reason="No Gym detected") @pytest.mark.parametrize("out_key", [None, ["outkey"], [("out", "key")]]) @@ -2360,7 +2400,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -2377,7 +2420,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("action_key", ["action", ("nested", "stuff")]) def test_transform_no_env(self, action_key): @@ -2605,7 +2651,10 @@ def make_env(): env = ParallelEnv(1, make_env) check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass del env def test_trans_serial_env_check(self, dtype_fixture): # noqa: F811 @@ -2623,7 +2672,10 @@ def test_trans_parallel_env_check(self, dtype_fixture): # noqa: F811 try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self, dtype_fixture): # noqa: F811 t = DoubleToFloat(in_keys=["observation"], in_keys_inv=["action"]) @@ -2778,7 +2830,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): t = Compose( @@ -2801,7 +2856,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_env(self): base_env = TestExcludeTransform.EnvWithManyKeys() @@ -3010,7 +3068,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): t = Compose( @@ -3033,7 +3094,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_env(self): base_env = TestExcludeTransform.EnvWithManyKeys() @@ -3203,7 +3267,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -3226,7 +3293,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_tv, reason="no torchvision") @pytest.mark.parametrize("nchannels", [1, 3]) @@ -3381,7 +3451,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -3396,7 +3469,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): t = FrameSkipTransform(2) @@ -3622,7 +3698,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): out_keys = None @@ -3641,7 +3720,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) def test_transform_env(self, out_keys): @@ -3717,7 +3799,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), NoopResetEnv()) @@ -3890,7 +3975,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check( self, @@ -3919,7 +4007,10 @@ def test_trans_parallel_env_check( try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("standard_normal", [True, False]) @pytest.mark.parametrize("in_key", ["observation", ("some_other", "observation")]) @@ -4473,7 +4564,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4490,7 +4584,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_gym, reason="No gym") @pytest.mark.parametrize("out_key", ["pixels", ("agents", "pixels")]) @@ -4549,7 +4646,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4564,7 +4664,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("reward_key", ["reward", ("agents", "reward")]) def test_transform_no_env(self, reward_key): @@ -4684,7 +4787,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4699,7 +4805,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("standard_normal", [True, False]) def test_transform_no_env(self, standard_normal): @@ -4817,7 +4926,10 @@ def make_env(): r = env.rollout(4) assert r["next", "episode_reward"].unique().numel() > 1 finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4838,7 +4950,10 @@ def test_trans_parallel_env_check(self): r = env.rollout(4) assert r["next", "episode_reward"].unique().numel() > 1 finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("has_in_keys,", [True, False]) @pytest.mark.parametrize("reset_keys,", [None, ["_reset"] * 3]) @@ -5491,7 +5606,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -5508,7 +5626,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @@ -5796,7 +5917,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -5805,7 +5929,10 @@ def test_trans_serial_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_parallel_env_check(self): env = TransformedEnv( @@ -5814,7 +5941,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("squeeze_dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @@ -6007,7 +6137,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) @@ -6020,7 +6153,10 @@ def test_trans_serial_env_check(self, mode, device): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) @@ -6033,7 +6169,10 @@ def test_trans_parallel_env_check(self, mode, device): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [SerialEnv, ParallelEnv]) @@ -6242,7 +6381,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -6259,7 +6401,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("out_keys", [None, ["stuff"], [("nested", "stuff")]]) @pytest.mark.parametrize("default_dtype", [torch.float32, torch.float64]) @@ -6392,7 +6537,10 @@ def make_env(): assert "mykey" in env.reset().keys() assert ("next", "mykey") in env.rollout(3).keys(True) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_serial_trans_env_check(self): def make_env(): @@ -6407,7 +6555,10 @@ def make_env(): assert "mykey" in env.reset().keys() assert ("next", "mykey") in env.rollout(3).keys(True) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_parallel_env_check(self): env = TransformedEnv( @@ -6421,7 +6572,10 @@ def test_trans_parallel_env_check(self): assert ("next", "mykey") in r.keys(True) assert r["next", "mykey"].shape == torch.Size([2, 3, 4]) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("spec_shape", [[4], [2, 4]]) def test_trans_serial_env_check(self, spec_shape): @@ -6684,7 +6838,10 @@ def test_parallel_trans_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -6707,7 +6864,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @@ -6871,7 +7031,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("shape", [(), (2,)]) def test_trans_serial_env_check(self, shape): @@ -6884,7 +7047,10 @@ def test_trans_serial_env_check(self, shape): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_parallel_env_check(self): state_dim = 7 @@ -6896,7 +7062,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): state_dim = 7 @@ -8748,7 +8917,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def make_env(): return TransformedEnv( @@ -8766,7 +8938,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self, create_copy): def make_env(): @@ -8808,7 +8983,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass env = TransformedEnv( ParallelEnv(2, make_env), RenameTransform( @@ -8822,7 +9000,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("mode", ["forward", "_call"]) @pytest.mark.parametrize( @@ -9021,7 +9202,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): def make_env(): @@ -9033,7 +9217,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_parallel_env_check(self): def make_env(): @@ -9045,7 +9232,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): with pytest.raises(ValueError, match="init_key can only be of type str"): @@ -9315,7 +9505,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): out_key = "reward" @@ -9324,7 +9517,10 @@ def test_trans_serial_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_parallel_env_check(self): out_key = "reward" @@ -9333,7 +9529,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_model(self): actor = self._make_actor() @@ -9472,21 +9671,30 @@ def test_parallel_trans_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, self._env_class), ActionMask()) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_parallel_env_check(self): env = TransformedEnv(ParallelEnv(2, self._env_class), ActionMask()) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): t = ActionMask() @@ -9603,7 +9811,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("in_keys", ["observation"]) @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) @@ -9654,7 +9865,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): t = DeviceCastTransform("cpu:1", "cpu:0", in_keys=["a"], out_keys=["b"]) @@ -9794,7 +10008,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): def make_env(): @@ -9813,7 +10030,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): t = DeviceCastTransform("cpu:1", "cpu:0") @@ -9905,7 +10125,10 @@ def test_parallel_trans_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -9915,7 +10138,10 @@ def test_trans_serial_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_parallel_env_check(self): env = TransformedEnv( @@ -9925,7 +10151,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) def test_transform_compose(self, batch): @@ -10022,7 +10251,7 @@ class TestEndOfLife(TransformBase): def test_trans_parallel_env_check(self): def make(): with set_gym_backend("gymnasium"): - return GymEnv("ALE/Breakout-v5") + return GymEnv(BREAKOUT_VERSIONED()) with pytest.warns(UserWarning, match="The base_env is not a gym env"): with pytest.raises(AttributeError): @@ -10034,7 +10263,7 @@ def make(): def test_trans_serial_env_check(self): def make(): with set_gym_backend("gymnasium"): - return GymEnv("ALE/Breakout-v5") + return GymEnv(BREAKOUT_VERSIONED()) with pytest.warns(UserWarning, match="The base_env is not a gym env"): env = TransformedEnv(SerialEnv(2, make), transform=EndOfLifeTransform()) @@ -10045,7 +10274,7 @@ def make(): def test_single_trans_env_check(self, eol_key, lives_key): with set_gym_backend("gymnasium"): env = TransformedEnv( - GymEnv("ALE/Breakout-v5"), + GymEnv(BREAKOUT_VERSIONED()), transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), ) check_env_specs(env) @@ -10056,7 +10285,7 @@ def test_serial_trans_env_check(self, eol_key, lives_key): def make(): with set_gym_backend("gymnasium"): return TransformedEnv( - GymEnv("ALE/Breakout-v5"), + GymEnv(BREAKOUT_VERSIONED()), transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), ) @@ -10069,7 +10298,7 @@ def test_parallel_trans_env_check(self, eol_key, lives_key): def make(): with set_gym_backend("gymnasium"): return TransformedEnv( - GymEnv("ALE/Breakout-v5"), + GymEnv(BREAKOUT_VERSIONED()), transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), ) @@ -10077,7 +10306,10 @@ def make(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): t = EndOfLifeTransform() @@ -10098,7 +10330,7 @@ def test_transform_env(self, eol_key, lives_key): with set_gym_backend("gymnasium"): env = TransformedEnv( - GymEnv("ALE/Breakout-v5"), + GymEnv(BREAKOUT_VERSIONED()), transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), ) check_env_specs(env) @@ -10480,7 +10712,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_serial_trans_env_check(self): def make_env(): @@ -10507,7 +10742,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -10520,7 +10758,10 @@ def test_trans_serial_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass class TestRemoveEmptySpecs(TransformBase): @@ -10581,7 +10822,10 @@ def test_parallel_trans_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): with pytest.raises(