Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] thread setting bug #1852

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Jan 31, 2024
commit 1337562049135d504f60b2e63dff0d05ab459f95
1 change: 0 additions & 1 deletion torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,3 @@
filter_warnings_subprocess = True

_THREAD_POOL_INIT = torch.get_num_threads()
_THREAD_POOL = torch.get_num_threads()
14 changes: 4 additions & 10 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,18 +1607,12 @@ def _queue_len(self) -> int:

def _run_processes(self) -> None:
if self.num_threads is None:
import torchrl

total_workers = self._total_workers_from_env(self.create_env_fn)
self.num_threads = max(
1, torchrl._THREAD_POOL - total_workers
1, torch.get_num_threads() - total_workers
) # 1 more thread for this proc

torch.set_num_threads(self.num_threads)
assert torch.get_num_threads() == self.num_threads
import torchrl

torchrl._THREAD_POOL = self.num_threads
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
self.procs = []
self.pipes = []
Expand Down Expand Up @@ -1727,11 +1721,11 @@ def _shutdown_main(self) -> None:
finally:
import torchrl

torchrl._THREAD_POOL = min(
num_threads = min(
torchrl._THREAD_POOL_INIT,
torchrl._THREAD_POOL + self._total_workers_from_env(self.create_env_fn),
torch.get_num_threads() + self._total_workers_from_env(self.create_env_fn),
)
torch.set_num_threads(torchrl._THREAD_POOL)
torch.set_num_threads(num_threads)

for proc in self.procs:
if proc.is_alive():
Expand Down
13 changes: 4 additions & 9 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,10 +633,10 @@ def close(self) -> None:
self.is_closed = True
import torchrl

torchrl._THREAD_POOL = min(
torchrl._THREAD_POOL_INIT, torchrl._THREAD_POOL + self.num_workers
num_threads = min(
torchrl._THREAD_POOL_INIT, torch.get_num_threads() + self.num_workers
)
torch.set_num_threads(torchrl._THREAD_POOL)
torch.set_num_threads(num_threads)

def _shutdown_workers(self) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -1015,16 +1015,11 @@ def _start_workers(self) -> None:
from torchrl.envs.env_creator import EnvCreator

if self.num_threads is None:
import torchrl

self.num_threads = max(
1, torchrl._THREAD_POOL - self.num_workers
1, torch.get_num_threads() - self.num_workers
) # 1 more thread for this proc

torch.set_num_threads(self.num_threads)
import torchrl

torchrl._THREAD_POOL = self.num_threads

ctx = mp.get_context("spawn")

Expand Down
Loading