Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quality] Filter out warnings in subprocs #1552

Merged
merged 6 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
init
  • Loading branch information
vmoens committed Sep 20, 2023
commit 81fddf279d651bc712eb20dd04bf1f41139f4810
10 changes: 9 additions & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
import torch
from packaging.version import parse

from torch import multiprocessing as mp
VERBOSE = strtobool(os.environ.get("VERBOSE", "0"))
_os_is_windows = sys.platform == "win32"
RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1"))
Expand Down Expand Up @@ -529,3 +529,11 @@ def clone(self):
def get_trace():
"""A simple debugging util to spot where a function is being called."""
traceback.print_stack()

class ProcessNoWarn(mp.Process):
def run(self, *args, **kwargs):
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return mp.Process.run(self, *args, **kwargs)

4 changes: 2 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
accept_remote_rref_udf_invocation,
prod,
RL_WARNINGS,
VERBOSE,
VERBOSE, ProcessNoWarn,
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
Expand Down Expand Up @@ -1335,7 +1335,7 @@ def _run_processes(self) -> None:
"idx": i,
"interruptor": self.interruptor,
}
proc = mp.Process(target=_main_async_collector, kwargs=kwargs)
proc = ProcessNoWarn(target=_main_async_collector, kwargs=kwargs)
# proc.daemon can't be set as daemonic processes may be launched by the process itself
try:
proc.start()
Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensordict import TensorDict
from torch import multiprocessing as mp, nn

from torchrl._utils import VERBOSE
from torchrl._utils import VERBOSE, ProcessNoWarn
from torchrl.collectors import MultiaSyncDataCollector
from torchrl.collectors.collectors import (
DataCollectorBase,
Expand Down Expand Up @@ -611,7 +611,7 @@ def _init_worker_dist_mp(self, i):
if not isinstance(env_make, (EnvBase, EnvCreator)):
env_make = CloudpickleWrapper(env_make)
TCP_PORT = self.tcp_port
job = mp.Process(
job = ProcessNoWarn(
target=_distributed_init_collection_node,
args=(
i + 1,
Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from torch import multiprocessing as mp, nn

from torch.distributed import rpc
from torchrl._utils import VERBOSE
from torchrl._utils import VERBOSE, ProcessNoWarn

from torchrl.collectors import MultiaSyncDataCollector
from torchrl.collectors.collectors import (
Expand Down Expand Up @@ -447,7 +447,7 @@ def _init_worker_rpc(self, executor, i):
print("job id", job.job_id) # ID of your job
return job
elif self.launcher == "mp":
job = mp.Process(
job = ProcessNoWarn(
target=_rpc_init_collection_node,
args=(
i + 1,
Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.cuda
from tensordict import TensorDict
from torch import multiprocessing as mp, nn
from torchrl._utils import VERBOSE
from torchrl._utils import VERBOSE, ProcessNoWarn

from torchrl.collectors import MultiaSyncDataCollector
from torchrl.collectors.collectors import (
Expand Down Expand Up @@ -397,7 +397,7 @@ def _init_worker_dist_mp(self, i):
env_make = self.env_constructors[i]
if not isinstance(env_make, (EnvBase, EnvCreator)):
env_make = CloudpickleWrapper(env_make)
job = mp.Process(
job = ProcessNoWarn(
target=_distributed_init_collection_node,
args=(
i + 1,
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from torch import multiprocessing as mp
from torchrl._utils import _check_for_faulty_process, VERBOSE
from torchrl._utils import _check_for_faulty_process, VERBOSE, ProcessNoWarn
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import get_env_metadata
Expand Down Expand Up @@ -715,7 +715,7 @@ def _start_workers(self) -> None:
if not isinstance(env_fun, EnvCreator):
env_fun = CloudpickleWrapper(env_fun)

process = ctx.Process(
process = ProcessNoWarn(
target=_run_worker_pipe_shared_mem,
args=(
parent_pipe,
Expand Down