From b3f99b30efc3522498fdb2f1b6d6c2de4bf2f040 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 30 Jul 2024 16:53:46 +0100 Subject: [PATCH] [CI, Tests] Fix windows tests (#2337) --- .../windows_optdepts/scripts/install.sh | 1 + test/test_transforms.py | 319 +++++++++--------- 2 files changed, 152 insertions(+), 168 deletions(-) diff --git a/.github/unittest/windows_optdepts/scripts/install.sh b/.github/unittest/windows_optdepts/scripts/install.sh index 5c425d18a95..f13b83a0be0 100644 --- a/.github/unittest/windows_optdepts/scripts/install.sh +++ b/.github/unittest/windows_optdepts/scripts/install.sh @@ -37,6 +37,7 @@ fi # submodules git submodule sync && git submodule update --init --recursive +python -m pip install "numpy<2.0" printf "Installing PyTorch with %s\n" "${cudatoolkit}" if [[ "$TORCH_VERSION" == "nightly" ]]; then diff --git a/test/test_transforms.py b/test/test_transforms.py index caac5c4f055..bdafba648eb 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1635,11 +1635,11 @@ def test_single_trans_env_check(self, update_done, max_steps): assert "truncated" not in r.keys() assert ("next", "truncated") not in r.keys(True) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), StepCounter(10)) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -1655,9 +1655,9 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), StepCounter(10) + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), StepCounter(10) ) try: check_env_specs(env) @@ -1935,7 +1935,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): ct = CatTensors( in_keys=["observation", "observation_orig"], @@ -1945,7 +1945,7 @@ def make_env(): ) return TransformedEnv(ContinuousActionVecMockEnv(), ct) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -1965,7 +1965,7 @@ def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), ct) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): ct = CatTensors( in_keys=["observation", "observation_orig"], out_key="observation_out", @@ -1973,7 +1973,7 @@ def test_trans_parallel_env_check(self): del_keys=False, ) - env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), ct) + env = TransformedEnv(maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), ct) try: check_env_specs(env) finally: @@ -2177,18 +2177,8 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): assert observation_spec[key].shape == torch.Size([nchannels, 20, h]) @pytest.mark.parametrize("nchannels", [3]) - @pytest.mark.parametrize( - "batch", - [ - [2], - ], - ) - @pytest.mark.parametrize( - "h", - [ - None, - ], - ) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) @pytest.mark.parametrize("keys", [["observation_pixels"]]) @pytest.mark.parametrize("device", get_default_devices()) def test_transform_model(self, keys, h, nchannels, batch, device): @@ -2213,18 +2203,8 @@ def test_transform_model(self, keys, h, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() @pytest.mark.parametrize("nchannels", [3]) - @pytest.mark.parametrize( - "batch", - [ - [2], - ], - ) - @pytest.mark.parametrize( - "h", - [ - None, - ], - ) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) @pytest.mark.parametrize("keys", [["observation_pixels"]]) @pytest.mark.parametrize("device", get_default_devices()) def test_transform_compose(self, keys, h, nchannels, batch, device): @@ -2253,18 +2233,8 @@ def test_transform_compose(self, keys, h, nchannels, batch, device): assert (tdc.get("dont touch") == dont_touch).all() @pytest.mark.parametrize("nchannels", [3]) - @pytest.mark.parametrize( - "batch", - [ - [2], - ], - ) - @pytest.mark.parametrize( - "h", - [ - None, - ], - ) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) @pytest.mark.parametrize("keys", [["observation_pixels"]]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb( @@ -2343,10 +2313,12 @@ def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, DiscreteActionConvMockEnvNumpy), ct) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): keys = ["pixels"] ct = Compose(ToTensorImage(), CenterCrop(w=20, h=20, in_keys=keys)) - env = TransformedEnv(ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct) + env = TransformedEnv( + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct + ) try: check_env_specs(env) finally: @@ -2390,13 +2362,13 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( DiscreteActionConvMockEnvNumpy(), DiscreteActionProjection(7, 10) ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -2412,9 +2384,9 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), DiscreteActionProjection(7, 10), ) try: @@ -2640,7 +2612,9 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self, dtype_fixture): # noqa: F811 + def test_parallel_trans_env_check( + self, dtype_fixture, maybe_fork_ParallelEnv # noqa: F811 + ): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(dtype=torch.float64), @@ -2648,7 +2622,7 @@ def make_env(): ) try: - env = ParallelEnv(1, make_env) + env = maybe_fork_ParallelEnv(1, make_env) check_env_specs(env) finally: try: @@ -2664,9 +2638,13 @@ def test_trans_serial_env_check(self, dtype_fixture): # noqa: F811 ) check_env_specs(env) - def test_trans_parallel_env_check(self, dtype_fixture): # noqa: F811 + def test_trans_parallel_env_check( + self, dtype_fixture, maybe_fork_ParallelEnv # noqa: F811 + ): env = TransformedEnv( - ParallelEnv(2, lambda: ContinuousActionVecMockEnv(dtype=torch.float64)), + maybe_fork_ParallelEnv( + 2, lambda: ContinuousActionVecMockEnv(dtype=torch.float64) + ), DoubleToFloat(in_keys=["observation"], in_keys_inv=["action"]), ) try: @@ -2815,7 +2793,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): t = Compose( CatTensors( @@ -2826,7 +2804,7 @@ def make_env(): env = TransformedEnv(ContinuousActionVecMockEnv(), t) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -2845,14 +2823,14 @@ def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), t) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): t = Compose( CatTensors( in_keys=["observation"], out_key="observation_copy", del_keys=False ), ExcludeTransform("observation_copy"), ) - env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), t) + env = TransformedEnv(maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), t) try: check_env_specs(env) finally: @@ -3053,7 +3031,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): t = Compose( CatTensors( @@ -3064,7 +3042,7 @@ def make_env(): env = TransformedEnv(ContinuousActionVecMockEnv(), t) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -3083,14 +3061,14 @@ def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), t) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): t = Compose( CatTensors( in_keys=["observation"], out_key="observation_copy", del_keys=False ), SelectTransform("observation", "observation_orig"), ) - env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), t) + env = TransformedEnv(maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), t) try: check_env_specs(env) finally: @@ -3256,14 +3234,14 @@ def make_env(): env = SerialEnv(2, make_env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = TransformedEnv( DiscreteActionConvMockEnvNumpy(), FlattenObservation(-3, -1) ) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -3282,9 +3260,9 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), FlattenObservation( -3, -1, @@ -3442,12 +3420,12 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = TransformedEnv(ContinuousActionVecMockEnv(), FrameSkipTransform(2)) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -3462,9 +3440,9 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), FrameSkipTransform(2) + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), FrameSkipTransform(2) ) try: check_env_specs(env) @@ -3685,7 +3663,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): out_keys = None def make_env(): @@ -3694,7 +3672,7 @@ def make_env(): Compose(ToTensorImage(), GrayScale(out_keys=out_keys)), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -3711,10 +3689,10 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): out_keys = None env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), Compose(ToTensorImage(), GrayScale(out_keys=out_keys)), ) try: @@ -3791,11 +3769,11 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), NoopResetEnv()) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -3958,9 +3936,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check( - self, - ): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), @@ -3971,7 +3947,7 @@ def make_env(): ), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -3993,11 +3969,9 @@ def test_trans_serial_env_check( ) check_env_specs(env) - def test_trans_parallel_env_check( - self, - ): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), ObservationNorm( loc=torch.zeros(7), in_keys=["observation"], @@ -4553,14 +4527,14 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( DiscreteActionConvMockEnvNumpy(), Compose(ToTensorImage(), Resize(20, 21, in_keys=["pixels"])), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -4576,9 +4550,9 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), Compose(ToTensorImage(), Resize(20, 21, in_keys=["pixels"])), ) try: @@ -4636,13 +4610,13 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), RewardClipping(-0.1, 0.1) ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -4657,9 +4631,10 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), RewardClipping(-0.1, 0.1) + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), + RewardClipping(-0.1, 0.1), ) try: check_env_specs(env) @@ -4779,11 +4754,11 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), RewardScaling(0.5, 1.5)) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -4798,9 +4773,10 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), RewardScaling(0.5, 1.5) + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), + RewardScaling(0.5, 1.5), ) try: check_env_specs(env) @@ -4913,14 +4889,14 @@ def make_env(): r = env.rollout(4) assert r["next", "episode_reward"].unique().numel() > 1 - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), Compose(RewardScaling(loc=-1, scale=1), RewardSum()), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) r = env.rollout(4) @@ -4940,9 +4916,9 @@ def test_trans_serial_env_check(self): r = env.rollout(4) assert r["next", "episode_reward"].unique().numel() > 1 - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), Compose(RewardScaling(loc=-1, scale=1), RewardSum()), ) try: @@ -5595,14 +5571,14 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), UnsqueezeTransform(-1, in_keys=["observation"]), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -5618,9 +5594,9 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), UnsqueezeTransform(-1, in_keys=["observation"]), ) try: @@ -5907,13 +5883,13 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), self._circular_transform ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -5934,9 +5910,10 @@ def test_trans_serial_env_check(self): except RuntimeError: pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), self._circular_transform + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), + self._circular_transform, ) try: check_env_specs(env) @@ -6125,7 +6102,7 @@ def make_env(): @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) - def test_parallel_trans_env_check(self, mode, device): + def test_parallel_trans_env_check(self, mode, device, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), @@ -6133,7 +6110,7 @@ def make_env(): device=device, ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -6160,9 +6137,9 @@ def test_trans_serial_env_check(self, mode, device): @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) - def test_trans_parallel_env_check(self, mode, device): + def test_trans_parallel_env_check(self, mode, device, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy).to(device), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy).to(device), TargetReturn(target_return=10.0, mode=mode), device=device, ) @@ -6370,14 +6347,14 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( DiscreteActionConvMockEnvNumpy(), ToTensorImage(in_keys=["pixels"], out_keys=None), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -6393,9 +6370,9 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ToTensorImage(in_keys=["pixels"], out_keys=None), ) try: @@ -6524,14 +6501,14 @@ def test_transform_compose(self): t(td) assert "mykey" in td.keys() - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) assert "mykey" in env.reset().keys() @@ -6560,9 +6537,9 @@ def make_env(): except RuntimeError: pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), ) try: @@ -6824,8 +6801,8 @@ def test_serial_trans_env_check(self): ) check_env_specs(env) - def test_parallel_trans_env_check(self): - env = ParallelEnv( + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( 2, lambda: TransformedEnv( ContinuousActionVecMockEnv(), @@ -6853,9 +6830,9 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), + maybe_fork_ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), TimeMaxPool( in_keys=["observation"], T=3, @@ -7018,7 +6995,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): state_dim = 7 action_dim = 7 @@ -7027,7 +7004,7 @@ def make_env(): gSDENoise(state_dim=state_dim, action_dim=action_dim), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -7052,11 +7029,11 @@ def test_trans_serial_env_check(self, shape): except RuntimeError: pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): state_dim = 7 action_dim = 7 env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)), ) try: @@ -8902,7 +8879,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self, create_copy): + def test_parallel_trans_env_check(self, create_copy, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), @@ -8913,7 +8890,7 @@ def make_env(): ), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -8934,7 +8911,7 @@ def make_env(): ), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -8968,12 +8945,12 @@ def make_env(): ) check_env_specs(env) - def test_trans_parallel_env_check(self, create_copy): + def test_trans_parallel_env_check(self, create_copy, maybe_fork_ParallelEnv): def make_env(): return ContinuousActionVecMockEnv() env = TransformedEnv( - ParallelEnv(2, make_env), + maybe_fork_ParallelEnv(2, make_env), RenameTransform( ["observation"], ["stuff"], @@ -8988,7 +8965,7 @@ def make_env(): except RuntimeError: pass env = TransformedEnv( - ParallelEnv(2, make_env), + maybe_fork_ParallelEnv(2, make_env), RenameTransform( ["observation_orig"], ["stuff"], @@ -9192,13 +9169,13 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2]) env = TransformedEnv(env, InitTracker()) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -9222,12 +9199,12 @@ def make_env(): except RuntimeError: pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2]) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) env = TransformedEnv(env, InitTracker()) try: check_env_specs(env) @@ -9494,14 +9471,14 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): out_key = "reward" def make_env(): base_env = self.envclass() return TransformedEnv(base_env, self._make_transform_env(out_key, base_env)) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -9522,9 +9499,9 @@ def test_trans_serial_env_check(self): except RuntimeError: pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): out_key = "reward" - base_env = ParallelEnv(2, self.envclass) + base_env = maybe_fork_ParallelEnv(2, self.envclass) env = TransformedEnv(base_env, self._make_transform_env(out_key, base_env)) try: check_env_specs(env) @@ -9666,8 +9643,10 @@ def test_serial_trans_env_check(self): env = SerialEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask())) check_env_specs(env) - def test_parallel_trans_env_check(self): - env = ParallelEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask())) + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( + 2, lambda: TransformedEnv(self._env_class(), ActionMask()) + ) try: check_env_specs(env) finally: @@ -9686,8 +9665,8 @@ def test_trans_serial_env_check(self): except RuntimeError: pass - def test_trans_parallel_env_check(self): - env = TransformedEnv(ParallelEnv(2, self._env_class), ActionMask()) + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): + env = TransformedEnv(maybe_fork_ParallelEnv(2, self._env_class), ActionMask()) try: check_env_specs(env) finally: @@ -9997,13 +9976,13 @@ def make_env(): assert env.device == torch.device("cpu:1") check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform("cpu:1") ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) assert env.device == torch.device("cpu:1") try: check_env_specs(env) @@ -10021,11 +10000,13 @@ def make_env(): assert env.device == torch.device("cpu:1") check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): def make_env(): return ContinuousActionVecMockEnv(device="cpu:0") - env = TransformedEnv(ParallelEnv(2, make_env), DeviceCastTransform("cpu:1")) + env = TransformedEnv( + maybe_fork_ParallelEnv(2, make_env), DeviceCastTransform("cpu:1") + ) assert env.device == torch.device("cpu:1") try: check_env_specs(env) @@ -10115,8 +10096,8 @@ def test_serial_trans_env_check(self): ) check_env_specs(env) - def test_parallel_trans_env_check(self): - env = ParallelEnv( + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( 2, lambda: TransformedEnv( TestPermuteTransform.envclass(), TestPermuteTransform._get_permute() @@ -10143,9 +10124,9 @@ def test_trans_serial_env_check(self): except RuntimeError: pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, TestPermuteTransform.envclass), + maybe_fork_ParallelEnv(2, TestPermuteTransform.envclass), TestPermuteTransform._get_permute(), ) try: @@ -10248,7 +10229,7 @@ def test_transform_no_env(self, batch): reason="EndOfLifeTransform can only be tested when Gym is present.", ) class TestEndOfLife(TransformBase): - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): def make(): with set_gym_backend("gymnasium"): return GymEnv(BREAKOUT_VERSIONED()) @@ -10256,7 +10237,7 @@ def make(): with pytest.warns(UserWarning, match="The base_env is not a gym env"): with pytest.raises(AttributeError): env = TransformedEnv( - ParallelEnv(2, make), transform=EndOfLifeTransform() + maybe_fork_ParallelEnv(2, make), transform=EndOfLifeTransform() ) check_env_specs(env) @@ -10294,7 +10275,7 @@ def make(): @pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")]) @pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")]) - def test_parallel_trans_env_check(self, eol_key, lives_key): + def test_parallel_trans_env_check(self, eol_key, lives_key, maybe_fork_ParallelEnv): def make(): with set_gym_backend("gymnasium"): return TransformedEnv( @@ -10302,7 +10283,7 @@ def make(): transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), ) - env = ParallelEnv(2, make) + env = maybe_fork_ParallelEnv(2, make) try: check_env_specs(env) finally: @@ -10697,7 +10678,7 @@ def test_transform_no_env(self): assert data["reward"] == 2 assert self.check_sign_applied(data["reward_sign"]) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = ContinuousActionVecMockEnv() return TransformedEnv( @@ -10708,7 +10689,7 @@ def make_env(): ), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: @@ -10731,9 +10712,9 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), SignTransform( in_keys=["observation", "reward"], in_keys_inv=["observation_orig"], @@ -10815,8 +10796,8 @@ def test_serial_trans_env_check(self): env = SerialEnv(2, lambda: TransformedEnv(self.DummyEnv(), RemoveEmptySpecs())) check_env_specs(env) - def test_parallel_trans_env_check(self): - env = ParallelEnv( + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( 2, lambda: TransformedEnv(self.DummyEnv(), RemoveEmptySpecs()) ) try: @@ -10833,11 +10814,13 @@ def test_trans_serial_env_check(self): ): env = TransformedEnv(SerialEnv(2, self.DummyEnv), RemoveEmptySpecs()) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): with pytest.raises( RuntimeError, match="The environment passed to ParallelEnv has empty specs" ): - env = TransformedEnv(ParallelEnv(2, self.DummyEnv), RemoveEmptySpecs()) + env = TransformedEnv( + maybe_fork_ParallelEnv(2, self.DummyEnv), RemoveEmptySpecs() + ) def test_transform_no_env(self): td = TensorDict({"a": {"b": {"c": {}}}}, []) @@ -11289,7 +11272,7 @@ def make_env(): ) return env - env = ParallelEnv(2, make_env, mp_start_method="fork") + env = ParallelEnv(2, make_env, mp_start_method=mp_ctx) check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) @@ -11302,7 +11285,7 @@ def test_trans_serial_env_check(self, categorical): @pytest.mark.parametrize("categorical", [True, False]) def test_trans_parallel_env_check(self, categorical): env = ParallelEnv( - 2, ContinuousActionVecMockEnv, mp_start_method="fork" + 2, ContinuousActionVecMockEnv, mp_start_method=mp_ctx ).append_transform(ActionDiscretizer(num_intervals=5, categorical=categorical)) check_env_specs(env)