Skip to content

Commit

Permalink
[BugFix] thread setting bug (pytorch#1852)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 31, 2024
1 parent 017bcd0 commit 86b8918
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 26 deletions.
4 changes: 2 additions & 2 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,7 +2361,7 @@ def make_env():
class TestLibThreading:
@pytest.mark.skipif(
IS_OSX,
reason="setting different threads across workeres can randomly fail on OSX.",
reason="setting different threads across workers can randomly fail on OSX.",
)
def test_num_threads(self):
from torchrl.collectors import collectors
Expand Down Expand Up @@ -2396,7 +2396,7 @@ def test_num_threads(self):

@pytest.mark.skipif(
IS_OSX,
reason="setting different threads across workeres can randomly fail on OSX.",
reason="setting different threads across workers can randomly fail on OSX.",
)
def test_auto_num_threads(self):
init_threads = torch.get_num_threads()
Expand Down
8 changes: 4 additions & 4 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2337,7 +2337,7 @@ def test_terminated_or_truncated_spec(self):
class TestLibThreading:
@pytest.mark.skipif(
IS_OSX,
reason="setting different threads across workeres can randomly fail on OSX.",
reason="setting different threads across workers can randomly fail on OSX.",
)
def test_num_threads(self):
from torchrl.envs import batched_envs
Expand All @@ -2363,18 +2363,18 @@ def test_num_threads(self):

@pytest.mark.skipif(
IS_OSX,
reason="setting different threads across workeres can randomly fail on OSX.",
reason="setting different threads across workers can randomly fail on OSX.",
)
def test_auto_num_threads(self):
init_threads = torch.get_num_threads()

try:
env3 = ParallelEnv(3, lambda: GymEnv("Pendulum-v1"))
env3 = ParallelEnv(3, ContinuousActionVecMockEnv)
env3.rollout(2)

assert torch.get_num_threads() == max(1, init_threads - 3)

env2 = ParallelEnv(2, lambda: GymEnv("Pendulum-v1"))
env2 = ParallelEnv(2, ContinuousActionVecMockEnv)
env2.rollout(2)

assert torch.get_num_threads() == max(1, init_threads - 5)
Expand Down
1 change: 0 additions & 1 deletion torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,3 @@
filter_warnings_subprocess = True

_THREAD_POOL_INIT = torch.get_num_threads()
_THREAD_POOL = torch.get_num_threads()
15 changes: 5 additions & 10 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,18 +1607,12 @@ def _queue_len(self) -> int:

def _run_processes(self) -> None:
if self.num_threads is None:
import torchrl

total_workers = self._total_workers_from_env(self.create_env_fn)
self.num_threads = max(
1, torchrl._THREAD_POOL - total_workers
1, torch.get_num_threads() - total_workers
) # 1 more thread for this proc

torch.set_num_threads(self.num_threads)
assert torch.get_num_threads() == self.num_threads
import torchrl

torchrl._THREAD_POOL = self.num_threads
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
self.procs = []
self.pipes = []
Expand Down Expand Up @@ -1727,11 +1721,12 @@ def _shutdown_main(self) -> None:
finally:
import torchrl

torchrl._THREAD_POOL = min(
num_threads = min(
torchrl._THREAD_POOL_INIT,
torchrl._THREAD_POOL + self._total_workers_from_env(self.create_env_fn),
torch.get_num_threads()
+ self._total_workers_from_env(self.create_env_fn),
)
torch.set_num_threads(torchrl._THREAD_POOL)
torch.set_num_threads(num_threads)

for proc in self.procs:
if proc.is_alive():
Expand Down
13 changes: 4 additions & 9 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,10 +633,10 @@ def close(self) -> None:
self.is_closed = True
import torchrl

torchrl._THREAD_POOL = min(
torchrl._THREAD_POOL_INIT, torchrl._THREAD_POOL + self.num_workers
num_threads = min(
torchrl._THREAD_POOL_INIT, torch.get_num_threads() + self.num_workers
)
torch.set_num_threads(torchrl._THREAD_POOL)
torch.set_num_threads(num_threads)

def _shutdown_workers(self) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -1015,16 +1015,11 @@ def _start_workers(self) -> None:
from torchrl.envs.env_creator import EnvCreator

if self.num_threads is None:
import torchrl

self.num_threads = max(
1, torchrl._THREAD_POOL - self.num_workers
1, torch.get_num_threads() - self.num_workers
) # 1 more thread for this proc

torch.set_num_threads(self.num_threads)
import torchrl

torchrl._THREAD_POOL = self.num_threads

ctx = mp.get_context("spawn")

Expand Down

0 comments on commit 86b8918

Please sign in to comment.