Skip to content

Commit

Permalink
[Feature] Use non-default mp start method in ParallelEnv (pytorch#1966)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 27, 2024
1 parent 6274b27 commit cadf4d9
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
17 changes: 17 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2609,6 +2609,23 @@ def forward(self, values):
env.rollout(10, policy)


def test_parallel_another_ctx():
from torch import multiprocessing as mp

sm = mp.get_start_method()
if sm == "spawn":
other_sm = "fork"
else:
other_sm = "spawn"
env = ParallelEnv(2, ContinuousActionVecMockEnv, mp_start_method=other_sm)
try:
assert env.rollout(3) is not None
assert env._workers[0]._start_method == other_sm
finally:
env.close()
del env


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
4 changes: 3 additions & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,11 +636,13 @@ class _ProcessNoWarn(mp.Process):
"""A private Process class that shuts down warnings on the subprocess and controls the number of threads in the subprocess."""

@wraps(mp.Process.__init__)
def __init__(self, *args, num_threads=None, **kwargs):
def __init__(self, *args, num_threads=None, _start_method=None, **kwargs):
import torchrl

self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
self.num_threads = num_threads
if _start_method is not None:
self._start_method = _start_method
super().__init__(*args, **kwargs)

def run(self, *args, **kwargs):
Expand Down
35 changes: 29 additions & 6 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from __future__ import annotations

import functools

import gc

import os
Expand Down Expand Up @@ -172,6 +174,10 @@ class BatchedEnvBase(EnvBase):
non_blocking (bool, optional): if ``True``, device moves will be done using the
``non_blocking=True`` option. Defaults to ``True`` for batched environments
on cuda devices, and ``False`` otherwise.
mp_start_method (str, optional): the multiprocessing start method.
Uses the default start method if not indicated ('spawn' by default in
TorchRL if not initiated differently before first import).
To be used only with :class:`~torchrl.envs.ParallelEnv` subclasses.
Examples:
>>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
Expand Down Expand Up @@ -275,6 +281,7 @@ def __init__(
num_sub_threads: int = 1,
serial_for_single: bool = False,
non_blocking: bool = False,
mp_start_method: str = None,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
Expand Down Expand Up @@ -333,6 +340,11 @@ def __init__(
self._properties_set = False
self._get_metadata(create_env_fn, create_env_kwargs)
self._non_blocking = non_blocking
if mp_start_method is not None and not isinstance(self, ParallelEnv):
raise TypeError(
f"Cannot use mp_start_method={mp_start_method} with envs of type {type(self)}."
)
self._mp_start_method = mp_start_method

@property
def non_blocking(self):
Expand Down Expand Up @@ -1059,7 +1071,18 @@ def _start_workers(self) -> None:

torch.set_num_threads(self.num_threads)

ctx = mp.get_context("spawn")
if self._mp_start_method is not None:
ctx = mp.get_context(self._mp_start_method)
proc_fun = ctx.Process
num_sub_threads = self.num_sub_threads
else:
ctx = mp.get_context("spawn")
proc_fun = functools.partial(
_ProcessNoWarn,
num_threads=self.num_sub_threads,
_start_method=self._mp_start_method,
)
num_sub_threads = None

_num_workers = self.num_workers

Expand Down Expand Up @@ -1102,13 +1125,10 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
"_selected_reset_keys": self._selected_reset_keys,
"_selected_step_keys": self._selected_step_keys,
"has_lazy_inputs": self.has_lazy_inputs,
"num_threads": num_sub_threads,
}
)
process = _ProcessNoWarn(
target=func,
num_threads=self.num_sub_threads,
kwargs=kwargs[idx],
)
process = proc_fun(target=func, kwargs=kwargs[idx])
process.daemon = True
process.start()
child_pipe.close()
Expand Down Expand Up @@ -1474,7 +1494,10 @@ def _run_worker_pipe_shared_mem(
_selected_step_keys=None,
has_lazy_inputs: bool = False,
verbose: bool = False,
num_threads: int | None = None, # for fork start method
) -> None:
if num_threads is not None:
torch.set_num_threads(num_threads)
device = shared_tensordict.device
if device is None or device.type != "cuda":
# Check if some tensors are shared on cuda
Expand Down

0 comments on commit cadf4d9

Please sign in to comment.