From 1853aa577642435f07dc72152d6329a9ddae58a8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Sep 2023 05:41:02 -0400 Subject: [PATCH 1/5] init --- test/test_libs.py | 13 +++++++++++-- torchrl/__init__.py | 2 -- torchrl/envs/libs/robohive.py | 12 +++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index ef2be615c2e..053a24536c1 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -74,7 +74,7 @@ from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv -from torchrl.envs.libs.robohive import RoboHiveEnv +from torchrl.envs.libs.robohive import RoboHiveEnv, _has_robohive from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType @@ -1977,9 +1977,18 @@ def test_collector(self, task, parallel): break +@pytest.mark.skipif(not _has_robohive, reason="SMACv2 not found") class TestRoboHive: - @pytest.mark.parametrize("envname", RoboHiveEnv.env_list) + # unfortunately we must import robohive to get the available envs + # and this import will occur whenever pytest is run on this file. + # The other option would be not to use parametrize but that also + # means less informative error trace stacks. + # In the CI, robohive should not coexist with other libs so that's fine. + # Locally these imports can be annoying, especially given the amount of + # stuff printed by robohive. + @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs) @pytest.mark.parametrize("from_pixels", [True, False]) + @set_gym_backend("gym") def test_robohive(self, envname, from_pixels): if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")): print("not testing envs with prebuilt rendering") diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 5f928cd1ca6..bc240bbeed0 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -12,8 +12,6 @@ if torch.cuda.device_count() > 1: n = torch.cuda.device_count() - 1 os.environ["MUJOCO_EGL_DEVICE_ID"] = str(1 + (os.getpid() % n)) - # if VERBOSE: - print("MUJOCO_EGL_DEVICE_ID: ", os.environ["MUJOCO_EGL_DEVICE_ID"]) from ._extension import _init_extension diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index a1fc1d6d1ba..fd332062813 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -78,9 +78,15 @@ def CURR_DIR(cls): else: return None + @_classproperty + def available_envs(cls): + if not _has_robohive: + return + RoboHiveEnv.register_envs() + yield from cls.env_list + @classmethod def register_envs(cls): - if not _has_robohive: raise ImportError( "Cannot load robohive from the current virtual environment." @@ -333,7 +339,3 @@ def get_available_cams(cls, env_name): env = gym.make(env_name) cams = [env.sim.model.id2name(ic, 7) for ic in range(env.sim.model.ncam)] return cams - - -if _has_robohive: - RoboHiveEnv.register_envs() From a0fa6c172365049f08c8ec0f67fc144320b95bf3 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Sep 2023 05:45:33 -0400 Subject: [PATCH 2/5] lint --- test/test_libs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_libs.py b/test/test_libs.py index 053a24536c1..b23708c0924 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -74,7 +74,7 @@ from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv -from torchrl.envs.libs.robohive import RoboHiveEnv, _has_robohive +from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType From 81fddf279d651bc712eb20dd04bf1f41139f4810 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Sep 2023 06:04:02 -0400 Subject: [PATCH 3/5] init --- torchrl/_utils.py | 10 +++++++++- torchrl/collectors/collectors.py | 4 ++-- torchrl/collectors/distributed/generic.py | 4 ++-- torchrl/collectors/distributed/rpc.py | 4 ++-- torchrl/collectors/distributed/sync.py | 4 ++-- torchrl/envs/batched_envs.py | 4 ++-- 6 files changed, 19 insertions(+), 11 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 7f5b80175be..d6f6f6249cf 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -23,7 +23,7 @@ import numpy as np import torch from packaging.version import parse - +from torch import multiprocessing as mp VERBOSE = strtobool(os.environ.get("VERBOSE", "0")) _os_is_windows = sys.platform == "win32" RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1")) @@ -529,3 +529,11 @@ def clone(self): def get_trace(): """A simple debugging util to spot where a function is being called.""" traceback.print_stack() + +class ProcessNoWarn(mp.Process): + def run(self, *args, **kwargs): + import warnings + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return mp.Process.run(self, *args, **kwargs) + diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b06a9b15252..e41a5bd58b9 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -33,7 +33,7 @@ accept_remote_rref_udf_invocation, prod, RL_WARNINGS, - VERBOSE, + VERBOSE, ProcessNoWarn, ) from torchrl.collectors.utils import split_trajectories from torchrl.data.tensor_specs import CompositeSpec, TensorSpec @@ -1335,7 +1335,7 @@ def _run_processes(self) -> None: "idx": i, "interruptor": self.interruptor, } - proc = mp.Process(target=_main_async_collector, kwargs=kwargs) + proc = ProcessNoWarn(target=_main_async_collector, kwargs=kwargs) # proc.daemon can't be set as daemonic processes may be launched by the process itself try: proc.start() diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 8f37e2a8458..786f30dca79 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -16,7 +16,7 @@ from tensordict import TensorDict from torch import multiprocessing as mp, nn -from torchrl._utils import VERBOSE +from torchrl._utils import VERBOSE, ProcessNoWarn from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( DataCollectorBase, @@ -611,7 +611,7 @@ def _init_worker_dist_mp(self, i): if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) TCP_PORT = self.tcp_port - job = mp.Process( + job = ProcessNoWarn( target=_distributed_init_collection_node, args=( i + 1, diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 84eb963f04b..cf9ab4b6dec 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -35,7 +35,7 @@ from torch import multiprocessing as mp, nn from torch.distributed import rpc -from torchrl._utils import VERBOSE +from torchrl._utils import VERBOSE, ProcessNoWarn from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( @@ -447,7 +447,7 @@ def _init_worker_rpc(self, executor, i): print("job id", job.job_id) # ID of your job return job elif self.launcher == "mp": - job = mp.Process( + job = ProcessNoWarn( target=_rpc_init_collection_node, args=( i + 1, diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 4fe3a4aa460..81d54d86960 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -14,7 +14,7 @@ import torch.cuda from tensordict import TensorDict from torch import multiprocessing as mp, nn -from torchrl._utils import VERBOSE +from torchrl._utils import VERBOSE, ProcessNoWarn from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( @@ -397,7 +397,7 @@ def _init_worker_dist_mp(self, i): env_make = self.env_constructors[i] if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) - job = mp.Process( + job = ProcessNoWarn( target=_distributed_init_collection_node, args=( i + 1, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 56a368d1053..99a7f15fd6c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -20,7 +20,7 @@ from tensordict._tensordict import _unravel_key_to_tuple from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp -from torchrl._utils import _check_for_faulty_process, VERBOSE +from torchrl._utils import _check_for_faulty_process, VERBOSE, ProcessNoWarn from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import get_env_metadata @@ -715,7 +715,7 @@ def _start_workers(self) -> None: if not isinstance(env_fun, EnvCreator): env_fun = CloudpickleWrapper(env_fun) - process = ctx.Process( + process = ProcessNoWarn( target=_run_worker_pipe_shared_mem, args=( parent_pipe, From 8ae8edd83ed66a74a72e0b76bbd9d2dd9fcfa0b6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Sep 2023 06:12:39 -0400 Subject: [PATCH 4/5] aha! --- docs/source/reference/collectors.rst | 5 +++++ docs/source/reference/envs.rst | 6 ++++++ torchrl/__init__.py | 4 ++++ torchrl/_utils.py | 17 ++++++++++++----- torchrl/collectors/collectors.py | 3 ++- torchrl/collectors/distributed/generic.py | 4 ++-- torchrl/collectors/distributed/rpc.py | 4 ++-- torchrl/collectors/distributed/sync.py | 4 ++-- torchrl/envs/batched_envs.py | 2 +- 9 files changed, 36 insertions(+), 13 deletions(-) diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index d34a266d3db..aa8de179f20 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -105,6 +105,11 @@ node or across multiple nodes. building a parallel environment or collector can result in a slower collection than using ``device="cuda"`` when available. +.. note:: + Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) + warnings can quickly become quite annoying in multiprocessed / distributed settings. + By default, TorchRL filters out these warnings in sub-processes. If one still wishes to + see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. .. currentmodule:: torchrl.collectors.distributed diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index dc081183156..c49e12fc2b0 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -136,6 +136,12 @@ environments in parallel. As this class inherits from :class:`SerialEnv`, it enjoys the exact same API as other environment. Of course, a :class:`ParallelEnv` will have a batch size that corresponds to its environment count: +.. note:: + Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others) + warnings can quickly become quite annoying in multiprocessed / distributed settings. + By default, TorchRL filters out these warnings in sub-processes. If one still wishes to + see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``. + It is important that your environment specs match the input and output that it sends and receives, as :class:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. Check the :func:`~torchrl.envs.utils.check_env_specs` method for a sanity check. diff --git a/torchrl/__init__.py b/torchrl/__init__.py index bc240bbeed0..7d807244f70 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -41,3 +41,7 @@ import torchrl.modules import torchrl.objectives import torchrl.trainers + +# Filter warnings in subprocesses: True by default given the multiple optional +# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`. +filter_warnings_subprocess = True diff --git a/torchrl/_utils.py b/torchrl/_utils.py index d6f6f6249cf..047798508ce 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -24,6 +24,8 @@ import torch from packaging.version import parse from torch import multiprocessing as mp + + VERBOSE = strtobool(os.environ.get("VERBOSE", "0")) _os_is_windows = sys.platform == "win32" RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1")) @@ -530,10 +532,15 @@ def get_trace(): """A simple debugging util to spot where a function is being called.""" traceback.print_stack() + class ProcessNoWarn(mp.Process): - def run(self, *args, **kwargs): - import warnings - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - return mp.Process.run(self, *args, **kwargs) + def run(self, *args, **kwargs): + import torchrl + + if torchrl.filter_warnings_subprocess: + import warnings + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return mp.Process.run(self, *args, **kwargs) + return mp.Process.run(self, *args, **kwargs) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index e41a5bd58b9..e2002581871 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -31,9 +31,10 @@ from torchrl._utils import ( _check_for_faulty_process, accept_remote_rref_udf_invocation, + ProcessNoWarn, prod, RL_WARNINGS, - VERBOSE, ProcessNoWarn, + VERBOSE, ) from torchrl.collectors.utils import split_trajectories from torchrl.data.tensor_specs import CompositeSpec, TensorSpec diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 786f30dca79..8cef9d5afde 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -14,9 +14,9 @@ import torch.cuda from tensordict import TensorDict -from torch import multiprocessing as mp, nn +from torch import nn -from torchrl._utils import VERBOSE, ProcessNoWarn +from torchrl._utils import ProcessNoWarn, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( DataCollectorBase, diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index cf9ab4b6dec..6f86ef5f97e 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -32,10 +32,10 @@ SUBMITIT_ERR = err import torch.cuda from tensordict import TensorDict -from torch import multiprocessing as mp, nn +from torch import nn from torch.distributed import rpc -from torchrl._utils import VERBOSE, ProcessNoWarn +from torchrl._utils import ProcessNoWarn, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 81d54d86960..08c706c6d9d 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -13,8 +13,8 @@ import torch.cuda from tensordict import TensorDict -from torch import multiprocessing as mp, nn -from torchrl._utils import VERBOSE, ProcessNoWarn +from torch import nn +from torchrl._utils import ProcessNoWarn, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 99a7f15fd6c..c577f5ed45e 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -20,7 +20,7 @@ from tensordict._tensordict import _unravel_key_to_tuple from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp -from torchrl._utils import _check_for_faulty_process, VERBOSE, ProcessNoWarn +from torchrl._utils import _check_for_faulty_process, ProcessNoWarn, VERBOSE from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import get_env_metadata From 8acbeb3ae92d0bba390292658fd4bd9b9dc16905 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 20 Sep 2023 09:28:14 -0400 Subject: [PATCH 5/5] amend --- torchrl/_utils.py | 12 +++++++++--- torchrl/collectors/collectors.py | 4 ++-- torchrl/collectors/distributed/generic.py | 4 ++-- torchrl/collectors/distributed/rpc.py | 4 ++-- torchrl/collectors/distributed/sync.py | 4 ++-- torchrl/envs/batched_envs.py | 4 ++-- 6 files changed, 19 insertions(+), 13 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 047798508ce..48670d3e8de 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -533,13 +533,19 @@ def get_trace(): traceback.print_stack() -class ProcessNoWarn(mp.Process): - def run(self, *args, **kwargs): +class _ProcessNoWarn(mp.Process): + """A private Process class that shuts down warnings on the subprocess.""" + @wraps(mp.Process.__init__) + def __init__(self, *args, **kwargs): import torchrl if torchrl.filter_warnings_subprocess: - import warnings + self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess + super().__init__(*args, **kwargs) + def run(self, *args, **kwargs): + if self.filter_warnings_subprocess: + import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore") return mp.Process.run(self, *args, **kwargs) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index e2002581871..16a50df1c0f 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -31,7 +31,7 @@ from torchrl._utils import ( _check_for_faulty_process, accept_remote_rref_udf_invocation, - ProcessNoWarn, + _ProcessNoWarn, prod, RL_WARNINGS, VERBOSE, @@ -1336,7 +1336,7 @@ def _run_processes(self) -> None: "idx": i, "interruptor": self.interruptor, } - proc = ProcessNoWarn(target=_main_async_collector, kwargs=kwargs) + proc = _ProcessNoWarn(target=_main_async_collector, kwargs=kwargs) # proc.daemon can't be set as daemonic processes may be launched by the process itself try: proc.start() diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 8cef9d5afde..752a09231c0 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -16,7 +16,7 @@ from tensordict import TensorDict from torch import nn -from torchrl._utils import ProcessNoWarn, VERBOSE +from torchrl._utils import _ProcessNoWarn, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( DataCollectorBase, @@ -611,7 +611,7 @@ def _init_worker_dist_mp(self, i): if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) TCP_PORT = self.tcp_port - job = ProcessNoWarn( + job = _ProcessNoWarn( target=_distributed_init_collection_node, args=( i + 1, diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 6f86ef5f97e..5fef2dd1666 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -35,7 +35,7 @@ from torch import nn from torch.distributed import rpc -from torchrl._utils import ProcessNoWarn, VERBOSE +from torchrl._utils import _ProcessNoWarn, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( @@ -447,7 +447,7 @@ def _init_worker_rpc(self, executor, i): print("job id", job.job_id) # ID of your job return job elif self.launcher == "mp": - job = ProcessNoWarn( + job = _ProcessNoWarn( target=_rpc_init_collection_node, args=( i + 1, diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 08c706c6d9d..66e55318832 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -14,7 +14,7 @@ import torch.cuda from tensordict import TensorDict from torch import nn -from torchrl._utils import ProcessNoWarn, VERBOSE +from torchrl._utils import _ProcessNoWarn, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( @@ -397,7 +397,7 @@ def _init_worker_dist_mp(self, i): env_make = self.env_constructors[i] if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) - job = ProcessNoWarn( + job = _ProcessNoWarn( target=_distributed_init_collection_node, args=( i + 1, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index c577f5ed45e..67944af1f35 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -20,7 +20,7 @@ from tensordict._tensordict import _unravel_key_to_tuple from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp -from torchrl._utils import _check_for_faulty_process, ProcessNoWarn, VERBOSE +from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, VERBOSE from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import get_env_metadata @@ -715,7 +715,7 @@ def _start_workers(self) -> None: if not isinstance(env_fun, EnvCreator): env_fun = CloudpickleWrapper(env_fun) - process = ProcessNoWarn( + process = _ProcessNoWarn( target=_run_worker_pipe_shared_mem, args=( parent_pipe,