Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix another ctx test #2284

Merged
merged 4 commits into from
Jul 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
amend
  • Loading branch information
vmoens committed Jul 10, 2024
commit 980c015ea75f3c8db30e1f4e542a0af44f7b15d7
26 changes: 19 additions & 7 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def test_parallel_devices(

@pytest.mark.parametrize("start_method", [None, mp_ctx])
def test_serial_for_single(self, maybe_fork_ParallelEnv, start_method):
gc.collect()
env = ParallelEnv(
1,
ContinuousActionVecMockEnv,
Expand Down Expand Up @@ -2002,6 +2003,7 @@ def forward(self, tensordict):

@staticmethod
def main_penv(j, q=None):
gc.collect()
device = "cpu" if not torch.cuda.device_count() else "cuda:0"
n_workers = 1
env_p = ParallelEnv(
Expand Down Expand Up @@ -2328,6 +2330,7 @@ def test_rollout_policy(self, batch_size, rollout_steps, count):
def test_vec_env(
self, batch_size, env_type, break_when_any_done, rollout_steps=4, n_workers=2
):
gc.collect()
env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size)
if env_type == "serial":
vec_env = SerialEnv(n_workers, env_fun)
Expand Down Expand Up @@ -2591,6 +2594,7 @@ class TestLibThreading:
reason="setting different threads across workers can randomly fail on OSX.",
)
def test_num_threads(self):
gc.collect()
from torchrl.envs import batched_envs

_run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem
Expand All @@ -2617,6 +2621,7 @@ def test_num_threads(self):
reason="setting different threads across workers can randomly fail on OSX.",
)
def test_auto_num_threads(self, maybe_fork_ParallelEnv):
gc.collect()
init_threads = torch.get_num_threads()

try:
Expand Down Expand Up @@ -2671,6 +2676,8 @@ def test_run_type_checks():
@pytest.mark.skipif(not torch.cuda.device_count(), reason="No cuda device found.")
@pytest.mark.parametrize("break_when_any_done", [True, False])
def test_auto_cast_to_device(break_when_any_done):
gc.collect()

env = ContinuousActionVecMockEnv(device="cpu")
policy = Actor(
nn.Linear(
Expand Down Expand Up @@ -2701,6 +2708,8 @@ def test_auto_cast_to_device(break_when_any_done):
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("share_individual_td", [True, False])
def test_backprop(device, maybe_fork_ParallelEnv, share_individual_td):
gc.collect()

# Tests that backprop through a series of single envs and through a serial env are identical
# Also tests that no backprop can be achieved with parallel env.
class DifferentiableEnv(EnvBase):
Expand Down Expand Up @@ -2785,6 +2794,7 @@ def make_env(seed, device=device):
assert not r_parallel.exclude("action").requires_grad
finally:
p_env.close()
del p_env


@pytest.mark.skipif(not _has_gym, reason="Gym required for this test")
Expand All @@ -2808,20 +2818,22 @@ def forward(self, values):
def test_parallel_another_ctx():
from torch import multiprocessing as mp

sm = mp.get_start_method()
if sm == "spawn":
other_sm = "fork"
else:
other_sm = "spawn"
env = ParallelEnv(2, ContinuousActionVecMockEnv, mp_start_method=other_sm)
gc.collect()

try:
sm = mp.get_start_method()
if sm == "spawn":
other_sm = "fork"
else:
other_sm = "spawn"
env = ParallelEnv(2, ContinuousActionVecMockEnv, mp_start_method=other_sm)
assert env.rollout(3) is not None
assert env._workers[0]._start_method == other_sm
finally:
try:
env.close()
del env
except RuntimeError:
except Exception:
pass


Expand Down
Loading