Skip to content

Commit

Permalink
[Feature] Threaded collection and parallel envs (pytorch#1559)
Browse files Browse the repository at this point in the history
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
  • Loading branch information
vmoens and matteobettini authored Sep 22, 2023
1 parent 95773f7 commit 09e148b
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 8 deletions.
9 changes: 9 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from tensordict import tensorclass, TensorDict
from torchrl._utils import implement_for, seed_generator
from torchrl.data.utils import CloudpickleWrapper

from torchrl.envs import MultiThreadedEnv, ObservationNorm
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
Expand Down Expand Up @@ -433,3 +434,11 @@ def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int):
== td["nested_2", "observation"][~action_is_count]
).all()
assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all()


def decorate_thread_sub_func(func, num_threads):
def new_func(*args, **kwargs):
assert torch.get_num_threads() == num_threads
return func(*args, **kwargs)

return CloudpickleWrapper(new_func)
31 changes: 31 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
# LICENSE file in the root directory of this source tree.

import argparse

import sys

import numpy as np
import pytest
import torch
from _utils_internal import (
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
PENDULUM_VERSIONED,
PONG_VERSIONED,
Expand Down Expand Up @@ -1783,6 +1785,35 @@ def make_env():
collector.shutdown()


def test_num_threads():
from torchrl.collectors import collectors

_main_async_collector_saved = collectors._main_async_collector
collectors._main_async_collector = decorate_thread_sub_func(
collectors._main_async_collector, num_threads=3
)
num_threads = torch.get_num_threads()
try:
env = ContinuousActionVecMockEnv()
c = MultiSyncDataCollector(
[env],
policy=RandomPolicy(env.action_spec),
num_threads=7,
num_sub_threads=3,
total_frames=200,
frames_per_batch=200,
)
assert torch.get_num_threads() == 7
for _ in c:
pass
c.shutdown()
del c
finally:
# reset vals
collectors._main_async_collector = _main_async_collector_saved
torch.set_num_threads(num_threads)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
24 changes: 24 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_make_envs,
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
Expand Down Expand Up @@ -2088,6 +2089,29 @@ def test_mocking_envs(envclass):
check_env_specs(env, seed=100, return_contiguous=False)


def test_num_threads():
from torchrl.envs import batched_envs

_run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem
batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func(
batched_envs._run_worker_pipe_shared_mem, num_threads=3
)
num_threads = torch.get_num_threads()
try:
env = ParallelEnv(
2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7
)
# We could test that the number of threads isn't changed until we start the procs.
# Even though it's unlikely that we have 7 threads, we still disable this for safety
# assert torch.get_num_threads() != 7
env.rollout(3)
assert torch.get_num_threads() == 7
finally:
# reset vals
batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save
torch.set_num_threads(num_threads)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
13 changes: 8 additions & 5 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import collections

Expand Down Expand Up @@ -534,21 +535,23 @@ def get_trace():


class _ProcessNoWarn(mp.Process):
"""A private Process class that shuts down warnings on the subprocess."""
"""A private Process class that shuts down warnings on the subprocess and controls the number of threads in the subprocess."""

@wraps(mp.Process.__init__)
def __init__(self, *args, **kwargs):
def __init__(self, *args, num_threads=None, **kwargs):
import torchrl

if torchrl.filter_warnings_subprocess:
self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
self.num_threads = num_threads
super().__init__(*args, **kwargs)

def run(self, *args, **kwargs):
if self.num_threads is not None:
torch.set_num_threads(self.num_threads)
if self.filter_warnings_subprocess:
import warnings

with warnings.catch_warnings():
warnings.simplefilter("ignore")
return mp.Process.run(self, *args, **kwargs)
return mp.Process.run(self, *args, **kwargs)
return mp.Process.run(self, *args, **kwargs)
21 changes: 20 additions & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,14 @@ class _MultiDataCollector(DataCollectorBase):
Defaults to ``False``.
preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
that will be allowed to finished collecting their rollout before the rest are forced to end early.
num_threads (int, optional): number of threads for this process.
Defaults to the number of workers.
num_sub_threads (int, optional): number of threads of the subprocesses.
Should be equal to one plus the number of processes launched within
each subprocess (or one if a single process is launched).
Defaults to 1 for safety: if none is indicated, launching multiple
workers may charge the cpu load too much and harm performance.
"""

def __init__(
Expand Down Expand Up @@ -1127,11 +1135,17 @@ def __init__(
update_at_each_batch: bool = False,
devices=None,
storing_devices=None,
num_threads: int = None,
num_sub_threads: int = 1,
):
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
self.closed = True
if num_threads is None:
num_threads = len(create_env_fn) + 1 # 1 more thread for this proc
self.num_sub_threads = num_sub_threads
self.num_threads = num_threads
self.create_env_fn = create_env_fn
self.num_workers = len(create_env_fn)
self.create_env_kwargs = (
Expand Down Expand Up @@ -1308,6 +1322,7 @@ def _queue_len(self) -> int:
raise NotImplementedError

def _run_processes(self) -> None:
torch.set_num_threads(self.num_threads)
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
self.procs = []
self.pipes = []
Expand Down Expand Up @@ -1339,7 +1354,11 @@ def _run_processes(self) -> None:
"idx": i,
"interruptor": self.interruptor,
}
proc = _ProcessNoWarn(target=_main_async_collector, kwargs=kwargs)
proc = _ProcessNoWarn(
target=_main_async_collector,
num_threads=self.num_sub_threads,
kwargs=kwargs,
)
# proc.daemon can't be set as daemonic processes may be launched by the process itself
try:
proc.start()
Expand Down
3 changes: 1 addition & 2 deletions torchrl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,8 @@ def __setstate__(self, ob: bytes):
self.fn, self.kwargs = pickle.loads(ob)

def __call__(self, *args, **kwargs) -> Any:
kwargs = {k: item for k, item in kwargs.items()}
kwargs.update(self.kwargs)
return self.fn(**kwargs)
return self.fn(*args, **kwargs)


def _process_action_space_spec(action_space, spec):
Expand Down
18 changes: 18 additions & 0 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ class _BatchedEnv(EnvBase):
It is assumed that all environments will run on the same device as a common shared
tensordict will be used to pass data from process to process. The device can be
changed after instantiation using :obj:`env.to(device)`.
num_threads (int, optional): number of threads for this process.
Defaults to the number of workers.
This parameter has no effect for the :class:`~SerialEnv` class.
num_sub_threads (int, optional): number of threads of the subprocesses.
Should be equal to one plus the number of processes launched within
each subprocess (or one if a single process is launched).
Defaults to 1 for safety: if none is indicated, launching multiple
workers may charge the cpu load too much and harm performance.
This parameter has no effect for the :class:`~SerialEnv` class.
"""

Expand All @@ -144,6 +153,8 @@ def __init__(
policy_proof: Optional[Callable] = None,
device: Optional[DEVICE_TYPING] = None,
allow_step_when_done: bool = False,
num_threads: int = None,
num_sub_threads: int = 1,
):
if device is not None:
raise ValueError(
Expand All @@ -154,6 +165,10 @@ def __init__(

super().__init__(device=None)
self.is_closed = True
if num_threads is None:
num_threads = num_workers + 1 # 1 more thread for this proc
self.num_sub_threads = num_sub_threads
self.num_threads = num_threads
self._cache_in_keys = None

self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1)
Expand Down Expand Up @@ -692,6 +707,8 @@ class ParallelEnv(_BatchedEnv):
def _start_workers(self) -> None:
from torchrl.envs.env_creator import EnvCreator

torch.set_num_threads(self.num_threads)

ctx = mp.get_context("spawn")

_num_workers = self.num_workers
Expand All @@ -717,6 +734,7 @@ def _start_workers(self) -> None:

process = _ProcessNoWarn(
target=_run_worker_pipe_shared_mem,
num_threads=self.num_sub_threads,
args=(
parent_pipe,
child_pipe,
Expand Down

0 comments on commit 09e148b

Please sign in to comment.