Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Use non-default mp start method in ParallelEnv #1966

Merged
merged 2 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Feb 26, 2024
commit 38e1ae82d406a5f3d06b02fcf73b92b4a0211941
4 changes: 3 additions & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,11 +636,13 @@ class _ProcessNoWarn(mp.Process):
"""A private Process class that shuts down warnings on the subprocess and controls the number of threads in the subprocess."""

@wraps(mp.Process.__init__)
def __init__(self, *args, num_threads=None, **kwargs):
def __init__(self, *args, num_threads=None, _start_method=None, **kwargs):
import torchrl

self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
self.num_threads = num_threads
if _start_method is not None:
self._start_method = _start_method
super().__init__(*args, **kwargs)

def run(self, *args, **kwargs):
Expand Down
29 changes: 23 additions & 6 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from __future__ import annotations

import functools

import gc

import os
Expand Down Expand Up @@ -172,6 +174,10 @@ class BatchedEnvBase(EnvBase):
non_blocking (bool, optional): if ``True``, device moves will be done using the
``non_blocking=True`` option. Defaults to ``True`` for batched environments
on cuda devices, and ``False`` otherwise.
mp_start_method (str, optional): the multiprocessing start method.
Uses the default start method if not indicated ('spawn' by default in
TorchRL if not initiated differently before first import).
To be used only with :class:`~torchrl.envs.ParallelEnv` subclasses.

Examples:
>>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
Expand Down Expand Up @@ -275,6 +281,7 @@ def __init__(
num_sub_threads: int = 1,
serial_for_single: bool = False,
non_blocking: bool = False,
mp_start_method: str = None,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
Expand Down Expand Up @@ -333,6 +340,11 @@ def __init__(
self._properties_set = False
self._get_metadata(create_env_fn, create_env_kwargs)
self._non_blocking = non_blocking
if mp_start_method is not None and not isinstance(self, ParallelEnv):
raise TypeError(
f"Cannot use mp_start_method={mp_start_method} with envs of type {type(self)}."
)
self._mp_start_method = mp_start_method

@property
def non_blocking(self):
Expand Down Expand Up @@ -1059,7 +1071,16 @@ def _start_workers(self) -> None:

torch.set_num_threads(self.num_threads)

ctx = mp.get_context("spawn")
if self._mp_start_method is not None:
ctx = mp.get_context(self._mp_start_method)
proc_fun = ctx.Process
else:
ctx = mp.get_context("spawn")
proc_fun = functools.partial(
_ProcessNoWarn,
num_threads=self.num_sub_threads,
_start_method=self._mp_start_method,
)

_num_workers = self.num_workers

Expand Down Expand Up @@ -1104,11 +1125,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
"has_lazy_inputs": self.has_lazy_inputs,
}
)
process = _ProcessNoWarn(
target=func,
num_threads=self.num_sub_threads,
kwargs=kwargs[idx],
)
process = proc_fun(target=func, kwargs=kwargs[idx])
process.daemon = True
process.start()
child_pipe.close()
Expand Down
Loading