Skip to content

Commit

Permalink
[Feature, BugFix] Better thread control in penv and collectors (pytor…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 30, 2024
1 parent b1cc796 commit 967bad2
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 64 deletions.
101 changes: 71 additions & 30 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import argparse
import gc
import logging

import sys
Expand Down Expand Up @@ -2357,39 +2358,79 @@ def make_env():
del collector


@pytest.mark.skipif(
IS_OSX, reason="setting different threads across workeres can randomly fail on OSX."
)
def test_num_threads():
from torchrl.collectors import collectors

_main_async_collector_saved = collectors._main_async_collector
collectors._main_async_collector = decorate_thread_sub_func(
collectors._main_async_collector, num_threads=3
class TestLibThreading:
@pytest.mark.skipif(
IS_OSX,
reason="setting different threads across workeres can randomly fail on OSX.",
)
num_threads = torch.get_num_threads()
try:
env = ContinuousActionVecMockEnv()
c = MultiSyncDataCollector(
[env],
policy=RandomPolicy(env.action_spec),
num_threads=7,
num_sub_threads=3,
total_frames=200,
frames_per_batch=200,
def test_num_threads(self):
from torchrl.collectors import collectors

_main_async_collector_saved = collectors._main_async_collector
collectors._main_async_collector = decorate_thread_sub_func(
collectors._main_async_collector, num_threads=3
)
assert torch.get_num_threads() == 7
for _ in c:
pass
finally:
num_threads = torch.get_num_threads()
try:
env = ContinuousActionVecMockEnv()
c = MultiSyncDataCollector(
[env],
policy=RandomPolicy(env.action_spec),
num_threads=7,
num_sub_threads=3,
total_frames=200,
frames_per_batch=200,
)
assert torch.get_num_threads() == 7
for _ in c:
pass
finally:
try:
c.shutdown()
del c
except Exception:
logging.info("Failed to shut down collector")
# reset vals
collectors._main_async_collector = _main_async_collector_saved
torch.set_num_threads(num_threads)

@pytest.mark.skipif(
IS_OSX,
reason="setting different threads across workeres can randomly fail on OSX.",
)
def test_auto_num_threads(self):
init_threads = torch.get_num_threads()
try:
collector = MultiSyncDataCollector(
[ContinuousActionVecMockEnv],
RandomPolicy(ContinuousActionVecMockEnv().full_action_spec),
frames_per_batch=3,
)
for _ in collector:
assert torch.get_num_threads() == init_threads - 1
break
collector.shutdown()
assert torch.get_num_threads() == init_threads
del collector
gc.collect()
finally:
torch.set_num_threads(init_threads)

try:
c.shutdown()
del c
except Exception:
logging.info("Failed to shut down collector")
# reset vals
collectors._main_async_collector = _main_async_collector_saved
torch.set_num_threads(num_threads)
collector = MultiSyncDataCollector(
[ParallelEnv(2, ContinuousActionVecMockEnv)],
RandomPolicy(ContinuousActionVecMockEnv().full_action_spec.expand(2)),
frames_per_batch=3,
)
for _ in collector:
assert torch.get_num_threads() == init_threads - 2
break
collector.shutdown()
assert torch.get_num_threads() == init_threads
del collector
gc.collect()
finally:
torch.set_num_threads(init_threads)


if __name__ == "__main__":
Expand Down
79 changes: 57 additions & 22 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import gc
import os.path
import re
from collections import defaultdict
Expand Down Expand Up @@ -2333,30 +2334,64 @@ def test_terminated_or_truncated_spec(self):
assert not data["nested", "_reset"].any()


@pytest.mark.skipif(
IS_OSX, reason="setting different threads across workeres can randomly fail on OSX."
)
def test_num_threads():
from torchrl.envs import batched_envs

_run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem
batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func(
batched_envs._run_worker_pipe_shared_mem, num_threads=3
class TestLibThreading:
@pytest.mark.skipif(
IS_OSX,
reason="setting different threads across workeres can randomly fail on OSX.",
)
num_threads = torch.get_num_threads()
try:
env = ParallelEnv(
2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7
def test_num_threads(self):
from torchrl.envs import batched_envs

_run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem
batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func(
batched_envs._run_worker_pipe_shared_mem, num_threads=3
)
# We could test that the number of threads isn't changed until we start the procs.
# Even though it's unlikely that we have 7 threads, we still disable this for safety
# assert torch.get_num_threads() != 7
env.rollout(3)
assert torch.get_num_threads() == 7
finally:
# reset vals
batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save
torch.set_num_threads(num_threads)
num_threads = torch.get_num_threads()
try:
env = ParallelEnv(
2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7
)
# We could test that the number of threads isn't changed until we start the procs.
# Even though it's unlikely that we have 7 threads, we still disable this for safety
# assert torch.get_num_threads() != 7
env.rollout(3)
assert torch.get_num_threads() == 7
finally:
# reset vals
batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save
torch.set_num_threads(num_threads)

@pytest.mark.skipif(
IS_OSX,
reason="setting different threads across workeres can randomly fail on OSX.",
)
def test_auto_num_threads(self):
init_threads = torch.get_num_threads()

try:
env3 = ParallelEnv(3, lambda: GymEnv("Pendulum-v1"))
env3.rollout(2)

assert torch.get_num_threads() == max(1, init_threads - 3)

env2 = ParallelEnv(2, lambda: GymEnv("Pendulum-v1"))
env2.rollout(2)

assert torch.get_num_threads() == max(1, init_threads - 5)

env2.close()
del env2
gc.collect()

assert torch.get_num_threads() == max(1, init_threads - 3)

env3.close()
del env3
gc.collect()

assert torch.get_num_threads() == init_threads
finally:
torch.set_num_threads(init_threads)


def test_run_type_checks():
Expand Down
3 changes: 3 additions & 0 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@
# Filter warnings in subprocesses: True by default given the multiple optional
# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`.
filter_warnings_subprocess = True

_THREAD_POOL_INIT = torch.get_num_threads()
_THREAD_POOL = torch.get_num_threads()
37 changes: 34 additions & 3 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,12 +1368,11 @@ def __init__(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
self.closed = True
if num_threads is None:
num_threads = len(create_env_fn) + 1 # 1 more thread for this proc
self.num_workers = len(create_env_fn)

self.num_sub_threads = num_sub_threads
self.num_threads = num_threads
self.create_env_fn = create_env_fn
self.num_workers = len(create_env_fn)
self.create_env_kwargs = (
create_env_kwargs
if create_env_kwargs is not None
Expand Down Expand Up @@ -1521,6 +1520,18 @@ def _get_weight_fn(weights=policy_weights):
self._frames = 0
self._iter = -1

@classmethod
def _total_workers_from_env(cls, env_creators):
if isinstance(env_creators, (tuple, list)):
return sum(
cls._total_workers_from_env(env_creator) for env_creator in env_creators
)
from torchrl.envs import ParallelEnv

if isinstance(env_creators, ParallelEnv):
return env_creators.num_workers
return 1

def _get_devices(
self,
*,
Expand Down Expand Up @@ -1595,7 +1606,19 @@ def _queue_len(self) -> int:
raise NotImplementedError

def _run_processes(self) -> None:
if self.num_threads is None:
import torchrl

total_workers = self._total_workers_from_env(self.create_env_fn)
self.num_threads = max(
1, torchrl._THREAD_POOL - total_workers
) # 1 more thread for this proc

torch.set_num_threads(self.num_threads)
assert torch.get_num_threads() == self.num_threads
import torchrl

torchrl._THREAD_POOL = self.num_threads
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
self.procs = []
self.pipes = []
Expand Down Expand Up @@ -1702,6 +1725,14 @@ def _shutdown_main(self) -> None:
for proc in self.procs:
proc.join(1.0)
finally:
import torchrl

torchrl._THREAD_POOL = min(
torchrl._THREAD_POOL_INIT,
torchrl._THREAD_POOL + self._total_workers_from_env(self.create_env_fn),
)
torch.set_num_threads(torchrl._THREAD_POOL)

for proc in self.procs:
if proc.is_alive():
proc.terminate()
Expand Down
18 changes: 16 additions & 2 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,6 @@ def __init__(
super().__init__(device=device)
self.serial_for_single = serial_for_single
self.is_closed = True
if num_threads is None:
num_threads = num_workers + 1 # 1 more thread for this proc
self.num_sub_threads = num_sub_threads
self.num_threads = num_threads
self._cache_in_keys = None
Expand Down Expand Up @@ -633,6 +631,12 @@ def close(self) -> None:

self._shutdown_workers()
self.is_closed = True
import torchrl

torchrl._THREAD_POOL = min(
torchrl._THREAD_POOL_INIT, torchrl._THREAD_POOL + self.num_workers
)
torch.set_num_threads(torchrl._THREAD_POOL)

def _shutdown_workers(self) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -1010,7 +1014,17 @@ class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta):
def _start_workers(self) -> None:
from torchrl.envs.env_creator import EnvCreator

if self.num_threads is None:
import torchrl

self.num_threads = max(
1, torchrl._THREAD_POOL - self.num_workers
) # 1 more thread for this proc

torch.set_num_threads(self.num_threads)
import torchrl

torchrl._THREAD_POOL = self.num_threads

ctx = mp.get_context("spawn")

Expand Down
14 changes: 7 additions & 7 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,16 +306,16 @@ def _gym_to_torchrl_spec_transform(
shape = torch.Size([1])
if dtype is None:
dtype = numpy_to_torch_dtype_dict[spec.dtype]
low = torch.tensor(spec.low, device=device, dtype=dtype)
high = torch.tensor(spec.high, device=device, dtype=dtype)
low = torch.as_tensor(spec.low, device=device, dtype=dtype)
high = torch.as_tensor(spec.high, device=device, dtype=dtype)
is_unbounded = low.isinf().all() and high.isinf().all()

minval, maxval = _minmax_dtype(dtype)
minval = torch.as_tensor(minval).to(low.device, dtype)
maxval = torch.as_tensor(maxval).to(low.device, dtype)
is_unbounded = is_unbounded or (
torch.isclose(low, torch.tensor(minval, dtype=dtype)).all()
and torch.isclose(high, torch.tensor(maxval, dtype=dtype)).all()
torch.isclose(low, torch.as_tensor(minval, dtype=dtype)).all()
and torch.isclose(high, torch.as_tensor(maxval, dtype=dtype)).all()
)
return (
UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype)
Expand Down Expand Up @@ -1480,7 +1480,7 @@ def _read_obs(self, obs, key, tensor, index):
# Simplest case: there is one observation,
# presented as a np.ndarray. The key should be pixels or observation.
# We just write that value at its location in the tensor
tensor[index] = torch.tensor(obs, device=tensor.device)
tensor[index] = torch.as_tensor(obs, device=tensor.device)
elif isinstance(obs, dict):
if key not in obs:
raise KeyError(
Expand All @@ -1491,13 +1491,13 @@ def _read_obs(self, obs, key, tensor, index):
# if the obs is a dict, we expect that the key points also to
# a value in the obs. We retrieve this value and write it in the
# tensor
tensor[index] = torch.tensor(subobs, device=tensor.device)
tensor[index] = torch.as_tensor(subobs, device=tensor.device)

elif isinstance(obs, (list, tuple)):
# tuples are stacked along the first dimension when passing gym spaces
# to torchrl specs. As such, we can simply stack the tuple and set it
# at the relevant index (assuming stacking can be achieved)
tensor[index] = torch.tensor(obs, device=tensor.device)
tensor[index] = torch.as_tensor(obs, device=tensor.device)
else:
raise NotImplementedError(
f"Observations of type {type(obs)} are not supported yet."
Expand Down

0 comments on commit 967bad2

Please sign in to comment.