From cadf4d90f2c9a1d2006523363a9dd37be33d969e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 26 Feb 2024 19:35:04 -0500 Subject: [PATCH] [Feature] Use non-default mp start method in ParallelEnv (#1966) --- test/test_env.py | 17 +++++++++++++++++ torchrl/_utils.py | 4 +++- torchrl/envs/batched_envs.py | 35 +++++++++++++++++++++++++++++------ 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index f03b35e20c4..6ff97e0c37f 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2609,6 +2609,23 @@ def forward(self, values): env.rollout(10, policy) +def test_parallel_another_ctx(): + from torch import multiprocessing as mp + + sm = mp.get_start_method() + if sm == "spawn": + other_sm = "fork" + else: + other_sm = "spawn" + env = ParallelEnv(2, ContinuousActionVecMockEnv, mp_start_method=other_sm) + try: + assert env.rollout(3) is not None + assert env._workers[0]._start_method == other_sm + finally: + env.close() + del env + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index ae01556f0e6..c3fc3e09ea1 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -636,11 +636,13 @@ class _ProcessNoWarn(mp.Process): """A private Process class that shuts down warnings on the subprocess and controls the number of threads in the subprocess.""" @wraps(mp.Process.__init__) - def __init__(self, *args, num_threads=None, **kwargs): + def __init__(self, *args, num_threads=None, _start_method=None, **kwargs): import torchrl self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess self.num_threads = num_threads + if _start_method is not None: + self._start_method = _start_method super().__init__(*args, **kwargs) def run(self, *args, **kwargs): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index ae313ce5f19..7546459c988 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -5,6 +5,8 @@ from __future__ import annotations +import functools + import gc import os @@ -172,6 +174,10 @@ class BatchedEnvBase(EnvBase): non_blocking (bool, optional): if ``True``, device moves will be done using the ``non_blocking=True`` option. Defaults to ``True`` for batched environments on cuda devices, and ``False`` otherwise. + mp_start_method (str, optional): the multiprocessing start method. + Uses the default start method if not indicated ('spawn' by default in + TorchRL if not initiated differently before first import). + To be used only with :class:`~torchrl.envs.ParallelEnv` subclasses. Examples: >>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator @@ -275,6 +281,7 @@ def __init__( num_sub_threads: int = 1, serial_for_single: bool = False, non_blocking: bool = False, + mp_start_method: str = None, ): super().__init__(device=device) self.serial_for_single = serial_for_single @@ -333,6 +340,11 @@ def __init__( self._properties_set = False self._get_metadata(create_env_fn, create_env_kwargs) self._non_blocking = non_blocking + if mp_start_method is not None and not isinstance(self, ParallelEnv): + raise TypeError( + f"Cannot use mp_start_method={mp_start_method} with envs of type {type(self)}." + ) + self._mp_start_method = mp_start_method @property def non_blocking(self): @@ -1059,7 +1071,18 @@ def _start_workers(self) -> None: torch.set_num_threads(self.num_threads) - ctx = mp.get_context("spawn") + if self._mp_start_method is not None: + ctx = mp.get_context(self._mp_start_method) + proc_fun = ctx.Process + num_sub_threads = self.num_sub_threads + else: + ctx = mp.get_context("spawn") + proc_fun = functools.partial( + _ProcessNoWarn, + num_threads=self.num_sub_threads, + _start_method=self._mp_start_method, + ) + num_sub_threads = None _num_workers = self.num_workers @@ -1102,13 +1125,10 @@ def look_for_cuda(tensor, has_cuda=has_cuda): "_selected_reset_keys": self._selected_reset_keys, "_selected_step_keys": self._selected_step_keys, "has_lazy_inputs": self.has_lazy_inputs, + "num_threads": num_sub_threads, } ) - process = _ProcessNoWarn( - target=func, - num_threads=self.num_sub_threads, - kwargs=kwargs[idx], - ) + process = proc_fun(target=func, kwargs=kwargs[idx]) process.daemon = True process.start() child_pipe.close() @@ -1474,7 +1494,10 @@ def _run_worker_pipe_shared_mem( _selected_step_keys=None, has_lazy_inputs: bool = False, verbose: bool = False, + num_threads: int | None = None, # for fork start method ) -> None: + if num_threads is not None: + torch.set_num_threads(num_threads) device = shared_tensordict.device if device is None or device.type != "cuda": # Check if some tensors are shared on cuda