Skip to content

Commit

Permalink
[BugFix] Fix another ctx test (pytorch#2284)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 10, 2024
1 parent ea79350 commit d0fa836
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux_optdeps/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- pytest-mock
- pytest-instafail
- pytest-rerunfailures
- pytest-timeout
- expecttest
- pyyaml
- scipy
Expand Down
4 changes: 3 additions & 1 deletion .github/unittest/linux_optdeps/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ export MKL_THREADING_LAYER=GNU
export CKPT_BACKEND=torch
export BATCHED_PIPE_TIMEOUT=60

MUJOCO_GL=egl python .github/unittest/helpers/coverage_run_parallel.py -m pytest --instafail -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py
MUJOCO_GL=egl python .github/unittest/helpers/coverage_run_parallel.py -m pytest --instafail \
-v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py --capture no \
--timeout=120 --mp_fork_if_no_cuda
coverage combine
coverage xml -i
26 changes: 19 additions & 7 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def test_parallel_devices(

@pytest.mark.parametrize("start_method", [None, mp_ctx])
def test_serial_for_single(self, maybe_fork_ParallelEnv, start_method):
gc.collect()
env = ParallelEnv(
1,
ContinuousActionVecMockEnv,
Expand Down Expand Up @@ -2002,6 +2003,7 @@ def forward(self, tensordict):

@staticmethod
def main_penv(j, q=None):
gc.collect()
device = "cpu" if not torch.cuda.device_count() else "cuda:0"
n_workers = 1
env_p = ParallelEnv(
Expand Down Expand Up @@ -2329,6 +2331,7 @@ def test_rollout_policy(self, batch_size, rollout_steps, count):
def test_vec_env(
self, batch_size, env_type, break_when_any_done, rollout_steps=4, n_workers=2
):
gc.collect()
env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size)
if env_type == "serial":
vec_env = SerialEnv(n_workers, env_fun)
Expand Down Expand Up @@ -2592,6 +2595,7 @@ class TestLibThreading:
reason="setting different threads across workers can randomly fail on OSX.",
)
def test_num_threads(self):
gc.collect()
from torchrl.envs import batched_envs

_run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem
Expand All @@ -2618,6 +2622,7 @@ def test_num_threads(self):
reason="setting different threads across workers can randomly fail on OSX.",
)
def test_auto_num_threads(self, maybe_fork_ParallelEnv):
gc.collect()
init_threads = torch.get_num_threads()

try:
Expand Down Expand Up @@ -2672,6 +2677,8 @@ def test_run_type_checks():
@pytest.mark.skipif(not torch.cuda.device_count(), reason="No cuda device found.")
@pytest.mark.parametrize("break_when_any_done", [True, False])
def test_auto_cast_to_device(break_when_any_done):
gc.collect()

env = ContinuousActionVecMockEnv(device="cpu")
policy = Actor(
nn.Linear(
Expand Down Expand Up @@ -2702,6 +2709,8 @@ def test_auto_cast_to_device(break_when_any_done):
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("share_individual_td", [True, False])
def test_backprop(device, maybe_fork_ParallelEnv, share_individual_td):
gc.collect()

# Tests that backprop through a series of single envs and through a serial env are identical
# Also tests that no backprop can be achieved with parallel env.
class DifferentiableEnv(EnvBase):
Expand Down Expand Up @@ -2786,6 +2795,7 @@ def make_env(seed, device=device):
assert not r_parallel.exclude("action").requires_grad
finally:
p_env.close()
del p_env


@pytest.mark.skipif(not _has_gym, reason="Gym required for this test")
Expand All @@ -2809,20 +2819,22 @@ def forward(self, values):
def test_parallel_another_ctx():
from torch import multiprocessing as mp

sm = mp.get_start_method()
if sm == "spawn":
other_sm = "fork"
else:
other_sm = "spawn"
env = ParallelEnv(2, ContinuousActionVecMockEnv, mp_start_method=other_sm)
gc.collect()

try:
sm = mp.get_start_method()
if sm == "spawn":
other_sm = "fork"
else:
other_sm = "spawn"
env = ParallelEnv(2, ContinuousActionVecMockEnv, mp_start_method=other_sm)
assert env.rollout(3) is not None
assert env._workers[0]._start_method == other_sm
finally:
try:
env.close()
del env
except RuntimeError:
except Exception:
pass


Expand Down

0 comments on commit d0fa836

Please sign in to comment.