Skip to content

Commit

Permalink
[Quality] Filter out warnings in subprocs (pytorch#1552)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 20, 2023
1 parent f13cd77 commit 1301d6c
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 12 deletions.
5 changes: 5 additions & 0 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ node or across multiple nodes.
building a parallel environment or collector can result in a slower collection
than using ``device="cuda"`` when available.

.. note::
Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others)
warnings can quickly become quite annoying in multiprocessed / distributed settings.
By default, TorchRL filters out these warnings in sub-processes. If one still wishes to
see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``.

.. currentmodule:: torchrl.collectors.distributed

Expand Down
6 changes: 6 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ environments in parallel.
As this class inherits from :class:`SerialEnv`, it enjoys the exact same API as other environment.
Of course, a :class:`ParallelEnv` will have a batch size that corresponds to its environment count:

.. note::
Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others)
warnings can quickly become quite annoying in multiprocessed / distributed settings.
By default, TorchRL filters out these warnings in sub-processes. If one still wishes to
see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``.

It is important that your environment specs match the input and output that it sends and receives, as
:class:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :func:`~torchrl.envs.utils.check_env_specs` method for a sanity check.
Expand Down
4 changes: 4 additions & 0 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@
import torchrl.modules
import torchrl.objectives
import torchrl.trainers

# Filter warnings in subprocesses: True by default given the multiple optional
# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`.
filter_warnings_subprocess = True
21 changes: 21 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import numpy as np
import torch
from packaging.version import parse
from torch import multiprocessing as mp


VERBOSE = strtobool(os.environ.get("VERBOSE", "0"))
_os_is_windows = sys.platform == "win32"
Expand Down Expand Up @@ -529,3 +531,22 @@ def clone(self):
def get_trace():
"""A simple debugging util to spot where a function is being called."""
traceback.print_stack()


class _ProcessNoWarn(mp.Process):
"""A private Process class that shuts down warnings on the subprocess."""
@wraps(mp.Process.__init__)
def __init__(self, *args, **kwargs):
import torchrl

if torchrl.filter_warnings_subprocess:
self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
super().__init__(*args, **kwargs)

def run(self, *args, **kwargs):
if self.filter_warnings_subprocess:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return mp.Process.run(self, *args, **kwargs)
return mp.Process.run(self, *args, **kwargs)
3 changes: 2 additions & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torchrl._utils import (
_check_for_faulty_process,
accept_remote_rref_udf_invocation,
_ProcessNoWarn,
prod,
RL_WARNINGS,
VERBOSE,
Expand Down Expand Up @@ -1335,7 +1336,7 @@ def _run_processes(self) -> None:
"idx": i,
"interruptor": self.interruptor,
}
proc = mp.Process(target=_main_async_collector, kwargs=kwargs)
proc = _ProcessNoWarn(target=_main_async_collector, kwargs=kwargs)
# proc.daemon can't be set as daemonic processes may be launched by the process itself
try:
proc.start()
Expand Down
6 changes: 3 additions & 3 deletions torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

import torch.cuda
from tensordict import TensorDict
from torch import multiprocessing as mp, nn
from torch import nn

from torchrl._utils import VERBOSE
from torchrl._utils import _ProcessNoWarn, VERBOSE
from torchrl.collectors import MultiaSyncDataCollector
from torchrl.collectors.collectors import (
DataCollectorBase,
Expand Down Expand Up @@ -611,7 +611,7 @@ def _init_worker_dist_mp(self, i):
if not isinstance(env_make, (EnvBase, EnvCreator)):
env_make = CloudpickleWrapper(env_make)
TCP_PORT = self.tcp_port
job = mp.Process(
job = _ProcessNoWarn(
target=_distributed_init_collection_node,
args=(
i + 1,
Expand Down
6 changes: 3 additions & 3 deletions torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
SUBMITIT_ERR = err
import torch.cuda
from tensordict import TensorDict
from torch import multiprocessing as mp, nn
from torch import nn

from torch.distributed import rpc
from torchrl._utils import VERBOSE
from torchrl._utils import _ProcessNoWarn, VERBOSE

from torchrl.collectors import MultiaSyncDataCollector
from torchrl.collectors.collectors import (
Expand Down Expand Up @@ -447,7 +447,7 @@ def _init_worker_rpc(self, executor, i):
print("job id", job.job_id) # ID of your job
return job
elif self.launcher == "mp":
job = mp.Process(
job = _ProcessNoWarn(
target=_rpc_init_collection_node,
args=(
i + 1,
Expand Down
6 changes: 3 additions & 3 deletions torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

import torch.cuda
from tensordict import TensorDict
from torch import multiprocessing as mp, nn
from torchrl._utils import VERBOSE
from torch import nn
from torchrl._utils import _ProcessNoWarn, VERBOSE

from torchrl.collectors import MultiaSyncDataCollector
from torchrl.collectors.collectors import (
Expand Down Expand Up @@ -397,7 +397,7 @@ def _init_worker_dist_mp(self, i):
env_make = self.env_constructors[i]
if not isinstance(env_make, (EnvBase, EnvCreator)):
env_make = CloudpickleWrapper(env_make)
job = mp.Process(
job = _ProcessNoWarn(
target=_distributed_init_collection_node,
args=(
i + 1,
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from torch import multiprocessing as mp
from torchrl._utils import _check_for_faulty_process, VERBOSE
from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, VERBOSE
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import get_env_metadata
Expand Down Expand Up @@ -715,7 +715,7 @@ def _start_workers(self) -> None:
if not isinstance(env_fun, EnvCreator):
env_fun = CloudpickleWrapper(env_fun)

process = ctx.Process(
process = _ProcessNoWarn(
target=_run_worker_pipe_shared_mem,
args=(
parent_pipe,
Expand Down

0 comments on commit 1301d6c

Please sign in to comment.