From 1f539dd4f7a10c7c8ee67d5dc1d6eb60ad6c49f0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 6 Oct 2023 10:16:01 +0100 Subject: [PATCH 01/67] init --- torchrl/collectors/collectors.py | 9 +------- torchrl/envs/common.py | 39 +++++++++++++++++++++++++------- torchrl/envs/utils.py | 36 +++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 16 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0d5443b22b4..44e9a357790 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -794,16 +794,9 @@ def _step_and_maybe_reset(self) -> None: traj_ids = traj_ids.clone() # collectors do not support passing other tensors than `"_reset"` # to `reset()`. - traj_sop = _aggregate_resets(td_reset, reset_keys=self.env.reset_keys) td_reset = self.env.reset(td_reset) - if td_reset.batch_dims: - # better cloning here than when passing the td for stacking - # cloning is necessary to avoid modifying entries in-place - self._tensordict = torch.where(traj_sop, td_reset, self._tensordict) - else: - self._tensordict.update(td_reset) - + traj_sop = _aggregate_resets(td_reset, reset_keys=self.env.reset_keys) traj_ids[traj_sop] = traj_ids.max() + torch.arange( 1, traj_sop.sum() + 1, device=traj_ids.device ) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 55e057ffa47..2bd3bccd9e5 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -7,7 +7,7 @@ import abc from copy import deepcopy -from typing import Any, Callable, Dict, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple import numpy as np import torch @@ -26,6 +26,7 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.utils import ( _replace_last, + _update_during_reset, get_available_libraries, step_mdp, terminated_or_truncated, @@ -1535,13 +1536,11 @@ def reset( raise RuntimeError( f"Env done entry '{done_key}' was (partially) True after a call to reset(). This is not allowed." ) - - if tensordict is not None: - tensordict.update(tensordict_reset) - else: - tensordict = tensordict_reset - tensordict.exclude(*self.reset_keys, inplace=True) - return tensordict + return ( + _update_during_reset(tensordict_reset, tensordict, self.reset_keys) + if tensordict is not None + else tensordict_reset + ) def numel(self) -> int: return prod(self.batch_size) @@ -1836,6 +1835,30 @@ def policy(td): out_td.refine_names(..., "time") return out_td + def step_and_maybe_reset( + self, tensordict: TensorDictBase + ) -> Tuple[TensorDictBase, TensorDictBase]: + tensordict = self.step(tensordict) + tensordict_ = step_mdp( + tensordict, + keep_other=True, + exclude_action=False, + exclude_reward=True, + reward_keys=self.reward_keys, + action_keys=self.action_keys, + done_keys=self.done_keys, + ) + # done and truncated are in done_keys + # We read if any key is done. + any_done = terminated_or_truncated( + tensordict, + full_done_spec=self.output_spec["full_done_spec"], + key="_reset", + ) + if any_done: + tensordict_ = self.reset(tensordict_) + return tensordict, tensordict_ + @property def reset_keys(self) -> List[NestedKey]: """Returns a list of reset keys. diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index cc1f05d3ffb..42d4da3919e 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -899,3 +899,39 @@ def skim_through(td, reset=reset): reset = skim_through(data) return reset + + +def _update_during_reset( + tensordict_reset: TensorDictBase, + tensordict: TensorDictBase, + reset_keys: List[NestedKey], +): + for reset_key in reset_keys: + # get the node of the reset key + if isinstance(reset_key, tuple): + # the reset key *must* have gone through unravel_key + # we don't test it to avoid induced overhead + node_key = reset_key[:-1] + node_reset = tensordict_reset.get(node_key) + node = tensordict.get(node_key) + else: + node_reset = tensordict_reset + node = tensordict + # get the reset signal + reset = tensordict.pop(reset_key, None) + if reset is None or reset.all(): + # perform simple update, at a single level. + # by contract, a reset signal at one level cannot + # be followed by other resets at nested levels, so it's safe to + # simply update + node.update(node_reset) + else: + # there can be two cases: (1) the key is present in both tds, + # in which case we use the reset mask to update + # (2) the key is not present in the input tensordict, in which + # case we just return the data + + # empty tensordicts won't be returned + reset = reset.reshape(node) + node.where(reset, node_reset, out=node, pad=0) + return tensordict From d4c16e19dd1f221ec1322e8caf013bb3ad547f70 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 6 Oct 2023 12:22:34 +0100 Subject: [PATCH 02/67] amend --- test/test_collector.py | 56 +++++++++++++++++++------------- torchrl/collectors/collectors.py | 41 ++++++++--------------- torchrl/envs/batched_envs.py | 52 ++++++++++++++++++++++++++++- torchrl/envs/common.py | 17 +++++----- torchrl/envs/utils.py | 11 ++++--- 5 files changed, 114 insertions(+), 63 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 3d71bb09a8c..4e93e351fe6 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -337,7 +337,8 @@ def env_fn(seed): @pytest.mark.skipif(not _has_gym, reason="gym library is not installed") -def test_collector_env_reset(): +@pytest.mark.parametrize("parallel", [False, True]) +def test_collector_env_reset(parallel): torch.manual_seed(0) def make_env(): @@ -346,27 +347,38 @@ def make_env(): with set_gym_backend(gym_backend()): return TransformedEnv(GymEnv(PONG_VERSIONED, frame_skip=4), StepCounter()) - env = SerialEnv(2, make_env) - # env = SerialEnv(2, lambda: GymEnv("CartPole-v1", frame_skip=4)) - env.set_seed(0) - collector = SyncDataCollector( - env, policy=None, total_frames=10000, frames_per_batch=10000, split_trajs=False - ) - for _data in collector: - continue - steps = _data["next", "step_count"][..., 1:, :] - done = _data["next", "done"][..., :-1, :] - # we don't want just one done - assert done.sum() > 3 - # check that after a done, the next step count is always 1 - assert (steps[done] == 1).all() - # check that if the env is not done, the next step count is > 1 - assert (steps[~done] > 1).all() - # check that if step is 1, then the env was done before - assert (steps == 1)[done].all() - # check that split traj has a minimum total reward of -21 (for pong only) - _data = split_trajectories(_data, prefix="collector") - assert _data["next", "reward"].sum(-2).min() == -21 + if parallel: + env = ParallelEnv(2, make_env) + else: + env = SerialEnv(2, make_env) + try: + # env = SerialEnv(2, lambda: GymEnv("CartPole-v1", frame_skip=4)) + env.set_seed(0) + collector = SyncDataCollector( + env, + policy=None, + total_frames=10001, + frames_per_batch=10000, + split_trajs=False, + ) + for _data in collector: + break + steps = _data["next", "step_count"][..., 1:, :] + done = _data["next", "done"][..., :-1, :] + # we don't want just one done + assert done.sum() > 3 + # check that after a done, the next step count is always 1 + assert (steps[done] == 1).all() + # check that if the env is not done, the next step count is > 1 + assert (steps[~done] > 1).all() + # check that if step is 1, then the env was done before + assert (steps == 1)[done].all() + # check that split traj has a minimum total reward of -21 (for pong only) + _data = split_trajectories(_data, prefix="collector") + assert _data["next", "reward"].sum(-2).min() == -21 + finally: + env.close() + del env # Deprecated reset_when_done diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 44e9a357790..d25b99414ef 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -772,31 +772,14 @@ def iterator(self) -> Iterator[TensorDictBase]: # >>> assert data0["done"] is not data1["done"] yield tensordict_out.clone() - def _step_and_maybe_reset(self) -> None: - - self._tensordict = step_mdp( - self._tensordict, - reward_keys=self.env.reward_keys, - done_keys=self.env.done_keys, - action_keys=self.env.action_keys, - ) - if not self.reset_when_done: - return - td_reset = self._tensordict.clone(False) - any_done = terminated_or_truncated( - td_reset, - full_done_spec=self.env.output_spec["full_done_spec"], - key="_reset", + def _update_traj_ids(self, tensordict) -> None: + # we can't use the reset keys because they're gone + traj_sop = _aggregate_resets( + tensordict.get("next"), done_keys=self.env.done_keys ) - - if any_done: + if traj_sop.any(): traj_ids = self._tensordict.get(("collector", "traj_ids")) traj_ids = traj_ids.clone() - # collectors do not support passing other tensors than `"_reset"` - # to `reset()`. - td_reset = self.env.reset(td_reset) - - traj_sop = _aggregate_resets(td_reset, reset_keys=self.env.reset_keys) traj_ids[traj_sop] = traj_ids.max() + torch.arange( 1, traj_sop.sum() + 1, device=traj_ids.device ) @@ -822,14 +805,18 @@ def rollout(self) -> TensorDictBase: self.init_random_frames is not None and self._frames < self.init_random_frames ): - self.env.rand_step(self._tensordict) + self.env.rand_action(self._tensordict) else: self.policy(self._tensordict) - self.env.step(self._tensordict) - # we must clone all the values, since the step / traj_id updates are done in-place - tensordicts.append(self._tensordict.to(self.storing_device)) + tensordict, tensordict_ = self.env.step_and_maybe_reset( + self._tensordict + ) + self._tensordict = tensordict_.set( + "collector", tensordict.get("collector").clone(False) + ) + tensordicts.append(tensordict.to(self.storing_device)) - self._step_and_maybe_reset() + self._update_traj_ids(tensordict) if ( self.interruptor is not None and self.interruptor.collection_stopped() diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index aa1f256c070..fe2024fe32c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -11,7 +11,7 @@ from functools import wraps from multiprocessing import connection from multiprocessing.synchronize import Lock as MpLock -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from warnings import warn import torch @@ -795,6 +795,43 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: event.wait() event.clear() + @_check_start + def step_and_maybe_reset( + self, tensordict: TensorDictBase + ) -> Tuple[TensorDictBase, TensorDictBase]: + if self._single_task and not self.has_lazy_inputs: + # this is faster than update_ but won't work for lazy stacks + for key in self._env_input_keys: + key = _unravel_key_to_tuple(key) + self.shared_tensordict_parent._set_tuple( + key, + tensordict._get_tuple(key, None), + inplace=True, + validated=True, + ) + else: + self.shared_tensordict_parent.update_( + tensordict.select(*self._env_input_keys, strict=False) + ) + if self.event is not None: + self.event.record() + self.event.synchronize() + for i in range(self.num_workers): + self.parent_channels[i].send(("step_and_maybe_reset", None)) + + for i in range(self.num_workers): + event = self._events[i] + event.wait() + event.clear() + + # We must pass a clone of the tensordict, as the values of this tensordict + # will be modified in-place at further steps + tensordict.set("next", self.shared_tensordict_parent.get("next").clone()) + tensordict_ = self.shared_tensordict_parent.exclude( + "next", *self.reset_keys + ).clone() + return tensordict, tensordict_ + @_check_start def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self._single_task and not self.has_lazy_inputs: @@ -1090,6 +1127,19 @@ def _run_worker_pipe_shared_mem( event.synchronize() mp_event.set() + elif cmd == "step_and_maybe_reset": + if not initialized: + raise RuntimeError("called 'init' before step") + i += 1 + td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) + assert "_reset" not in td.get("next").keys() + next_shared_tensordict.update_(td.get("next")) + shared_tensordict.update_(root_next_td) + if event is not None: + event.record() + event.synchronize() + mp_event.set() + elif cmd == "close": del shared_tensordict, data if not initialized: diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 2bd3bccd9e5..fa2ebae1493 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -7,7 +7,7 @@ import abc from copy import deepcopy -from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import numpy as np import torch @@ -1536,11 +1536,10 @@ def reset( raise RuntimeError( f"Env done entry '{done_key}' was (partially) True after a call to reset(). This is not allowed." ) - return ( - _update_during_reset(tensordict_reset, tensordict, self.reset_keys) - if tensordict is not None - else tensordict_reset - ) + if tensordict is not None: + result = _update_during_reset(tensordict_reset, tensordict, self.reset_keys) + return result + return tensordict_reset def numel(self) -> int: return prod(self.batch_size) @@ -1839,6 +1838,8 @@ def step_and_maybe_reset( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: tensordict = self.step(tensordict) + # done and truncated are in done_keys + # We read if any key is done. tensordict_ = step_mdp( tensordict, keep_other=True, @@ -1848,10 +1849,8 @@ def step_and_maybe_reset( action_keys=self.action_keys, done_keys=self.done_keys, ) - # done and truncated are in done_keys - # We read if any key is done. any_done = terminated_or_truncated( - tensordict, + tensordict_, full_done_spec=self.output_spec["full_done_spec"], key="_reset", ) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 42d4da3919e..85e56294de6 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -851,12 +851,15 @@ def inner_terminated_or_truncated(data, full_done_spec, key, curr_done_key=()): PARTIAL_MISSING_ERR = "Some reset keys were present but not all. Either all the `'_reset'` entries must be present, or none." -def _aggregate_resets(data: TensorDictBase, reset_keys=None) -> torch.Tensor: +def _aggregate_resets( + data: TensorDictBase, reset_keys=None, done_keys=None +) -> torch.Tensor: # goes through the tensordict and brings the _reset information to # a boolean tensor of the shape of the tensordict. batch_size = data.batch_size n = len(batch_size) - + if done_keys is not None and reset_keys is None: + reset_keys = {_replace_last(key, "done") for key in done_keys} if reset_keys is not None: reset = False has_missing = None @@ -932,6 +935,6 @@ def _update_during_reset( # case we just return the data # empty tensordicts won't be returned - reset = reset.reshape(node) - node.where(reset, node_reset, out=node, pad=0) + reset = reset.reshape(node.shape) + node.where(~reset, other=node_reset, out=node, pad=0) return tensordict From d2321aa33706faef1b1ff995753d20492940b9e0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 6 Oct 2023 13:57:19 +0100 Subject: [PATCH 03/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 4 +- torchrl/envs/common.py | 46 +++++++++++++++++++--- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 457f15a2b5a..04749ca6a68 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -219,7 +219,7 @@ def make_env( penv = EnvCreator( lambda num_workers=num_workers // num_collectors: make_env( - num_workers + num_workers=num_workers ) ) collector = MultiaSyncDataCollector( @@ -306,7 +306,7 @@ def make_env( penv = EnvCreator( lambda num_workers=num_workers // num_collectors: make_env( - num_workers + num_workers=num_workers ) ) collector = MultiSyncDataCollector( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index fa2ebae1493..fe855a82fe8 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1789,7 +1789,21 @@ def rollout( def policy(td): self.rand_action(td) return td - + kwargs = { + "tensordict": tensordict, + "auto_cast_to_device": auto_cast_to_device, + "max_steps": max_steps, + "policy": policy, + "policy_device": policy_device, + "env_device": env_device, + "callback": callback, + "return_contiguous": return_contiguous, + } + if break_when_any_done: + return self._rollout_stop_early(**kwargs) + return self._rollout_nonstop(**kwargs) + + def _rollout_stop_early(self, *, tensordict, auto_cast_to_device, max_steps, policy, policy_device, env_device, callback, return_contiguous): tensordicts = [] for i in range(max_steps): if auto_cast_to_device: @@ -1817,12 +1831,10 @@ def policy(td): any_done = terminated_or_truncated( tensordict, full_done_spec=self.output_spec["full_done_spec"], - key=None if break_when_any_done else "_reset", + key=None, ) - if break_when_any_done and any_done: + if any_done: break - if not break_when_any_done and any_done: - tensordict = self.reset(tensordict) if callback is not None: callback(self, tensordict) @@ -1834,6 +1846,30 @@ def policy(td): out_td.refine_names(..., "time") return out_td + def _rollout_nonstop(self, *, tensordict, auto_cast_to_device, max_steps, policy, policy_device, env_device, callback, return_contiguous): + tensordicts = [] + tensordict_ = tensordict + for i in range(max_steps): + if auto_cast_to_device: + tensordict_ = tensordict_.to(policy_device, non_blocking=True) + tensordict = policy(tensordict_) + if auto_cast_to_device: + tensordict_ = tensordict.to(env_device, non_blocking=True) + tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) + tensordicts.append(tensordict.clone(False)) + if i == max_steps - 1: + # we don't truncated as one could potentially continue the run + break + if callback is not None: + callback(self, tensordict) + + batch_size = self.batch_size if tensordict is None else tensordict.batch_size + out_td = torch.stack(tensordicts, len(batch_size)) + if return_contiguous: + out_td = out_td.contiguous() + out_td.refine_names(..., "time") + return out_td + def step_and_maybe_reset( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: From 565115a4617978550cb2a91e26f905c5754cc8e7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 6 Oct 2023 14:07:12 +0100 Subject: [PATCH 04/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 04749ca6a68..a15c5f82012 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -39,7 +39,7 @@ # the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes for num_workers, num_collectors in zip((8, 16, 32, 64), (2, 4, 8, 8)): with open( - f"atari_{envname}_{num_workers}.txt".replace("/", "-"), "w+" + f"{envname}_{num_workers}.txt".replace("/", "-"), "w+" ) as log: if "myo" in envname: gym_backend = "gym" From 3c46136c39c46b21ccfd01e676c28d3acb9cffbf Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 6 Oct 2023 15:07:42 +0100 Subject: [PATCH 05/67] amend --- test/test_collector.py | 27 ++++++++++++++------------- torchrl/envs/utils.py | 1 + 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 4e93e351fe6..da7fc1e890b 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1395,19 +1395,20 @@ def test_reset_heterogeneous_envs(): collector = SyncDataCollector( env, RandomPolicy(env.action_spec), total_frames=10_000, frames_per_batch=1000 ) - for data in collector: # noqa: B007 - break - collector.shutdown() - del collector - assert ( - data[0]["next", "truncated"].squeeze() - == torch.tensor([False, True]).repeat(250)[:500] - ).all() - assert ( - data[1]["next", "truncated"].squeeze() - == torch.tensor([False, False, True]).repeat(168)[:500] - ).all() - + try: + for data in collector: # noqa: B007 + break + assert ( + data[0]["next", "truncated"].squeeze() + == torch.tensor([False, True]).repeat(250)[:500] + ).all(), data[0]["next", "truncated"][:10] + assert ( + data[1]["next", "truncated"].squeeze() + == torch.tensor([False, False, True]).repeat(168)[:500] + ).all(), data[1]["next", "truncated"][:10] + finally: + collector.shutdown() + del collector def test_policy_with_mask(): env = CountingBatchedEnv(start_val=torch.tensor(10), max_steps=torch.tensor(1e5)) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 85e56294de6..d93bcb9ddba 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -936,5 +936,6 @@ def _update_during_reset( # empty tensordicts won't be returned reset = reset.reshape(node.shape) + # node.update(node.where(~reset, other=node_reset, pad=0)) node.where(~reset, other=node_reset, out=node, pad=0) return tensordict From 78cfa41015f9845ae17beddfb0d9942a89a560d2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 6 Oct 2023 15:32:27 +0100 Subject: [PATCH 06/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index a15c5f82012..ac7abad46a5 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -30,14 +30,15 @@ if __name__ == "__main__": for envname in [ - "HalfCheetah-v4", "CartPole-v1", + "HalfCheetah-v4", "myoHandReachRandom-v0", "ALE/Breakout-v5", "CartPole-v1", ]: # the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes - for num_workers, num_collectors in zip((8, 16, 32, 64), (2, 4, 8, 8)): + for num_workers, num_collectors in zip((32, 64, 8, 16), (8,82, 4)): + # for num_workers, num_collectors in zip((8, 16, 32, 64), (2, 4, 8, 8)): with open( f"{envname}_{num_workers}.txt".replace("/", "-"), "w+" ) as log: From a6bd8eb779328834da17dd55ad7020e14cc14164 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 6 Oct 2023 15:57:59 +0100 Subject: [PATCH 07/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index ac7abad46a5..5c17a72823d 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -37,8 +37,7 @@ "CartPole-v1", ]: # the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes - for num_workers, num_collectors in zip((32, 64, 8, 16), (8,82, 4)): - # for num_workers, num_collectors in zip((8, 16, 32, 64), (2, 4, 8, 8)): + for num_workers, num_collectors in zip((32, 64, 8, 16), (8,8, 2, 4)): with open( f"{envname}_{num_workers}.txt".replace("/", "-"), "w+" ) as log: From 04d4ae7120b7323687f256888f760356ad32fc58 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 6 Oct 2023 15:58:10 +0100 Subject: [PATCH 08/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 5c17a72823d..1b065ecce89 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -34,7 +34,6 @@ "HalfCheetah-v4", "myoHandReachRandom-v0", "ALE/Breakout-v5", - "CartPole-v1", ]: # the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes for num_workers, num_collectors in zip((32, 64, 8, 16), (8,8, 2, 4)): From f1b0ea4c5d33bc6c71478911b249bd33bea8f0f5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 11:35:06 +0100 Subject: [PATCH 09/67] tensordict_ --- torchrl/envs/common.py | 10 +++++----- torchrl/envs/gym_like.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index fe855a82fe8..d2b3ec23e1a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1850,11 +1850,11 @@ def _rollout_nonstop(self, *, tensordict, auto_cast_to_device, max_steps, policy tensordicts = [] tensordict_ = tensordict for i in range(max_steps): - if auto_cast_to_device: - tensordict_ = tensordict_.to(policy_device, non_blocking=True) - tensordict = policy(tensordict_) - if auto_cast_to_device: - tensordict_ = tensordict.to(env_device, non_blocking=True) + # if auto_cast_to_device: + # tensordict_ = tensordict_.to(policy_device, non_blocking=True) + tensordict_ = policy(tensordict_) + # if auto_cast_to_device: + # tensordict_ = tensordict.to(env_device, non_blocking=True) tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) tensordicts.append(tensordict.clone(False)) if i == max_steps - 1: diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 79dc8c4ab64..90f89de79cc 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -262,7 +262,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: obs_dict["done"] = done obs_dict["terminated"] = terminated - tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size) + tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size, device=self.device) if self.info_dict_reader and info is not None: if not isinstance(info, dict): @@ -274,7 +274,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: out = info_dict_reader(info, tensordict_out) if out is not None: tensordict_out = out - tensordict_out = tensordict_out.to(self.device, non_blocking=True) + # tensordict_out = tensordict_out.to(self.device, non_blocking=True) return tensordict_out def _reset( From 16b3538e1d99f8b019d9ab742eace4c4573c6c15 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 11:41:20 +0100 Subject: [PATCH 10/67] amend rollout logic --- torchrl/envs/common.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index d2b3ec23e1a..dd4f3fa3344 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1800,8 +1800,15 @@ def policy(td): "return_contiguous": return_contiguous, } if break_when_any_done: - return self._rollout_stop_early(**kwargs) - return self._rollout_nonstop(**kwargs) + tensordicts = self._rollout_stop_early(**kwargs) + else: + tensordicts = self._rollout_nonstop(**kwargs) + batch_size = self.batch_size if tensordict is None else tensordict.batch_size + out_td = torch.stack(tensordicts, len(batch_size)) + if return_contiguous: + out_td = out_td.contiguous() + out_td.refine_names(..., "time") + return out_td def _rollout_stop_early(self, *, tensordict, auto_cast_to_device, max_steps, policy, policy_device, env_device, callback, return_contiguous): tensordicts = [] @@ -1838,37 +1845,26 @@ def _rollout_stop_early(self, *, tensordict, auto_cast_to_device, max_steps, pol if callback is not None: callback(self, tensordict) - - batch_size = self.batch_size if tensordict is None else tensordict.batch_size - out_td = torch.stack(tensordicts, len(batch_size)) - if return_contiguous: - out_td = out_td.contiguous() - out_td.refine_names(..., "time") - return out_td + return tensordicts def _rollout_nonstop(self, *, tensordict, auto_cast_to_device, max_steps, policy, policy_device, env_device, callback, return_contiguous): tensordicts = [] tensordict_ = tensordict for i in range(max_steps): - # if auto_cast_to_device: - # tensordict_ = tensordict_.to(policy_device, non_blocking=True) + if auto_cast_to_device: + tensordict_ = tensordict_.to(policy_device, non_blocking=True) tensordict_ = policy(tensordict_) - # if auto_cast_to_device: - # tensordict_ = tensordict.to(env_device, non_blocking=True) + if auto_cast_to_device: + tensordict_ = tensordict.to(env_device, non_blocking=True) tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) - tensordicts.append(tensordict.clone(False)) + tensordicts.append(tensordict) if i == max_steps - 1: # we don't truncated as one could potentially continue the run break if callback is not None: callback(self, tensordict) - batch_size = self.batch_size if tensordict is None else tensordict.batch_size - out_td = torch.stack(tensordicts, len(batch_size)) - if return_contiguous: - out_td = out_td.contiguous() - out_td.refine_names(..., "time") - return out_td + return tensordicts def step_and_maybe_reset( self, tensordict: TensorDictBase From 7ad08649f741c0b6148f44fe20c5b79636344e8d Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 11:52:53 +0100 Subject: [PATCH 11/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 2 +- test/test_collector.py | 1 + torchrl/collectors/collectors.py | 4 +-- torchrl/envs/common.py | 29 +++++++++++++++++++--- torchrl/envs/gym_like.py | 4 ++- 5 files changed, 31 insertions(+), 9 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 8319abf0999..146d011442d 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -36,7 +36,7 @@ "ALE/Breakout-v5", ]: # the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes - for num_workers, num_collectors in zip((32, 64, 8, 16), (8,8, 2, 4)): + for num_workers, num_collectors in zip((32, 64, 8, 16), (8, 8, 2, 4)): with open(f"{envname}_{num_workers}.txt".replace("/", "-"), "w+") as log: if "myo" in envname: gym_backend = "gym" diff --git a/test/test_collector.py b/test/test_collector.py index da7fc1e890b..75a192e647c 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1410,6 +1410,7 @@ def test_reset_heterogeneous_envs(): collector.shutdown() del collector + def test_policy_with_mask(): env = CountingBatchedEnv(start_val=torch.tensor(10), max_steps=torch.tensor(1e5)) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index c28d70a307e..584fcbfbca1 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -48,8 +48,6 @@ _convert_exploration_type, ExplorationType, set_exploration_type, - step_mdp, - terminated_or_truncated, ) _TIMEOUT = 1.0 @@ -828,7 +826,7 @@ def rollout(self) -> TensorDictBase: self._tensordict = tensordict_.set( "collector", tensordict.get("collector").clone(False) ) - tensordicts.append(tensordict.to(self.storing_device)) + tensordicts.append(tensordict.to(self.storing_device, non_blocking=True)) self._update_traj_ids(tensordict) if ( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index dd4f3fa3344..600fd780ee4 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1650,6 +1650,7 @@ def rollout( break_when_any_done: bool = True, return_contiguous: bool = True, tensordict: Optional[TensorDictBase] = None, + out = None, ): """Executes a rollout in the environment. @@ -1789,6 +1790,7 @@ def rollout( def policy(td): self.rand_action(td) return td + kwargs = { "tensordict": tensordict, "auto_cast_to_device": auto_cast_to_device, @@ -1797,20 +1799,29 @@ def policy(td): "policy_device": policy_device, "env_device": env_device, "callback": callback, - "return_contiguous": return_contiguous, } if break_when_any_done: tensordicts = self._rollout_stop_early(**kwargs) else: tensordicts = self._rollout_nonstop(**kwargs) batch_size = self.batch_size if tensordict is None else tensordict.batch_size - out_td = torch.stack(tensordicts, len(batch_size)) + out_td = torch.stack(tensordicts, len(batch_size), out=out) if return_contiguous: out_td = out_td.contiguous() out_td.refine_names(..., "time") return out_td - def _rollout_stop_early(self, *, tensordict, auto_cast_to_device, max_steps, policy, policy_device, env_device, callback, return_contiguous): + def _rollout_stop_early( + self, + *, + tensordict, + auto_cast_to_device, + max_steps, + policy, + policy_device, + env_device, + callback, + ): tensordicts = [] for i in range(max_steps): if auto_cast_to_device: @@ -1847,7 +1858,17 @@ def _rollout_stop_early(self, *, tensordict, auto_cast_to_device, max_steps, pol callback(self, tensordict) return tensordicts - def _rollout_nonstop(self, *, tensordict, auto_cast_to_device, max_steps, policy, policy_device, env_device, callback, return_contiguous): + def _rollout_nonstop( + self, + *, + tensordict, + auto_cast_to_device, + max_steps, + policy, + policy_device, + env_device, + callback, + ): tensordicts = [] tensordict_ = tensordict for i in range(max_steps): diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 90f89de79cc..07bcab87506 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -262,7 +262,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: obs_dict["done"] = done obs_dict["terminated"] = terminated - tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size, device=self.device) + tensordict_out = TensorDict( + obs_dict, batch_size=tensordict.batch_size, device=self.device + ) if self.info_dict_reader and info is not None: if not isinstance(info, dict): From bcac3981a966272a53d78f4a146d84230767ee3f Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 11:53:55 +0100 Subject: [PATCH 12/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 146d011442d..2d0a2fab1b3 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -84,8 +84,9 @@ def make(envname=envname, gym_backend=gym_backend, device=device): penv.rollout(2) pbar = tqdm.tqdm(total=num_workers * 10_000) t0 = time.time() + data = None for _ in range(100): - data = penv.rollout(100, break_when_any_done=False) + data = penv.rollout(100, break_when_any_done=False, out=data) pbar.update(100 * num_workers) log.write( f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n" From 428f8ee60eeae11d8247450204d435341d388651 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 11:57:16 +0100 Subject: [PATCH 13/67] inference --- benchmarks/ecosystem/gym_env_throughput.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 2d0a2fab1b3..6982315f50f 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -17,6 +17,7 @@ import time import myosuite # noqa: F401 +import torch import tqdm from torchrl._utils import timeit from torchrl.collectors import ( @@ -81,13 +82,14 @@ def make(envname=envname, gym_backend=gym_backend, device=device): env_make = EnvCreator(make) penv = ParallelEnv(num_workers, env_make) # warmup - penv.rollout(2) - pbar = tqdm.tqdm(total=num_workers * 10_000) - t0 = time.time() - data = None - for _ in range(100): - data = penv.rollout(100, break_when_any_done=False, out=data) - pbar.update(100 * num_workers) + with torch.inference_mode(): + penv.rollout(2) + pbar = tqdm.tqdm(total=num_workers * 10_000) + t0 = time.time() + data = None + for _ in range(100): + data = penv.rollout(100, break_when_any_done=False, out=data) + pbar.update(100 * num_workers) log.write( f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n" ) From 6fbd0bd78116d90afe1ad3c932d89eb95e59e8ce Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 12:00:34 +0100 Subject: [PATCH 14/67] cpu -> cuda --- benchmarks/ecosystem/gym_env_throughput.py | 31 +++++++--------------- torchrl/envs/common.py | 6 ++--- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 6982315f50f..eb320a6d076 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -70,10 +70,7 @@ def make(envname=envname, gym_backend=gym_backend): log.flush() # regular parallel env - for device in ( - "cuda:0", - "cpu", - ): + for device in ("cpu", "cuda:0"): def make(envname=envname, gym_backend=gym_backend, device=device): with set_gym_backend(gym_backend): @@ -88,7 +85,9 @@ def make(envname=envname, gym_backend=gym_backend, device=device): t0 = time.time() data = None for _ in range(100): - data = penv.rollout(100, break_when_any_done=False, out=data) + data = penv.rollout( + 100, break_when_any_done=False, out=data + ) pbar.update(100 * num_workers) log.write( f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n" @@ -98,7 +97,7 @@ def make(envname=envname, gym_backend=gym_backend, device=device): timeit.print() del penv - for device in ("cuda:0", "cpu"): + for device in ("cpu", "cuda:0"): def make(envname=envname, gym_backend=gym_backend, device=device): with set_gym_backend(gym_backend): @@ -131,10 +130,7 @@ def make(envname=envname, gym_backend=gym_backend, device=device): collector.shutdown() del collector - for device in ( - "cuda:0", - "cpu", - ): + for device in ("cpu", "cuda:0"): # gym parallel env def make_env( envname=envname, @@ -201,10 +197,7 @@ def make_env( collector.shutdown() del collector - for device in ( - "cuda:0", - "cpu", - ): + for device in ("cpu", "cuda:0"): # async collector # + gym async env def make_env( @@ -248,10 +241,7 @@ def make_env( collector.shutdown() del collector - for device in ( - "cuda:0", - "cpu", - ): + for device in ("cpu", "cuda:0"): # sync collector # + torchrl parallel env def make_env( @@ -288,10 +278,7 @@ def make_env( collector.shutdown() del collector - for device in ( - "cuda:0", - "cpu", - ): + for device in ("cpu", "cuda:0"): # sync collector # + gym async env def make_env( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 600fd780ee4..46572bf69f0 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1650,7 +1650,7 @@ def rollout( break_when_any_done: bool = True, return_contiguous: bool = True, tensordict: Optional[TensorDictBase] = None, - out = None, + out=None, ): """Executes a rollout in the environment. @@ -1787,9 +1787,7 @@ def rollout( raise RuntimeError("tensordict must be provided when auto_reset is False") if policy is None: - def policy(td): - self.rand_action(td) - return td + policy = self.rand_action kwargs = { "tensordict": tensordict, From 02db623727eea37eeef3fa8dc9159a3149d6d2f2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 12:11:26 +0100 Subject: [PATCH 15/67] checks --- benchmarks/ecosystem/gym_env_throughput.py | 23 ++++++++++------------ torchrl/collectors/collectors.py | 4 +++- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index eb320a6d076..33f08a6a8dc 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -78,8 +78,8 @@ def make(envname=envname, gym_backend=gym_backend, device=device): env_make = EnvCreator(make) penv = ParallelEnv(num_workers, env_make) - # warmup with torch.inference_mode(): + # warmup penv.rollout(2) pbar = tqdm.tqdm(total=num_workers * 10_000) t0 = time.time() @@ -112,17 +112,17 @@ def make(envname=envname, gym_backend=gym_backend, device=device): frames_per_batch=1024, total_frames=num_workers * 10_000, ) + assert collector.env.device == torch.device(device) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 for i, data in enumerate(collector): - if i == num_collectors: - t0 = time.time() - if i >= num_collectors: - total_frames += data.numel() - pbar.update(data.numel()) - pbar.set_description( - f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps" - ) + t0 = time.time() + total_frames += data.numel() + pbar.update(data.numel()) + pbar.set_description( + f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps" + ) + assert data.device == torch.device(device) log.write( f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" ) @@ -157,10 +157,7 @@ def make_env( penv.close() del penv - for device in ( - "cuda:0", - "cpu", - ): + for device in ("cpu", "cuda:0"): # async collector # + torchrl parallel env def make_env( diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 584fcbfbca1..85ac4ab5fe5 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -826,7 +826,9 @@ def rollout(self) -> TensorDictBase: self._tensordict = tensordict_.set( "collector", tensordict.get("collector").clone(False) ) - tensordicts.append(tensordict.to(self.storing_device, non_blocking=True)) + tensordicts.append( + tensordict.to(self.storing_device, non_blocking=True) + ) self._update_traj_ids(tensordict) if ( From 08a8f4728b7f15184229a307385710ea3d7f6378 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 14:19:00 +0100 Subject: [PATCH 16/67] using pipe instead of event --- torchrl/envs/batched_envs.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index fe2024fe32c..9c8a2312c60 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -820,9 +820,11 @@ def step_and_maybe_reset( self.parent_channels[i].send(("step_and_maybe_reset", None)) for i in range(self.num_workers): - event = self._events[i] - event.wait() - event.clear() + msg = self.parent_channels[i].recv() + assert msg == "smr done" + # event = self._events[i] + # event.wait() + # event.clear() # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -1138,7 +1140,8 @@ def _run_worker_pipe_shared_mem( if event is not None: event.record() event.synchronize() - mp_event.set() + # mp_event.set() + child_pipe.send("smr done") elif cmd == "close": del shared_tensordict, data From 45e64f7121d6ad34cd981831d70012cde8988088 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 14:22:31 +0100 Subject: [PATCH 17/67] amend --- torchrl/envs/batched_envs.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9c8a2312c60..9576e584d27 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -820,11 +820,9 @@ def step_and_maybe_reset( self.parent_channels[i].send(("step_and_maybe_reset", None)) for i in range(self.num_workers): - msg = self.parent_channels[i].recv() - assert msg == "smr done" - # event = self._events[i] - # event.wait() - # event.clear() + event = self._events[i] + event.wait() + event.clear() # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -1134,14 +1132,14 @@ def _run_worker_pipe_shared_mem( raise RuntimeError("called 'init' before step") i += 1 td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) - assert "_reset" not in td.get("next").keys() + assert td.device == next_shared_tensordict.device + assert root_next_td.device == shared_tensordict.device next_shared_tensordict.update_(td.get("next")) shared_tensordict.update_(root_next_td) if event is not None: event.record() event.synchronize() - # mp_event.set() - child_pipe.send("smr done") + mp_event.set() elif cmd == "close": del shared_tensordict, data From 7dd48210747e7a88f730db51a0858f40979a1bcc Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 14:30:11 +0100 Subject: [PATCH 18/67] amend --- torchrl/envs/batched_envs.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9576e584d27..eb84a98060a 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1132,10 +1132,12 @@ def _run_worker_pipe_shared_mem( raise RuntimeError("called 'init' before step") i += 1 td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) - assert td.device == next_shared_tensordict.device - assert root_next_td.device == shared_tensordict.device - next_shared_tensordict.update_(td.get("next")) - shared_tensordict.update_(root_next_td) + for key, val in td.get("next").items(True, True): + next_shared_tensordict.get(key).copy_(val, non_blocking=True) + # next_shared_tensordict.update_(td.get("next")) + for key, val in root_next_td.items(True, True): + shared_tensordict.get(key).copy_(val, non_blocking=True) + # shared_tensordict.update_(root_next_td) if event is not None: event.record() event.synchronize() From e1a2206130109a76530730381a174b6bb9fd1af4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 14:37:06 +0100 Subject: [PATCH 19/67] rm cuda event --- torchrl/envs/batched_envs.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index eb84a98060a..865d560b0a7 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -722,10 +722,10 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] self._events = [] - if self.device.type == "cuda": - self.event = torch.cuda.Event() - else: - self.event = None + # if self.device.type == "cuda": + # self.event = torch.cuda.Event() + # else: + self.event = None with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: @@ -1052,10 +1052,10 @@ def _run_worker_pipe_shared_mem( ) -> None: if device is None: device = torch.device("cpu") - if device.type == "cuda": - event = torch.cuda.Event() - else: - event = None + # if device.type == "cuda": + # event = torch.cuda.Event() + # else: + event = None parent_pipe.close() pid = os.getpid() @@ -1132,12 +1132,12 @@ def _run_worker_pipe_shared_mem( raise RuntimeError("called 'init' before step") i += 1 td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) - for key, val in td.get("next").items(True, True): - next_shared_tensordict.get(key).copy_(val, non_blocking=True) - # next_shared_tensordict.update_(td.get("next")) - for key, val in root_next_td.items(True, True): - shared_tensordict.get(key).copy_(val, non_blocking=True) - # shared_tensordict.update_(root_next_td) + # for key, val in td.get("next").items(True, True): + # next_shared_tensordict.get(key).copy_(val, non_blocking=True) + next_shared_tensordict.update_(td.get("next")) + # for key, val in root_next_td.items(True, True): + # shared_tensordict.get(key).copy_(val, non_blocking=True) + shared_tensordict.update_(root_next_td) if event is not None: event.record() event.synchronize() From dc2caabb3787470aca86c56533230093d2144f80 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 14:40:36 +0100 Subject: [PATCH 20/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 4 ++-- torchrl/envs/batched_envs.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 33f08a6a8dc..ca77cb3a464 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -76,8 +76,8 @@ def make(envname=envname, gym_backend=gym_backend, device=device): with set_gym_backend(gym_backend): return GymEnv(envname, device=device) - env_make = EnvCreator(make) - penv = ParallelEnv(num_workers, env_make) + # env_make = EnvCreator(make) + penv = ParallelEnv(num_workers, [EnvCreator(make) for _ in range(num_workers)]) with torch.inference_mode(): # warmup penv.rollout(2) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 865d560b0a7..a2b01d7dc99 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -722,10 +722,10 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] self._events = [] - # if self.device.type == "cuda": - # self.event = torch.cuda.Event() - # else: - self.event = None + if self.device.type == "cuda": + self.event = torch.cuda.Event() + else: + self.event = None with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: @@ -1052,10 +1052,10 @@ def _run_worker_pipe_shared_mem( ) -> None: if device is None: device = torch.device("cpu") - # if device.type == "cuda": - # event = torch.cuda.Event() - # else: - event = None + if device.type == "cuda": + event = torch.cuda.Event() + else: + event = None parent_pipe.close() pid = os.getpid() From 01ffbf9686c750fda2ad7863c29a2130eb65512e Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 15:27:47 +0100 Subject: [PATCH 21/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 11 +++++++---- torchrl/envs/batched_envs.py | 20 +++++++++++--------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index ca77cb3a464..c47e2177b2f 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -70,14 +70,14 @@ def make(envname=envname, gym_backend=gym_backend): log.flush() # regular parallel env - for device in ("cpu", "cuda:0"): + for device in ("cuda:0", "cpu"): def make(envname=envname, gym_backend=gym_backend, device=device): with set_gym_backend(gym_backend): return GymEnv(envname, device=device) # env_make = EnvCreator(make) - penv = ParallelEnv(num_workers, [EnvCreator(make) for _ in range(num_workers)]) + penv = ParallelEnv(num_workers, EnvCreator(make)) with torch.inference_mode(): # warmup penv.rollout(2) @@ -111,8 +111,9 @@ def make(envname=envname, gym_backend=gym_backend, device=device): RandomPolicy(penv.action_spec), frames_per_batch=1024, total_frames=num_workers * 10_000, + device=device, + storing_device=device, ) - assert collector.env.device == torch.device(device) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 for i, data in enumerate(collector): @@ -122,7 +123,6 @@ def make(envname=envname, gym_backend=gym_backend, device=device): pbar.set_description( f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps" ) - assert data.device == torch.device(device) log.write( f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n" ) @@ -175,6 +175,7 @@ def make_env( frames_per_batch=1024, total_frames=num_workers * 10_000, device=device, + storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 @@ -219,6 +220,7 @@ def make_env( total_frames=num_workers * 10_000, num_sub_threads=num_workers // num_collectors, device=device, + storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 @@ -256,6 +258,7 @@ def make_env( frames_per_batch=1024, total_frames=num_workers * 10_000, device=device, + storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index a2b01d7dc99..f57c665da5f 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -477,7 +477,7 @@ def close(self) -> None: self.__dict__["_input_spec"] = None self.__dict__["_output_spec"] = None self._properties_set = False - self.event = None + self.cuda_events = None self._shutdown_workers() self.is_closed = True @@ -723,9 +723,11 @@ def _start_workers(self) -> None: self._workers = [] self._events = [] if self.device.type == "cuda": - self.event = torch.cuda.Event() + # self.streams = [torch.cuda.Stream(self.device) for _ in range(_num_workers)] + # self.cuda_events = [torch.cuda.Event(interprocess=True) for _ in range(_num_workers)] + self.cuda_events = torch.cuda.Event(interprocess=True) else: - self.event = None + self.cuda_events = None with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: @@ -813,9 +815,9 @@ def step_and_maybe_reset( self.shared_tensordict_parent.update_( tensordict.select(*self._env_input_keys, strict=False) ) - if self.event is not None: - self.event.record() - self.event.synchronize() + if self.cuda_events is not None: + self.cuda_events.record() + self.cuda_events.synchronize() for i in range(self.num_workers): self.parent_channels[i].send(("step_and_maybe_reset", None)) @@ -848,9 +850,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: self.shared_tensordict_parent.update_( tensordict.select(*self._env_input_keys, strict=False) ) - if self.event is not None: - self.event.record() - self.event.synchronize() + if self.cuda_events is not None: + self.cuda_events.record() + self.cuda_events.synchronize() for i in range(self.num_workers): self.parent_channels[i].send(("step", None)) From ac76ec3bc4efba8c201ccac9e547de86b2633165 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 15:35:22 +0100 Subject: [PATCH 22/67] amend --- torchrl/envs/batched_envs.py | 270 ++++++++++++++++++----------------- 1 file changed, 137 insertions(+), 133 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f57c665da5f..537e8bc2ca7 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -723,11 +723,12 @@ def _start_workers(self) -> None: self._workers = [] self._events = [] if self.device.type == "cuda": - # self.streams = [torch.cuda.Stream(self.device) for _ in range(_num_workers)] - # self.cuda_events = [torch.cuda.Event(interprocess=True) for _ in range(_num_workers)] - self.cuda_events = torch.cuda.Event(interprocess=True) + self.streams = [torch.cuda.Stream(self.device) for _ in range(_num_workers)] + self.cuda_events = [torch.cuda.Event(interprocess=True) for _ in range(_num_workers)] + # self.cuda_events = torch.cuda.Event(interprocess=True) else: - self.cuda_events = None + self.streams = [None] * _num_workers + self.cuda_events = [None] * _num_workers with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: @@ -1055,138 +1056,141 @@ def _run_worker_pipe_shared_mem( if device is None: device = torch.device("cpu") if device.type == "cuda": - event = torch.cuda.Event() + cuda_event = torch.cuda.Event() + stream = torch.cuda.Stream(device) else: - event = None - - parent_pipe.close() - pid = os.getpid() - if not isinstance(env_fun, EnvBase): - env = env_fun(**env_fun_kwargs) - else: - if env_fun_kwargs: - raise RuntimeError( - "env_fun_kwargs must be empty if an environment is passed to a process." - ) - env = env_fun - env = env.to(device) - del env_fun - - i = -1 - initialized = False - - child_pipe.send("started") - - while True: - try: - cmd, data = child_pipe.recv() - except EOFError as err: - raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err - if cmd == "seed": - if not initialized: - raise RuntimeError("call 'init' before closing") - # torch.manual_seed(data) - # np.random.seed(data) - new_seed = env.set_seed(data[0], static_seed=data[1]) - child_pipe.send(("seeded", new_seed)) - - elif cmd == "init": - if verbose: - print(f"initializing {pid}") - if initialized: - raise RuntimeError("worker already initialized") - i = 0 - next_shared_tensordict = shared_tensordict.get("next") - shared_tensordict = shared_tensordict.clone(False) - del shared_tensordict["next"] - - if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): + cuda_event = None + stream = None + + with torch.cuda.StreamContext(stream): + parent_pipe.close() + pid = os.getpid() + if not isinstance(env_fun, EnvBase): + env = env_fun(**env_fun_kwargs) + else: + if env_fun_kwargs: raise RuntimeError( - "tensordict must be placed in shared memory (share_memory_() or memmap_())" + "env_fun_kwargs must be empty if an environment is passed to a process." ) - initialized = True - - elif cmd == "reset": - if verbose: - print(f"resetting worker {pid}") - if not initialized: - raise RuntimeError("call 'init' before resetting") - cur_td = env._reset(tensordict=data) - shared_tensordict.update_(cur_td) - if event is not None: - event.record() - event.synchronize() - mp_event.set() - - elif cmd == "step": - if not initialized: - raise RuntimeError("called 'init' before step") - i += 1 - next_td = env._step(shared_tensordict) - next_shared_tensordict.update_(next_td) - if event is not None: - event.record() - event.synchronize() - mp_event.set() - - elif cmd == "step_and_maybe_reset": - if not initialized: - raise RuntimeError("called 'init' before step") - i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) - # for key, val in td.get("next").items(True, True): - # next_shared_tensordict.get(key).copy_(val, non_blocking=True) - next_shared_tensordict.update_(td.get("next")) - # for key, val in root_next_td.items(True, True): - # shared_tensordict.get(key).copy_(val, non_blocking=True) - shared_tensordict.update_(root_next_td) - if event is not None: - event.record() - event.synchronize() - mp_event.set() - - elif cmd == "close": - del shared_tensordict, data - if not initialized: - raise RuntimeError("call 'init' before closing") - env.close() - del env - mp_event.set() - child_pipe.close() - if verbose: - print(f"{pid} closed") - break - - elif cmd == "load_state_dict": - env.load_state_dict(data) - mp_event.set() - - elif cmd == "state_dict": - state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) - msg = "state_dict" - child_pipe.send((msg, state_dict)) + env = env_fun + env = env.to(device) + del env_fun - else: - err_msg = f"{cmd} from env" + i = -1 + initialized = False + + child_pipe.send("started") + + while True: try: - attr = getattr(env, cmd) - if callable(attr): - args, kwargs = data - args_replace = [] - for _arg in args: - if isinstance(_arg, str) and _arg == "_self": - continue - else: - args_replace.append(_arg) - result = attr(*args_replace, **kwargs) - else: - result = attr - except Exception as err: - raise AttributeError( - f"querying {err_msg} resulted in an error." - ) from err - if cmd not in ("to"): - child_pipe.send(("_".join([cmd, "done"]), result)) + cmd, data = child_pipe.recv() + except EOFError as err: + raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err + if cmd == "seed": + if not initialized: + raise RuntimeError("call 'init' before closing") + # torch.manual_seed(data) + # np.random.seed(data) + new_seed = env.set_seed(data[0], static_seed=data[1]) + child_pipe.send(("seeded", new_seed)) + + elif cmd == "init": + if verbose: + print(f"initializing {pid}") + if initialized: + raise RuntimeError("worker already initialized") + i = 0 + next_shared_tensordict = shared_tensordict.get("next") + shared_tensordict = shared_tensordict.clone(False) + del shared_tensordict["next"] + + if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): + raise RuntimeError( + "tensordict must be placed in shared memory (share_memory_() or memmap_())" + ) + initialized = True + + elif cmd == "reset": + if verbose: + print(f"resetting worker {pid}") + if not initialized: + raise RuntimeError("call 'init' before resetting") + cur_td = env._reset(tensordict=data) + shared_tensordict.update_(cur_td) + if cuda_event is not None: + cuda_event.record() + cuda_event.synchronize() + mp_event.set() + + elif cmd == "step": + if not initialized: + raise RuntimeError("called 'init' before step") + i += 1 + next_td = env._step(shared_tensordict) + next_shared_tensordict.update_(next_td) + if cuda_event is not None: + cuda_event.record() + cuda_event.synchronize() + mp_event.set() + + elif cmd == "step_and_maybe_reset": + if not initialized: + raise RuntimeError("called 'init' before step") + i += 1 + td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) + # for key, val in td.get("next").items(True, True): + # next_shared_tensordict.get(key).copy_(val, non_blocking=True) + next_shared_tensordict.update_(td.get("next")) + # for key, val in root_next_td.items(True, True): + # shared_tensordict.get(key).copy_(val, non_blocking=True) + shared_tensordict.update_(root_next_td) + if cuda_event is not None: + cuda_event.record() + cuda_event.synchronize() + mp_event.set() + + elif cmd == "close": + del shared_tensordict, data + if not initialized: + raise RuntimeError("call 'init' before closing") + env.close() + del env + mp_event.set() + child_pipe.close() + if verbose: + print(f"{pid} closed") + break + + elif cmd == "load_state_dict": + env.load_state_dict(data) + mp_event.set() + + elif cmd == "state_dict": + state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) + msg = "state_dict" + child_pipe.send((msg, state_dict)) + else: - # don't send env through pipe - child_pipe.send(("_".join([cmd, "done"]), None)) + err_msg = f"{cmd} from env" + try: + attr = getattr(env, cmd) + if callable(attr): + args, kwargs = data + args_replace = [] + for _arg in args: + if isinstance(_arg, str) and _arg == "_self": + continue + else: + args_replace.append(_arg) + result = attr(*args_replace, **kwargs) + else: + result = attr + except Exception as err: + raise AttributeError( + f"querying {err_msg} resulted in an error." + ) from err + if cmd not in ("to"): + child_pipe.send(("_".join([cmd, "done"]), result)) + else: + # don't send env through pipe + child_pipe.send(("_".join([cmd, "done"]), None)) From ceab010f099634e39d9e85a08dc2ff7002c14d0e Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 15:38:20 +0100 Subject: [PATCH 23/67] amend --- torchrl/envs/batched_envs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 537e8bc2ca7..cb248c0cec0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -723,8 +723,8 @@ def _start_workers(self) -> None: self._workers = [] self._events = [] if self.device.type == "cuda": - self.streams = [torch.cuda.Stream(self.device) for _ in range(_num_workers)] - self.cuda_events = [torch.cuda.Event(interprocess=True) for _ in range(_num_workers)] + self.streams = torch.cuda.Stream(self.device) + self.cuda_events = torch.cuda.Event(interprocess=True) # self.cuda_events = torch.cuda.Event(interprocess=True) else: self.streams = [None] * _num_workers From f0327c9b4bc16ae499ef8ee6f019877407dc03c5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 16:15:57 +0100 Subject: [PATCH 24/67] amend --- torchrl/envs/batched_envs.py | 561 +++++++++++++++++++++++++++++++---- 1 file changed, 500 insertions(+), 61 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index cb248c0cec0..444d7cde4cc 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -477,7 +477,7 @@ def close(self) -> None: self.__dict__["_input_spec"] = None self.__dict__["_output_spec"] = None self._properties_set = False - self.cuda_events = None + self._cuda_events = None self._shutdown_workers() self.is_closed = True @@ -721,42 +721,42 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] - self._events = [] if self.device.type == "cuda": - self.streams = torch.cuda.Stream(self.device) - self.cuda_events = torch.cuda.Event(interprocess=True) - # self.cuda_events = torch.cuda.Event(interprocess=True) + func = _run_worker_pipe_cuda + self._cuda_streams = [torch.cuda.Stream(self.device) for _ in range(_num_workers)] + self._cuda_events = [torch.cuda.Event(interprocess=True) for _ in range(_num_workers)] + self._events = None + kwargs = [{"stream": self._cuda_streams[i], "cuda_event": self._cuda_events[i]} for i in range(_num_workers)] else: - self.streams = [None] * _num_workers - self.cuda_events = [None] * _num_workers + func = _run_worker_pipe_shared_mem + kwargs = [{} for i in range(_num_workers)] + self._cuda_streams = None + self._cuda_events = None + self._events = [ctx.Event() for _ in range(_num_workers)] with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: print(f"initiating worker {idx}") # No certainty which module multiprocessing_context is parent_pipe, child_pipe = ctx.Pipe() - event = ctx.Event() - self._events.append(event) env_fun = self.create_env_fn[idx] if not isinstance(env_fun, EnvCreator): env_fun = CloudpickleWrapper(env_fun) - + kwargs.update({ + "parent_pipe": parent_pipe, + "child_pipe": child_pipe, + "env_fun": env_fun, + "env_fun_kwargs": self.create_env_kwargs[idx], + "shared_tensordict": self.shared_tensordicts[idx], + "_selected_input_keys": _selected_input_keys, + "_selected_reset_keys": self._selected_reset_keys, + "_selected_step_keys": self._selected_step_keys, + "has_lazy_inputs": self.has_lazy_inputs, + }) process = _ProcessNoWarn( - target=_run_worker_pipe_shared_mem, + target=func, num_threads=self.num_sub_threads, - args=( - parent_pipe, - child_pipe, - env_fun, - self.create_env_kwargs[idx], - self.device, - event, - self.shared_tensordicts[idx], - self._selected_input_keys, - self._selected_reset_keys, - self._selected_step_keys, - self.has_lazy_inputs, - ), + kwargs=kwargs, ) process.daemon = True process.start() @@ -816,16 +816,21 @@ def step_and_maybe_reset( self.shared_tensordict_parent.update_( tensordict.select(*self._env_input_keys, strict=False) ) - if self.cuda_events is not None: - self.cuda_events.record() - self.cuda_events.synchronize() for i in range(self.num_workers): self.parent_channels[i].send(("step_and_maybe_reset", None)) - for i in range(self.num_workers): - event = self._events[i] - event.wait() - event.clear() + + if self._events is not None: + # CPU case + for i in range(self.num_workers): + event = self._events[i] + event.wait() + event.clear() + else: + # CUDA case + for i in range(self.num_workers): + event = self._cuda_events[i] + event.wait(self._cuda_streams[i]) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -851,16 +856,19 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: self.shared_tensordict_parent.update_( tensordict.select(*self._env_input_keys, strict=False) ) - if self.cuda_events is not None: - self.cuda_events.record() - self.cuda_events.synchronize() for i in range(self.num_workers): self.parent_channels[i].send(("step", None)) - for i in range(self.num_workers): - event = self._events[i] - event.wait() - event.clear() + if self._events is not None: + # CPU case + for i in range(self.num_workers): + event = self._events[i] + event.wait() + event.clear() + else: + for i in range(self.num_workers): + event = self._cuda_events[i] + event.wait(self._cuda_streams[i]) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -921,10 +929,16 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: channel.send(out) workers.append(i) - for i in workers: - event = self._events[i] - event.wait() - event.clear() + if self._events is not None: + # CPU case + for i in workers: + event = self._events[i] + event.wait() + event.clear() + else: + for i in workers: + event = self._cuda_events[i] + event.wait(self._cuda_streams[i]) selected_output_keys = self._selected_reset_keys_filt if self._single_task: @@ -1038,7 +1052,294 @@ def _recursively_strip_locks_from_state_dict(state_dict: OrderedDict) -> Ordered } ) +def _run_worker_pipe_shared_mem( + parent_pipe: connection.Connection, + child_pipe: connection.Connection, + env_fun: Union[EnvBase, Callable], + env_fun_kwargs: Dict[str, Any], + device: DEVICE_TYPING = None, + mp_event: mp.Event = None, + shared_tensordict: TensorDictBase = None, + _selected_input_keys=None, + _selected_reset_keys=None, + _selected_step_keys=None, + has_lazy_inputs: bool = False, + verbose: bool = False, +) -> None: + parent_pipe.close() + pid = os.getpid() + if not isinstance(env_fun, EnvBase): + env = env_fun(**env_fun_kwargs) + else: + if env_fun_kwargs: + raise RuntimeError( + "env_fun_kwargs must be empty if an environment is passed to a process." + ) + env = env_fun + env = env.to(device) + del env_fun + + i = -1 + initialized = False + + child_pipe.send("started") + + while True: + try: + cmd, data = child_pipe.recv() + except EOFError as err: + raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err + if cmd == "seed": + if not initialized: + raise RuntimeError("call 'init' before closing") + # torch.manual_seed(data) + # np.random.seed(data) + new_seed = env.set_seed(data[0], static_seed=data[1]) + child_pipe.send(("seeded", new_seed)) + + elif cmd == "init": + if verbose: + print(f"initializing {pid}") + if initialized: + raise RuntimeError("worker already initialized") + i = 0 + next_shared_tensordict = shared_tensordict.get("next") + shared_tensordict = shared_tensordict.clone(False) + del shared_tensordict["next"] + + if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): + raise RuntimeError( + "tensordict must be placed in shared memory (share_memory_() or memmap_())" + ) + initialized = True + + elif cmd == "reset": + if verbose: + print(f"resetting worker {pid}") + if not initialized: + raise RuntimeError("call 'init' before resetting") + cur_td = env._reset(tensordict=data) + shared_tensordict.update_(cur_td) + mp_event.set() + + elif cmd == "step": + if not initialized: + raise RuntimeError("called 'init' before step") + i += 1 + next_td = env._step(shared_tensordict) + next_shared_tensordict.update_(next_td) + mp_event.set() + + elif cmd == "step_and_maybe_reset": + if not initialized: + raise RuntimeError("called 'init' before step") + i += 1 + td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) + # for key, val in td.get("next").items(True, True): + # next_shared_tensordict.get(key).copy_(val, non_blocking=True) + next_shared_tensordict.update_(td.get("next")) + # for key, val in root_next_td.items(True, True): + # shared_tensordict.get(key).copy_(val, non_blocking=True) + shared_tensordict.update_(root_next_td) + mp_event.set() + + elif cmd == "close": + del shared_tensordict, data + if not initialized: + raise RuntimeError("call 'init' before closing") + env.close() + del env + mp_event.set() + child_pipe.close() + if verbose: + print(f"{pid} closed") + break + + elif cmd == "load_state_dict": + env.load_state_dict(data) + mp_event.set() + + elif cmd == "state_dict": + state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) + msg = "state_dict" + child_pipe.send((msg, state_dict)) + + else: + err_msg = f"{cmd} from env" + try: + attr = getattr(env, cmd) + if callable(attr): + args, kwargs = data + args_replace = [] + for _arg in args: + if isinstance(_arg, str) and _arg == "_self": + continue + else: + args_replace.append(_arg) + result = attr(*args_replace, **kwargs) + else: + result = attr + except Exception as err: + raise AttributeError( + f"querying {err_msg} resulted in an error." + ) from err + if cmd not in ("to"): + child_pipe.send(("_".join([cmd, "done"]), result)) + else: + # don't send env through pipe + child_pipe.send(("_".join([cmd, "done"]), None)) + + +def _run_worker_pipe_shared_mem( + parent_pipe: connection.Connection, + child_pipe: connection.Connection, + env_fun: Union[EnvBase, Callable], + env_fun_kwargs: Dict[str, Any], + device: DEVICE_TYPING = None, + mp_event: mp.Event = None, + shared_tensordict: TensorDictBase = None, + _selected_input_keys=None, + _selected_reset_keys=None, + _selected_step_keys=None, + has_lazy_inputs: bool = False, + verbose: bool = False, +) -> None: + parent_pipe.close() + +pid = os.getpid() +if not isinstance(env_fun, EnvBase): + env = env_fun(**env_fun_kwargs) +else: + if env_fun_kwargs: + raise RuntimeError( + "env_fun_kwargs must be empty if an environment is passed to a process." + ) + env = env_fun +env = env.to(device) +del env_fun + +i = -1 +initialized = False + +child_pipe.send("started") + +while True: + try: + cmd, data = child_pipe.recv() + except EOFError as err: + raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err + if cmd == "seed": + if not initialized: + raise RuntimeError("call 'init' before closing") + # torch.manual_seed(data) + # np.random.seed(data) + new_seed = env.set_seed(data[0], static_seed=data[1]) + child_pipe.send(("seeded", new_seed)) + + elif cmd == "init": + if verbose: + print(f"initializing {pid}") + if initialized: + raise RuntimeError("worker already initialized") + i = 0 + next_shared_tensordict = shared_tensordict.get("next") + shared_tensordict = shared_tensordict.clone(False) + del shared_tensordict["next"] + + if not ( + shared_tensordict.is_shared() or shared_tensordict.is_memmap()): + raise RuntimeError( + "tensordict must be placed in shared memory (share_memory_() or memmap_())" + ) + initialized = True + + elif cmd == "reset": + if verbose: + print(f"resetting worker {pid}") + if not initialized: + raise RuntimeError("call 'init' before resetting") + cur_td = env._reset(tensordict=data) + shared_tensordict.update_(cur_td) + if cuda_event is not None: + cuda_event.record() + cuda_event.synchronize() + mp_event.set() + + elif cmd == "step": + if not initialized: + raise RuntimeError("called 'init' before step") + i += 1 + next_td = env._step(shared_tensordict) + next_shared_tensordict.update_(next_td) + if cuda_event is not None: + cuda_event.record() + cuda_event.synchronize() + mp_event.set() + + elif cmd == "step_and_maybe_reset": + if not initialized: + raise RuntimeError("called 'init' before step") + i += 1 + td, root_next_td = env.step_and_maybe_reset( + shared_tensordict.clone(False) + ) + # for key, val in td.get("next").items(True, True): + # next_shared_tensordict.get(key).copy_(val, non_blocking=True) + next_shared_tensordict.update_(td.get("next")) + # for key, val in root_next_td.items(True, True): + # shared_tensordict.get(key).copy_(val, non_blocking=True) + shared_tensordict.update_(root_next_td) + if cuda_event is not None: + cuda_event.record() + cuda_event.synchronize() + mp_event.set() + + elif cmd == "close": + del shared_tensordict, data + if not initialized: + raise RuntimeError("call 'init' before closing") + env.close() + del env + mp_event.set() + child_pipe.close() + if verbose: + print(f"{pid} closed") + break + + elif cmd == "load_state_dict": + env.load_state_dict(data) + mp_event.set() + + elif cmd == "state_dict": + state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) + msg = "state_dict" + child_pipe.send((msg, state_dict)) + + else: + err_msg = f"{cmd} from env" + try: + attr = getattr(env, cmd) + if callable(attr): + args, kwargs = data + args_replace = [] + for _arg in args: + if isinstance(_arg, str) and _arg == "_self": + continue + else: + args_replace.append(_arg) + result = attr(*args_replace, **kwargs) + else: + result = attr + except Exception as err: + raise AttributeError( + f"querying {err_msg} resulted in an error." + ) from err + if cmd not in ("to"): + child_pipe.send(("_".join([cmd, "done"]), result)) + else: + # don't send env through pipe + child_pipe.send(("_".join([cmd, "done"]), None)) def _run_worker_pipe_shared_mem( parent_pipe: connection.Connection, child_pipe: connection.Connection, @@ -1053,15 +1354,154 @@ def _run_worker_pipe_shared_mem( has_lazy_inputs: bool = False, verbose: bool = False, ) -> None: - if device is None: - device = torch.device("cpu") - if device.type == "cuda": - cuda_event = torch.cuda.Event() - stream = torch.cuda.Stream(device) + parent_pipe.close() + pid = os.getpid() + if not isinstance(env_fun, EnvBase): + env = env_fun(**env_fun_kwargs) else: - cuda_event = None - stream = None + if env_fun_kwargs: + raise RuntimeError( + "env_fun_kwargs must be empty if an environment is passed to a process." + ) + env = env_fun + env = env.to(device) + del env_fun + + i = -1 + initialized = False + + child_pipe.send("started") + + while True: + try: + cmd, data = child_pipe.recv() + except EOFError as err: + raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err + if cmd == "seed": + if not initialized: + raise RuntimeError("call 'init' before closing") + # torch.manual_seed(data) + # np.random.seed(data) + new_seed = env.set_seed(data[0], static_seed=data[1]) + child_pipe.send(("seeded", new_seed)) + + elif cmd == "init": + if verbose: + print(f"initializing {pid}") + if initialized: + raise RuntimeError("worker already initialized") + i = 0 + next_shared_tensordict = shared_tensordict.get("next") + shared_tensordict = shared_tensordict.clone(False) + del shared_tensordict["next"] + + if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): + raise RuntimeError( + "tensordict must be placed in shared memory (share_memory_() or memmap_())" + ) + initialized = True + + elif cmd == "reset": + if verbose: + print(f"resetting worker {pid}") + if not initialized: + raise RuntimeError("call 'init' before resetting") + cur_td = env._reset(tensordict=data) + shared_tensordict.update_(cur_td) + if cuda_event is not None: + cuda_event.record() + cuda_event.synchronize() + mp_event.set() + + elif cmd == "step": + if not initialized: + raise RuntimeError("called 'init' before step") + i += 1 + next_td = env._step(shared_tensordict) + next_shared_tensordict.update_(next_td) + if cuda_event is not None: + cuda_event.record() + cuda_event.synchronize() + mp_event.set() + + elif cmd == "step_and_maybe_reset": + if not initialized: + raise RuntimeError("called 'init' before step") + i += 1 + td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) + # for key, val in td.get("next").items(True, True): + # next_shared_tensordict.get(key).copy_(val, non_blocking=True) + next_shared_tensordict.update_(td.get("next")) + # for key, val in root_next_td.items(True, True): + # shared_tensordict.get(key).copy_(val, non_blocking=True) + shared_tensordict.update_(root_next_td) + if cuda_event is not None: + cuda_event.record() + cuda_event.synchronize() + mp_event.set() + + elif cmd == "close": + del shared_tensordict, data + if not initialized: + raise RuntimeError("call 'init' before closing") + env.close() + del env + mp_event.set() + child_pipe.close() + if verbose: + print(f"{pid} closed") + break + + elif cmd == "load_state_dict": + env.load_state_dict(data) + mp_event.set() + + elif cmd == "state_dict": + state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) + msg = "state_dict" + child_pipe.send((msg, state_dict)) + else: + err_msg = f"{cmd} from env" + try: + attr = getattr(env, cmd) + if callable(attr): + args, kwargs = data + args_replace = [] + for _arg in args: + if isinstance(_arg, str) and _arg == "_self": + continue + else: + args_replace.append(_arg) + result = attr(*args_replace, **kwargs) + else: + result = attr + except Exception as err: + raise AttributeError( + f"querying {err_msg} resulted in an error." + ) from err + if cmd not in ("to"): + child_pipe.send(("_".join([cmd, "done"]), result)) + else: + # don't send env through pipe + child_pipe.send(("_".join([cmd, "done"]), None)) + + +def _run_worker_pipe_cuda( + parent_pipe: connection.Connection, + child_pipe: connection.Connection, + env_fun: Union[EnvBase, Callable], + env_fun_kwargs: Dict[str, Any], + device: DEVICE_TYPING = None, + stream: torch.cuda.Stream = None, + cuda_event: torch.cuda.Event = None, + shared_tensordict: TensorDictBase = None, + _selected_input_keys=None, + _selected_reset_keys=None, + _selected_step_keys=None, + has_lazy_inputs: bool = False, + verbose: bool = False, +) -> None: with torch.cuda.StreamContext(stream): parent_pipe.close() pid = os.getpid() @@ -1104,7 +1544,8 @@ def _run_worker_pipe_shared_mem( shared_tensordict = shared_tensordict.clone(False) del shared_tensordict["next"] - if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): + if not ( + shared_tensordict.is_shared() or shared_tensordict.is_memmap()): raise RuntimeError( "tensordict must be placed in shared memory (share_memory_() or memmap_())" ) @@ -1117,10 +1558,8 @@ def _run_worker_pipe_shared_mem( raise RuntimeError("call 'init' before resetting") cur_td = env._reset(tensordict=data) shared_tensordict.update_(cur_td) - if cuda_event is not None: - cuda_event.record() - cuda_event.synchronize() - mp_event.set() + stream.record(cuda_event) + stream.synchronize() elif cmd == "step": if not initialized: @@ -1128,25 +1567,25 @@ def _run_worker_pipe_shared_mem( i += 1 next_td = env._step(shared_tensordict) next_shared_tensordict.update_(next_td) - if cuda_event is not None: - cuda_event.record() - cuda_event.synchronize() + stream.record(cuda_event) + stream.synchronize() mp_event.set() elif cmd == "step_and_maybe_reset": if not initialized: raise RuntimeError("called 'init' before step") i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) + td, root_next_td = env.step_and_maybe_reset( + shared_tensordict.clone(False) + ) # for key, val in td.get("next").items(True, True): # next_shared_tensordict.get(key).copy_(val, non_blocking=True) next_shared_tensordict.update_(td.get("next")) # for key, val in root_next_td.items(True, True): # shared_tensordict.get(key).copy_(val, non_blocking=True) shared_tensordict.update_(root_next_td) - if cuda_event is not None: - cuda_event.record() - cuda_event.synchronize() + stream.record(cuda_event) + stream.synchronize() mp_event.set() elif cmd == "close": From 518b3d1ac4274b2ff1651cce973c1696928f3c00 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 16:18:17 +0100 Subject: [PATCH 25/67] amend --- torchrl/envs/batched_envs.py | 299 +---------------------------------- 1 file changed, 2 insertions(+), 297 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 444d7cde4cc..27a7f2bbb3a 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -866,6 +866,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: event.wait() event.clear() else: + # CUDA case for i in range(self.num_workers): event = self._cuda_events[i] event.wait(self._cuda_streams[i]) @@ -936,6 +937,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: event.wait() event.clear() else: + # CUDA case for i in workers: event = self._cuda_events[i] event.wait(self._cuda_streams[i]) @@ -1190,303 +1192,6 @@ def _run_worker_pipe_shared_mem( child_pipe.send(("_".join([cmd, "done"]), None)) -def _run_worker_pipe_shared_mem( - parent_pipe: connection.Connection, - child_pipe: connection.Connection, - env_fun: Union[EnvBase, Callable], - env_fun_kwargs: Dict[str, Any], - device: DEVICE_TYPING = None, - mp_event: mp.Event = None, - shared_tensordict: TensorDictBase = None, - _selected_input_keys=None, - _selected_reset_keys=None, - _selected_step_keys=None, - has_lazy_inputs: bool = False, - verbose: bool = False, -) -> None: - parent_pipe.close() - - -pid = os.getpid() -if not isinstance(env_fun, EnvBase): - env = env_fun(**env_fun_kwargs) -else: - if env_fun_kwargs: - raise RuntimeError( - "env_fun_kwargs must be empty if an environment is passed to a process." - ) - env = env_fun -env = env.to(device) -del env_fun - -i = -1 -initialized = False - -child_pipe.send("started") - -while True: - try: - cmd, data = child_pipe.recv() - except EOFError as err: - raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err - if cmd == "seed": - if not initialized: - raise RuntimeError("call 'init' before closing") - # torch.manual_seed(data) - # np.random.seed(data) - new_seed = env.set_seed(data[0], static_seed=data[1]) - child_pipe.send(("seeded", new_seed)) - - elif cmd == "init": - if verbose: - print(f"initializing {pid}") - if initialized: - raise RuntimeError("worker already initialized") - i = 0 - next_shared_tensordict = shared_tensordict.get("next") - shared_tensordict = shared_tensordict.clone(False) - del shared_tensordict["next"] - - if not ( - shared_tensordict.is_shared() or shared_tensordict.is_memmap()): - raise RuntimeError( - "tensordict must be placed in shared memory (share_memory_() or memmap_())" - ) - initialized = True - - elif cmd == "reset": - if verbose: - print(f"resetting worker {pid}") - if not initialized: - raise RuntimeError("call 'init' before resetting") - cur_td = env._reset(tensordict=data) - shared_tensordict.update_(cur_td) - if cuda_event is not None: - cuda_event.record() - cuda_event.synchronize() - mp_event.set() - - elif cmd == "step": - if not initialized: - raise RuntimeError("called 'init' before step") - i += 1 - next_td = env._step(shared_tensordict) - next_shared_tensordict.update_(next_td) - if cuda_event is not None: - cuda_event.record() - cuda_event.synchronize() - mp_event.set() - - elif cmd == "step_and_maybe_reset": - if not initialized: - raise RuntimeError("called 'init' before step") - i += 1 - td, root_next_td = env.step_and_maybe_reset( - shared_tensordict.clone(False) - ) - # for key, val in td.get("next").items(True, True): - # next_shared_tensordict.get(key).copy_(val, non_blocking=True) - next_shared_tensordict.update_(td.get("next")) - # for key, val in root_next_td.items(True, True): - # shared_tensordict.get(key).copy_(val, non_blocking=True) - shared_tensordict.update_(root_next_td) - if cuda_event is not None: - cuda_event.record() - cuda_event.synchronize() - mp_event.set() - - elif cmd == "close": - del shared_tensordict, data - if not initialized: - raise RuntimeError("call 'init' before closing") - env.close() - del env - mp_event.set() - child_pipe.close() - if verbose: - print(f"{pid} closed") - break - - elif cmd == "load_state_dict": - env.load_state_dict(data) - mp_event.set() - - elif cmd == "state_dict": - state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) - msg = "state_dict" - child_pipe.send((msg, state_dict)) - - else: - err_msg = f"{cmd} from env" - try: - attr = getattr(env, cmd) - if callable(attr): - args, kwargs = data - args_replace = [] - for _arg in args: - if isinstance(_arg, str) and _arg == "_self": - continue - else: - args_replace.append(_arg) - result = attr(*args_replace, **kwargs) - else: - result = attr - except Exception as err: - raise AttributeError( - f"querying {err_msg} resulted in an error." - ) from err - if cmd not in ("to"): - child_pipe.send(("_".join([cmd, "done"]), result)) - else: - # don't send env through pipe - child_pipe.send(("_".join([cmd, "done"]), None)) -def _run_worker_pipe_shared_mem( - parent_pipe: connection.Connection, - child_pipe: connection.Connection, - env_fun: Union[EnvBase, Callable], - env_fun_kwargs: Dict[str, Any], - device: DEVICE_TYPING = None, - mp_event: mp.Event = None, - shared_tensordict: TensorDictBase = None, - _selected_input_keys=None, - _selected_reset_keys=None, - _selected_step_keys=None, - has_lazy_inputs: bool = False, - verbose: bool = False, -) -> None: - parent_pipe.close() - pid = os.getpid() - if not isinstance(env_fun, EnvBase): - env = env_fun(**env_fun_kwargs) - else: - if env_fun_kwargs: - raise RuntimeError( - "env_fun_kwargs must be empty if an environment is passed to a process." - ) - env = env_fun - env = env.to(device) - del env_fun - - i = -1 - initialized = False - - child_pipe.send("started") - - while True: - try: - cmd, data = child_pipe.recv() - except EOFError as err: - raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err - if cmd == "seed": - if not initialized: - raise RuntimeError("call 'init' before closing") - # torch.manual_seed(data) - # np.random.seed(data) - new_seed = env.set_seed(data[0], static_seed=data[1]) - child_pipe.send(("seeded", new_seed)) - - elif cmd == "init": - if verbose: - print(f"initializing {pid}") - if initialized: - raise RuntimeError("worker already initialized") - i = 0 - next_shared_tensordict = shared_tensordict.get("next") - shared_tensordict = shared_tensordict.clone(False) - del shared_tensordict["next"] - - if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): - raise RuntimeError( - "tensordict must be placed in shared memory (share_memory_() or memmap_())" - ) - initialized = True - - elif cmd == "reset": - if verbose: - print(f"resetting worker {pid}") - if not initialized: - raise RuntimeError("call 'init' before resetting") - cur_td = env._reset(tensordict=data) - shared_tensordict.update_(cur_td) - if cuda_event is not None: - cuda_event.record() - cuda_event.synchronize() - mp_event.set() - - elif cmd == "step": - if not initialized: - raise RuntimeError("called 'init' before step") - i += 1 - next_td = env._step(shared_tensordict) - next_shared_tensordict.update_(next_td) - if cuda_event is not None: - cuda_event.record() - cuda_event.synchronize() - mp_event.set() - - elif cmd == "step_and_maybe_reset": - if not initialized: - raise RuntimeError("called 'init' before step") - i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) - # for key, val in td.get("next").items(True, True): - # next_shared_tensordict.get(key).copy_(val, non_blocking=True) - next_shared_tensordict.update_(td.get("next")) - # for key, val in root_next_td.items(True, True): - # shared_tensordict.get(key).copy_(val, non_blocking=True) - shared_tensordict.update_(root_next_td) - if cuda_event is not None: - cuda_event.record() - cuda_event.synchronize() - mp_event.set() - - elif cmd == "close": - del shared_tensordict, data - if not initialized: - raise RuntimeError("call 'init' before closing") - env.close() - del env - mp_event.set() - child_pipe.close() - if verbose: - print(f"{pid} closed") - break - - elif cmd == "load_state_dict": - env.load_state_dict(data) - mp_event.set() - - elif cmd == "state_dict": - state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) - msg = "state_dict" - child_pipe.send((msg, state_dict)) - - else: - err_msg = f"{cmd} from env" - try: - attr = getattr(env, cmd) - if callable(attr): - args, kwargs = data - args_replace = [] - for _arg in args: - if isinstance(_arg, str) and _arg == "_self": - continue - else: - args_replace.append(_arg) - result = attr(*args_replace, **kwargs) - else: - result = attr - except Exception as err: - raise AttributeError( - f"querying {err_msg} resulted in an error." - ) from err - if cmd not in ("to"): - child_pipe.send(("_".join([cmd, "done"]), result)) - else: - # don't send env through pipe - child_pipe.send(("_".join([cmd, "done"]), None)) - - def _run_worker_pipe_cuda( parent_pipe: connection.Connection, child_pipe: connection.Connection, From 47dd93b969e29729cc667f5212f1b3aad4367908 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 16:20:19 +0100 Subject: [PATCH 26/67] amend --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 27a7f2bbb3a..ad6830f7830 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -742,7 +742,7 @@ def _start_workers(self) -> None: env_fun = self.create_env_fn[idx] if not isinstance(env_fun, EnvCreator): env_fun = CloudpickleWrapper(env_fun) - kwargs.update({ + kwargs[idx].update({ "parent_pipe": parent_pipe, "child_pipe": child_pipe, "env_fun": env_fun, From 354fb6f744fac94868230c6fa9ab86724f07128f Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 16:21:43 +0100 Subject: [PATCH 27/67] amend --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index ad6830f7830..017db28d4a8 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -748,7 +748,7 @@ def _start_workers(self) -> None: "env_fun": env_fun, "env_fun_kwargs": self.create_env_kwargs[idx], "shared_tensordict": self.shared_tensordicts[idx], - "_selected_input_keys": _selected_input_keys, + "_selected_input_keys": self._selected_input_keys, "_selected_reset_keys": self._selected_reset_keys, "_selected_step_keys": self._selected_step_keys, "has_lazy_inputs": self.has_lazy_inputs, From 53d5f9ae4eb1e1a71da7bc0f6ec261066d27d9ab Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 16:23:36 +0100 Subject: [PATCH 28/67] amend --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 017db28d4a8..139b4d97fd9 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -756,7 +756,7 @@ def _start_workers(self) -> None: process = _ProcessNoWarn( target=func, num_threads=self.num_sub_threads, - kwargs=kwargs, + kwargs=kwargs[idx], ) process.daemon = True process.start() From 78c00e84b753fc7a8c7531e7014949fe5b549f24 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 16:27:14 +0100 Subject: [PATCH 29/67] amend --- torchrl/envs/batched_envs.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 139b4d97fd9..a52200898a6 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -723,14 +723,14 @@ def _start_workers(self) -> None: self._workers = [] if self.device.type == "cuda": func = _run_worker_pipe_cuda - self._cuda_streams = [torch.cuda.Stream(self.device) for _ in range(_num_workers)] + self._cuda_stream = torch.cuda.Stream(self.device) self._cuda_events = [torch.cuda.Event(interprocess=True) for _ in range(_num_workers)] self._events = None - kwargs = [{"stream": self._cuda_streams[i], "cuda_event": self._cuda_events[i]} for i in range(_num_workers)] + kwargs = [{"cuda_event": self._cuda_events[i]} for i in range(_num_workers)] else: func = _run_worker_pipe_shared_mem kwargs = [{} for i in range(_num_workers)] - self._cuda_streams = None + self._cuda_stream = None self._cuda_events = None self._events = [ctx.Event() for _ in range(_num_workers)] with clear_mpi_env_vars(): @@ -830,7 +830,7 @@ def step_and_maybe_reset( # CUDA case for i in range(self.num_workers): event = self._cuda_events[i] - event.wait(self._cuda_streams[i]) + self._cuda_stream.wait_event(event) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -869,7 +869,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # CUDA case for i in range(self.num_workers): event = self._cuda_events[i] - event.wait(self._cuda_streams[i]) + self._cuda_stream.wait_event(event) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -940,7 +940,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # CUDA case for i in workers: event = self._cuda_events[i] - event.wait(self._cuda_streams[i]) + self._cuda_stream.wait_event(event) selected_output_keys = self._selected_reset_keys_filt if self._single_task: @@ -1198,7 +1198,6 @@ def _run_worker_pipe_cuda( env_fun: Union[EnvBase, Callable], env_fun_kwargs: Dict[str, Any], device: DEVICE_TYPING = None, - stream: torch.cuda.Stream = None, cuda_event: torch.cuda.Event = None, shared_tensordict: TensorDictBase = None, _selected_input_keys=None, @@ -1207,6 +1206,7 @@ def _run_worker_pipe_cuda( has_lazy_inputs: bool = False, verbose: bool = False, ) -> None: + stream = torch.cuda.Stream(device) with torch.cuda.StreamContext(stream): parent_pipe.close() pid = os.getpid() From f63480edb73b3e8d3a4357d14c4b8fd52cebbca4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 16:35:15 +0100 Subject: [PATCH 30/67] amend --- torchrl/envs/batched_envs.py | 60 +++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index a52200898a6..666a20f13c8 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -724,7 +724,9 @@ def _start_workers(self) -> None: if self.device.type == "cuda": func = _run_worker_pipe_cuda self._cuda_stream = torch.cuda.Stream(self.device) - self._cuda_events = [torch.cuda.Event(interprocess=True) for _ in range(_num_workers)] + self._cuda_events = [ + torch.cuda.Event(interprocess=True) for _ in range(_num_workers) + ] self._events = None kwargs = [{"cuda_event": self._cuda_events[i]} for i in range(_num_workers)] else: @@ -742,17 +744,20 @@ def _start_workers(self) -> None: env_fun = self.create_env_fn[idx] if not isinstance(env_fun, EnvCreator): env_fun = CloudpickleWrapper(env_fun) - kwargs[idx].update({ - "parent_pipe": parent_pipe, - "child_pipe": child_pipe, - "env_fun": env_fun, - "env_fun_kwargs": self.create_env_kwargs[idx], - "shared_tensordict": self.shared_tensordicts[idx], - "_selected_input_keys": self._selected_input_keys, - "_selected_reset_keys": self._selected_reset_keys, - "_selected_step_keys": self._selected_step_keys, - "has_lazy_inputs": self.has_lazy_inputs, - }) + kwargs[idx].update( + { + "parent_pipe": parent_pipe, + "child_pipe": child_pipe, + "env_fun": env_fun, + "device": self.device, + "env_fun_kwargs": self.create_env_kwargs[idx], + "shared_tensordict": self.shared_tensordicts[idx], + "_selected_input_keys": self._selected_input_keys, + "_selected_reset_keys": self._selected_reset_keys, + "_selected_step_keys": self._selected_step_keys, + "has_lazy_inputs": self.has_lazy_inputs, + } + ) process = _ProcessNoWarn( target=func, num_threads=self.num_sub_threads, @@ -794,9 +799,13 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: ) for i, channel in enumerate(self.parent_channels): channel.send(("load_state_dict", state_dict[f"worker{i}"])) - for event in self._events: - event.wait() - event.clear() + if self._events is not None: + for event in self._events: + event.wait() + event.clear() + else: + for event in self._cuda_events: + self._cuda_stream.wait_event(event) @_check_start def step_and_maybe_reset( @@ -819,7 +828,6 @@ def step_and_maybe_reset( for i in range(self.num_workers): self.parent_channels[i].send(("step_and_maybe_reset", None)) - if self._events is not None: # CPU case for i in range(self.num_workers): @@ -967,8 +975,12 @@ def _shutdown_workers(self) -> None: if self._verbose: print(f"closing {i}") channel.send(("close", None)) - self._events[i].wait() - self._events[i].clear() + if self._events is not None: + self._events[i].wait() + self._events[i].clear() + else: + for event in self._cuda_events: + self._cuda_stream.wait_event(event) del self.shared_tensordicts, self.shared_tensordict_parent @@ -1054,6 +1066,7 @@ def _recursively_strip_locks_from_state_dict(state_dict: OrderedDict) -> Ordered } ) + def _run_worker_pipe_shared_mem( parent_pipe: connection.Connection, child_pipe: connection.Connection, @@ -1249,8 +1262,7 @@ def _run_worker_pipe_cuda( shared_tensordict = shared_tensordict.clone(False) del shared_tensordict["next"] - if not ( - shared_tensordict.is_shared() or shared_tensordict.is_memmap()): + if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): raise RuntimeError( "tensordict must be placed in shared memory (share_memory_() or memmap_())" ) @@ -1274,7 +1286,6 @@ def _run_worker_pipe_cuda( next_shared_tensordict.update_(next_td) stream.record(cuda_event) stream.synchronize() - mp_event.set() elif cmd == "step_and_maybe_reset": if not initialized: @@ -1291,7 +1302,6 @@ def _run_worker_pipe_cuda( shared_tensordict.update_(root_next_td) stream.record(cuda_event) stream.synchronize() - mp_event.set() elif cmd == "close": del shared_tensordict, data @@ -1299,7 +1309,8 @@ def _run_worker_pipe_cuda( raise RuntimeError("call 'init' before closing") env.close() del env - mp_event.set() + stream.record(cuda_event) + stream.synchronize() child_pipe.close() if verbose: print(f"{pid} closed") @@ -1307,7 +1318,8 @@ def _run_worker_pipe_cuda( elif cmd == "load_state_dict": env.load_state_dict(data) - mp_event.set() + stream.record(cuda_event) + stream.synchronize() elif cmd == "state_dict": state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) From 9a3631fa58a5373a1de9e8ea217eaa781045916f Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 16:51:05 +0100 Subject: [PATCH 31/67] amend --- torchrl/envs/batched_envs.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 666a20f13c8..c747c51f635 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1275,7 +1275,7 @@ def _run_worker_pipe_cuda( raise RuntimeError("call 'init' before resetting") cur_td = env._reset(tensordict=data) shared_tensordict.update_(cur_td) - stream.record(cuda_event) + stream.record_event(cuda_event) stream.synchronize() elif cmd == "step": @@ -1284,7 +1284,7 @@ def _run_worker_pipe_cuda( i += 1 next_td = env._step(shared_tensordict) next_shared_tensordict.update_(next_td) - stream.record(cuda_event) + stream.record_event(cuda_event) stream.synchronize() elif cmd == "step_and_maybe_reset": @@ -1300,7 +1300,7 @@ def _run_worker_pipe_cuda( # for key, val in root_next_td.items(True, True): # shared_tensordict.get(key).copy_(val, non_blocking=True) shared_tensordict.update_(root_next_td) - stream.record(cuda_event) + stream.record_event(cuda_event) stream.synchronize() elif cmd == "close": @@ -1309,7 +1309,7 @@ def _run_worker_pipe_cuda( raise RuntimeError("call 'init' before closing") env.close() del env - stream.record(cuda_event) + stream.record_event(cuda_event) stream.synchronize() child_pipe.close() if verbose: @@ -1318,7 +1318,7 @@ def _run_worker_pipe_cuda( elif cmd == "load_state_dict": env.load_state_dict(data) - stream.record(cuda_event) + stream.record_event(cuda_event) stream.synchronize() elif cmd == "state_dict": From 2ceb438fb2da9fbe09dd74583c746f2e4fb85cbe Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 17:18:01 +0100 Subject: [PATCH 32/67] amend --- torchrl/envs/batched_envs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index c747c51f635..7d5ae6ed0ee 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -477,7 +477,6 @@ def close(self) -> None: self.__dict__["_input_spec"] = None self.__dict__["_output_spec"] = None self._properties_set = False - self._cuda_events = None self._shutdown_workers() self.is_closed = True @@ -983,13 +982,14 @@ def _shutdown_workers(self) -> None: self._cuda_stream.wait_event(event) del self.shared_tensordicts, self.shared_tensordict_parent - for channel in self.parent_channels: channel.close() for proc in self._workers: proc.join() del self._workers del self.parent_channels + self._cuda_events = None + self._events = None @_check_start def set_seed( From 6ecebdae6fe2f010ee7050b1105402c5a8987a90 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 17:41:05 +0100 Subject: [PATCH 33/67] amend --- torchrl/envs/batched_envs.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 7d5ae6ed0ee..636cd9a7aa3 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -814,12 +814,15 @@ def step_and_maybe_reset( # this is faster than update_ but won't work for lazy stacks for key in self._env_input_keys: key = _unravel_key_to_tuple(key) - self.shared_tensordict_parent._set_tuple( - key, - tensordict._get_tuple(key, None), - inplace=True, - validated=True, - ) + val = tensordict._get_tuple(key, None) + if val is not None: + self.shared_tensordict_parent.get(key).copy_(val, non_blocking=True) + # self.shared_tensordict_parent._set_tuple( + # key, + # val, + # inplace=True, + # validated=True, + # ) else: self.shared_tensordict_parent.update_( tensordict.select(*self._env_input_keys, strict=False) From 5c613c30d05c013199c06ca7fe69982b865b4be6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 18:43:12 +0100 Subject: [PATCH 34/67] amend --- torchrl/envs/batched_envs.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 636cd9a7aa3..3191acff096 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1223,6 +1223,7 @@ def _run_worker_pipe_cuda( verbose: bool = False, ) -> None: stream = torch.cuda.Stream(device) + env = env.to("cpu") with torch.cuda.StreamContext(stream): parent_pipe.close() pid = os.getpid() @@ -1297,12 +1298,12 @@ def _run_worker_pipe_cuda( td, root_next_td = env.step_and_maybe_reset( shared_tensordict.clone(False) ) - # for key, val in td.get("next").items(True, True): - # next_shared_tensordict.get(key).copy_(val, non_blocking=True) - next_shared_tensordict.update_(td.get("next")) - # for key, val in root_next_td.items(True, True): - # shared_tensordict.get(key).copy_(val, non_blocking=True) - shared_tensordict.update_(root_next_td) + for key, val in td.get("next").items(True, True): + next_shared_tensordict.get(key).copy_(val, non_blocking=True) + # next_shared_tensordict.update_(td.get("next")) + for key, val in root_next_td.items(True, True): + shared_tensordict.get(key).copy_(val, non_blocking=True) + # shared_tensordict.update_(root_next_td) stream.record_event(cuda_event) stream.synchronize() From 9f97e58d13af6aca3e3d183a70372cd040375b08 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 18:46:03 +0100 Subject: [PATCH 35/67] amend --- torchrl/envs/batched_envs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 3191acff096..df6350bb532 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1223,7 +1223,6 @@ def _run_worker_pipe_cuda( verbose: bool = False, ) -> None: stream = torch.cuda.Stream(device) - env = env.to("cpu") with torch.cuda.StreamContext(stream): parent_pipe.close() pid = os.getpid() @@ -1235,7 +1234,7 @@ def _run_worker_pipe_cuda( "env_fun_kwargs must be empty if an environment is passed to a process." ) env = env_fun - env = env.to(device) + env = env.to("cpu") del env_fun i = -1 From 9cbcbb0d056eed2cdcc152160a6a5e0183d6f5bb Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 20:40:09 +0100 Subject: [PATCH 36/67] amend --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index df6350bb532..449e78b4c1a 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1295,7 +1295,7 @@ def _run_worker_pipe_cuda( raise RuntimeError("called 'init' before step") i += 1 td, root_next_td = env.step_and_maybe_reset( - shared_tensordict.clone(False) + shared_tensordict.select(*_selected_input_keys).cpu() ) for key, val in td.get("next").items(True, True): next_shared_tensordict.get(key).copy_(val, non_blocking=True) From 72c4163b4279d24cb14ee17d8202916f1d3a2e7a Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 20:48:02 +0100 Subject: [PATCH 37/67] amend --- torchrl/envs/batched_envs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 449e78b4c1a..563d6112ac0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -730,10 +730,10 @@ def _start_workers(self) -> None: kwargs = [{"cuda_event": self._cuda_events[i]} for i in range(_num_workers)] else: func = _run_worker_pipe_shared_mem - kwargs = [{} for i in range(_num_workers)] self._cuda_stream = None self._cuda_events = None self._events = [ctx.Event() for _ in range(_num_workers)] + kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)] with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: @@ -1285,7 +1285,9 @@ def _run_worker_pipe_cuda( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - next_td = env._step(shared_tensordict) + next_td = env._step( + shared_tensordict.select(*_selected_input_keys).cpu() + ) next_shared_tensordict.update_(next_td) stream.record_event(cuda_event) stream.synchronize() From 6f4c3749dc6f378142baca754e21aec77bc19571 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 20:54:04 +0100 Subject: [PATCH 38/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index c47e2177b2f..2153f887806 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -116,8 +116,8 @@ def make(envname=envname, gym_backend=gym_backend, device=device): ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 + t0 = time.time() for i, data in enumerate(collector): - t0 = time.time() total_frames += data.numel() pbar.update(data.numel()) pbar.set_description( From 3657b41b474eec7e7334ddd0afa570b8de7ef762 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 11 Oct 2023 10:36:06 +0100 Subject: [PATCH 39/67] empty From 9206b930f22d895b64a4e171f0bcc61e1d786aee Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 11 Oct 2023 11:28:19 +0100 Subject: [PATCH 40/67] amend --- torchrl/envs/batched_envs.py | 141 +++++++++++++++++++++-------------- 1 file changed, 87 insertions(+), 54 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 563d6112ac0..5f11d8b5cd3 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -121,11 +121,16 @@ class _BatchedEnv(EnvBase): memmap (bool): whether or not the returned tensordict will be placed in memory map. policy_proof (callable, optional): if provided, it'll be used to get the list of tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc. - device (str, int, torch.device): for consistency, this argument is kept. However this - argument should not be passed, as the device will be inferred from the environments. - It is assumed that all environments will run on the same device as a common shared - tensordict will be used to pass data from process to process. The device can be - changed after instantiation using :obj:`env.to(device)`. + device (str, int, torch.device): The device of the batched environment can be passed. + If not, it is inferred from the env. In this case, it is assumed that + the device of all environments match. If it is provided, it can differ + from the sub-environment device(s). In that case, the data will be + automatically cast to the appropriate device during collection. + This can be used to speed up collection in case casting to device + introduces an overhead (eg, numpy-based environents etc.): by using + a ``"cuda"`` device for the batched environment but a ``"cpu"`` + device for the nested environments, one can keep the overhead to a + minimum. num_threads (int, optional): number of threads for this process. Defaults to the number of workers. This parameter has no effect for the :class:`~SerialEnv` class. @@ -161,14 +166,7 @@ def __init__( num_threads: int = None, num_sub_threads: int = 1, ): - if device is not None: - raise ValueError( - "Device setting for batched environment can't be done at initialization. " - "The device will be inferred from the constructed environment. " - "It can be set through the `to(device)` method." - ) - - super().__init__(device=None) + super().__init__(device=device) self.is_closed = True if num_threads is None: num_threads = num_workers + 1 # 1 more thread for this proc @@ -217,7 +215,7 @@ def __init__( "memmap and shared memory are mutually exclusive features." ) self._batch_size = None - self._device = None + self._device = torch.device(device) self._dummy_env_str = None self._seeds = None self.__dict__["_input_spec"] = None @@ -272,7 +270,9 @@ def _set_properties(self): self._properties_set = True if self._single_task: self._batch_size = meta_data.batch_size - device = self._device = meta_data.device + device = meta_data.device + if self._device is None: + self._device = device input_spec = meta_data.specs["input_spec"].to(device) output_spec = meta_data.specs["output_spec"].to(device) @@ -288,8 +288,18 @@ def _set_properties(self): self._batch_locked = meta_data.batch_locked else: self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size]) - device = self._device = meta_data[0].device - # TODO: check that all action_spec and reward spec match (issue #351) + devices = set() + for _meta_data in meta_data: + device = _meta_data.device + devices.append(device) + if self._device is None: + if len(devices) > 1: + raise ValueError( + f"The device wasn't passed to {type(self)}, but more than one device was found in the sub-environments. " + f"Please indicate a device to be used for collection." + ) + device = list(devices)[0] + self._device = device input_spec = [] for md in meta_data: @@ -500,11 +510,6 @@ def to(self, device: DEVICE_TYPING): if device == self.device: return self self._device = device - self.meta_data = ( - self.meta_data.to(device) - if self._single_task - else [meta_data.to(device) for meta_data in self.meta_data] - ) if not self.is_closed: warn( "Casting an open environment to another device requires closing and re-opening it. " @@ -536,7 +541,7 @@ def _start_workers(self) -> None: for idx in range(_num_workers): env = self.create_env_fn[idx](**self.create_env_kwargs[idx]) - self._envs.append(env.to(self.device)) + self._envs.append(env) self.is_closed = False @_check_start @@ -566,7 +571,12 @@ def _step( for i in range(self.num_workers): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. - out_td = self._envs[i]._step(tensordict_in[i]) + env_device = self._envs[i].device + if env_device != self.device: + data_in = tensordict_in[i].to(env_device) + else: + data_in = tensordict_in[i] + out_td = self._envs[i]._step(data_in) next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -617,6 +627,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_ = tensordict[i] if tensordict_.is_empty(): tensordict_ = None + else: + env_device = _env.device + if env_device != self.device: + tensordict_ = tensordict_.to(env_device) else: tensordict_ = None @@ -637,6 +651,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_.select(*self._selected_reset_keys, strict=False) ) continue + _td = _env._reset(tensordict=tensordict_, **kwargs) self.shared_tensordicts[i].update_( _td.select(*self._selected_reset_keys, strict=False) @@ -815,14 +830,12 @@ def step_and_maybe_reset( for key in self._env_input_keys: key = _unravel_key_to_tuple(key) val = tensordict._get_tuple(key, None) - if val is not None: - self.shared_tensordict_parent.get(key).copy_(val, non_blocking=True) - # self.shared_tensordict_parent._set_tuple( - # key, - # val, - # inplace=True, - # validated=True, - # ) + self.shared_tensordict_parent._set_tuple( + key, + val, + inplace=True, + validated=True, + ) else: self.shared_tensordict_parent.update_( tensordict.select(*self._env_input_keys, strict=False) @@ -1094,8 +1107,10 @@ def _run_worker_pipe_shared_mem( "env_fun_kwargs must be empty if an environment is passed to a process." ) env = env_fun - env = env.to(device) - del env_fun + env_device = env.device + # we check if the devices mismatch. This tells us that the data need + # to be cast onto the right device before any op + device_mismatch = device != env_device i = -1 initialized = False @@ -1136,6 +1151,8 @@ def _run_worker_pipe_shared_mem( print(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") + if data is not None and device_mismatch: + data = data.to(env_device, non_blocking=True) cur_td = env._reset(tensordict=data) shared_tensordict.update_(cur_td) mp_event.set() @@ -1144,7 +1161,13 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - next_td = env._step(shared_tensordict) + if device_mismatch: + env_input = shared_tensordict.select(*_selected_input_keys).to( + env_device, non_blocking=True + ) + else: + env_input = shared_tensordict + next_td = env._step(env_input) next_shared_tensordict.update_(next_td) mp_event.set() @@ -1152,12 +1175,14 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) - # for key, val in td.get("next").items(True, True): - # next_shared_tensordict.get(key).copy_(val, non_blocking=True) + if device_mismatch: + env_input = shared_tensordict.select(*_selected_input_keys).to( + env_device, non_blocking=True + ) + else: + env_input = shared_tensordict + td, root_next_td = env.step_and_maybe_reset(env_input) next_shared_tensordict.update_(td.get("next")) - # for key, val in root_next_td.items(True, True): - # shared_tensordict.get(key).copy_(val, non_blocking=True) shared_tensordict.update_(root_next_td) mp_event.set() @@ -1234,9 +1259,11 @@ def _run_worker_pipe_cuda( "env_fun_kwargs must be empty if an environment is passed to a process." ) env = env_fun - env = env.to("cpu") del env_fun - + env_device = env.device + # we check if the devices mismatch. This tells us that the data need + # to be cast onto the right device before any op + device_mismatch = device != env_device i = -1 initialized = False @@ -1276,6 +1303,8 @@ def _run_worker_pipe_cuda( print(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") + if data is not None and device_mismatch: + data = data.to(env_device, non_blocking=True) cur_td = env._reset(tensordict=data) shared_tensordict.update_(cur_td) stream.record_event(cuda_event) @@ -1285,9 +1314,13 @@ def _run_worker_pipe_cuda( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - next_td = env._step( - shared_tensordict.select(*_selected_input_keys).cpu() - ) + if device_mismatch: + env_input = shared_tensordict.select(*_selected_input_keys).to( + env_device, non_blocking=True + ) + else: + env_input = shared_tensordict + next_td = env._step(env_input) next_shared_tensordict.update_(next_td) stream.record_event(cuda_event) stream.synchronize() @@ -1296,15 +1329,15 @@ def _run_worker_pipe_cuda( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - td, root_next_td = env.step_and_maybe_reset( - shared_tensordict.select(*_selected_input_keys).cpu() - ) - for key, val in td.get("next").items(True, True): - next_shared_tensordict.get(key).copy_(val, non_blocking=True) - # next_shared_tensordict.update_(td.get("next")) - for key, val in root_next_td.items(True, True): - shared_tensordict.get(key).copy_(val, non_blocking=True) - # shared_tensordict.update_(root_next_td) + if device_mismatch: + env_input = shared_tensordict.select(*_selected_input_keys).to( + env_device, non_blocking=True + ) + else: + env_input = shared_tensordict + td, root_next_td = env.step_and_maybe_reset(env_input) + next_shared_tensordict.update_(td.get("next")) + shared_tensordict.update_(root_next_td) stream.record_event(cuda_event) stream.synchronize() From 2bac78c073a8b014b70390ea788b125a2e5e82a7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 10 Nov 2023 11:38:19 -0500 Subject: [PATCH 41/67] amend --- torchrl/envs/batched_envs.py | 167 ++++++++++++++++++----------------- 1 file changed, 86 insertions(+), 81 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 74d0f11f6ea..1aade430eeb 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -567,36 +567,6 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: for idx, env in enumerate(self._envs): env.load_state_dict(state_dict[f"worker{idx}"]) - @_check_start - def _step( - self, - tensordict: TensorDict, - ) -> TensorDict: - tensordict_in = tensordict.clone(False) - next_td = self.shared_tensordict_parent.get("next") - for i in range(self.num_workers): - # shared_tensordicts are locked, and we need to select the keys since we update in-place. - # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. - env_device = self._envs[i].device - if env_device != self.device: - data_in = tensordict_in[i].to(env_device) - else: - data_in = tensordict_in[i] - out_td = self._envs[i]._step(data_in) - next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) - # We must pass a clone of the tensordict, as the values of this tensordict - # will be modified in-place at further steps - if self._single_task: - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device - ) - for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) - else: - # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() - return out - def _shutdown_workers(self) -> None: if not self.is_closed: for env in self._envs: @@ -645,7 +615,6 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: else: tensordict_ = None - _td = _env.reset(tensordict=tensordict_, **kwargs) self.shared_tensordicts[i].update_( _td.select(*self._selected_reset_keys_filt, strict=False) @@ -671,6 +640,31 @@ def _reset_proc_data(self, tensordict, tensordict_reset): return _update_during_reset(tensordict_reset, tensordict, self.reset_keys) return tensordict_reset + # @_check_start + # def _step( + # self, + # tensordict: TensorDict, + # ) -> TensorDict: + # tensordict_in = tensordict.clone(False) + # next_td = self.shared_tensordict_parent.get("next") + # for i in range(self.num_workers): + # # shared_tensordicts are locked, and we need to select the keys since we update in-place. + # # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. + # out_td = self._envs[i]._step(tensordict_in[i]) + # next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) + # # We must pass a clone of the tensordict, as the values of this tensordict + # # will be modified in-place at further steps + # if self._single_task: + # out = TensorDict( + # {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + # ) + # for key in self._selected_step_keys: + # _set_single_key(next_td, out, key, clone=True) + # else: + # # strict=False ensures that non-homogeneous keys are still there + # out = next_td.select(*self._selected_step_keys, strict=False).clone() + # return out + @_check_start def _step( self, @@ -681,7 +675,12 @@ def _step( for i in range(self.num_workers): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. - out_td = self._envs[i]._step(tensordict_in[i]) + env_device = self._envs[i].device + if env_device != self.device: + data_in = tensordict_in[i].to(env_device, non_blocking=True) + else: + data_in = tensordict_in[i] + out_td = self._envs[i]._step(data_in) next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -846,48 +845,6 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: for event in self._cuda_events: self._cuda_stream.wait_event(event) - @_check_start - def step_and_maybe_reset( - self, tensordict: TensorDictBase - ) -> Tuple[TensorDictBase, TensorDictBase]: - if self._single_task and not self.has_lazy_inputs: - # this is faster than update_ but won't work for lazy stacks - for key in self._env_input_keys: - key = _unravel_key_to_tuple(key) - val = tensordict._get_tuple(key, None) - self.shared_tensordict_parent._set_tuple( - key, - val, - inplace=True, - validated=True, - ) - else: - self.shared_tensordict_parent.update_( - tensordict.select(*self._env_input_keys, strict=False) - ) - for i in range(self.num_workers): - self.parent_channels[i].send(("step_and_maybe_reset", None)) - - if self._events is not None: - # CPU case - for i in range(self.num_workers): - event = self._events[i] - event.wait() - event.clear() - else: - # CUDA case - for i in range(self.num_workers): - event = self._cuda_events[i] - self._cuda_stream.wait_event(event) - - # We must pass a clone of the tensordict, as the values of this tensordict - # will be modified in-place at further steps - tensordict.set("next", self.shared_tensordict_parent.get("next").clone()) - tensordict_ = self.shared_tensordict_parent.exclude( - "next", *self.reset_keys - ).clone() - return tensordict, tensordict_ - @_check_start def step_and_maybe_reset( self, tensordict: TensorDictBase @@ -919,10 +876,17 @@ def step_and_maybe_reset( for i in range(self.num_workers): self.parent_channels[i].send(("step_and_maybe_reset", None)) - for i in range(self.num_workers): - event = self._events[i] - event.wait() - event.clear() + if self._events is not None: + # CPU case + for i in range(self.num_workers): + event = self._events[i] + event.wait() + event.clear() + else: + # CUDA case + for i in range(self.num_workers): + event = self._cuda_events[i] + self._cuda_stream.wait_event(event) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -932,6 +896,50 @@ def step_and_maybe_reset( ).clone() return tensordict, tensordict_ + # @_check_start + # def step_and_maybe_reset( + # self, tensordict: TensorDictBase + # ) -> Tuple[TensorDictBase, TensorDictBase]: + # if self._single_task and not self.has_lazy_inputs: + # # We must use the in_keys and nothing else for the following reasons: + # # - efficiency: copying all the keys will in practice mean doing a lot + # # of writing operations since the input tensordict may (and often will) + # # contain all the previous output data. + # # - value mismatch: if the batched env is placed within a transform + # # and this transform overrides an observation key (eg, CatFrames) + # # the shape, dtype or device may not necessarily match and writing + # # the value in-place will fail. + # for key in tensordict.keys(True, True): + # # we copy the input keys as well as the keys in the 'next' td, if any + # # as this mechanism can be used by a policy to set anticipatively the + # # keys of the next call (eg, with recurrent nets) + # if key in self._env_input_keys or ( + # isinstance(key, tuple) + # and key[0] == "next" + # and key in self.shared_tensordict_parent.keys(True, True) + # ): + # val = tensordict.get(key) + # self.shared_tensordict_parent.set_(key, val) + # else: + # self.shared_tensordict_parent.update_( + # tensordict.select(*self._env_input_keys, "next", strict=False) + # ) + # for i in range(self.num_workers): + # self.parent_channels[i].send(("step_and_maybe_reset", None)) + # + # for i in range(self.num_workers): + # event = self._events[i] + # event.wait() + # event.clear() + # + # # We must pass a clone of the tensordict, as the values of this tensordict + # # will be modified in-place at further steps + # tensordict.set("next", self.shared_tensordict_parent.get("next").clone()) + # tensordict_ = self.shared_tensordict_parent.exclude( + # "next", *self.reset_keys + # ).clone() + # return tensordict, tensordict_ + @_check_start def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self._single_task and not self.has_lazy_inputs: @@ -1276,9 +1284,6 @@ def _run_worker_pipe_shared_mem( td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) next_shared_tensordict.update_(td.get("next")) root_shared_tensordict.update_(root_next_td) - if event is not None: - event.record() - event.synchronize() mp_event.set() elif cmd == "close": From 51a8856f8813790c361ac72891b106c4898fdbc4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 10 Nov 2023 11:47:51 -0500 Subject: [PATCH 42/67] amend --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 1aade430eeb..bac0868ad14 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -216,7 +216,7 @@ def __init__( "memmap and shared memory are mutually exclusive features." ) self._batch_size = None - self._device = torch.device(device) + self._device = torch.device(device) if device is not None else device self._dummy_env_str = None self._seeds = None self.__dict__["_input_spec"] = None From 7a300f5121b827562ad83643c0bd202e52f0ed4d Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 10 Nov 2023 11:53:47 -0500 Subject: [PATCH 43/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 71b7a481ce0..8b649fcf29c 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -76,12 +76,12 @@ def make(envname=envname, gym_backend=gym_backend): # regular parallel env for device in avail_devices: - def make(envname=envname, gym_backend=gym_backend, device=device): + def make(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") # env_make = EnvCreator(make) - penv = ParallelEnv(num_workers, EnvCreator(make)) + penv = ParallelEnv(num_workers, EnvCreator(make), device=device) with torch.inference_mode(): # warmup penv.rollout(2) From b25921d7e24091a9729d75f444ee103d13bfd285 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 10 Nov 2023 12:07:29 -0500 Subject: [PATCH 44/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 8b649fcf29c..783e42e079d 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -81,7 +81,7 @@ def make(envname=envname, gym_backend=gym_backend): return GymEnv(envname, device="cpu") # env_make = EnvCreator(make) - penv = ParallelEnv(num_workers, EnvCreator(make), device=device) + penv = ParallelEnv(num_workers, (EnvCreator(make),)*num_workers, device=device) with torch.inference_mode(): # warmup penv.rollout(2) From d42348da115f2dfa019282309cbddf960c287fa2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 10 Nov 2023 12:19:49 -0500 Subject: [PATCH 45/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 2 +- torchrl/envs/batched_envs.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 783e42e079d..8b649fcf29c 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -81,7 +81,7 @@ def make(envname=envname, gym_backend=gym_backend): return GymEnv(envname, device="cpu") # env_make = EnvCreator(make) - penv = ParallelEnv(num_workers, (EnvCreator(make),)*num_workers, device=device) + penv = ParallelEnv(num_workers, EnvCreator(make), device=device) with torch.inference_mode(): # warmup penv.rollout(2) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index bac0868ad14..e24a7130310 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1364,6 +1364,7 @@ def _run_worker_pipe_cuda( # we check if the devices mismatch. This tells us that the data need # to be cast onto the right device before any op device_mismatch = device != env_device + env_device_cpu = env_device.type == "cpu" i = -1 initialized = False @@ -1436,8 +1437,12 @@ def _run_worker_pipe_cuda( else: env_input = shared_tensordict td, root_next_td = env.step_and_maybe_reset(env_input) - next_shared_tensordict.update_(td.get("next")) - shared_tensordict.update_(root_next_td) + if env_device_cpu: + next_shared_tensordict.update_(td.get("next").pin_memory()) + shared_tensordict.update_(root_next_td.pin_memory()) + else: + next_shared_tensordict.update_(td.get("next")) + shared_tensordict.update_(root_next_td) stream.record_event(cuda_event) stream.synchronize() From f3421aa721648238d9a445bad9fd4ec87a6eb747 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 10 Nov 2023 12:27:41 -0500 Subject: [PATCH 46/67] amend --- torchrl/envs/batched_envs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index e24a7130310..4a1a00b7f36 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1438,8 +1438,8 @@ def _run_worker_pipe_cuda( env_input = shared_tensordict td, root_next_td = env.step_and_maybe_reset(env_input) if env_device_cpu: - next_shared_tensordict.update_(td.get("next").pin_memory()) - shared_tensordict.update_(root_next_td.pin_memory()) + next_shared_tensordict.apply(_update_cuda, td.get("next")) + shared_tensordict.apply(_update_cuda, root_next_td) else: next_shared_tensordict.update_(td.get("next")) shared_tensordict.update_(root_next_td) @@ -1493,3 +1493,7 @@ def _run_worker_pipe_cuda( else: # don't send env through pipe child_pipe.send(("_".join([cmd, "done"]), None)) + +def _update_cuda(t_dest, t_source): + t_dest.copy_(t_source.pin_memory(), non_blocking=True) + return None \ No newline at end of file From 939ece493da01ee00715d19456839278183fbf72 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 10 Nov 2023 12:27:56 -0500 Subject: [PATCH 47/67] amend --- torchrl/envs/batched_envs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 4a1a00b7f36..f8540045d1f 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1438,8 +1438,8 @@ def _run_worker_pipe_cuda( env_input = shared_tensordict td, root_next_td = env.step_and_maybe_reset(env_input) if env_device_cpu: - next_shared_tensordict.apply(_update_cuda, td.get("next")) - shared_tensordict.apply(_update_cuda, root_next_td) + next_shared_tensordict._fast_apply(_update_cuda, td.get("next")) + shared_tensordict._fast_apply(_update_cuda, root_next_td) else: next_shared_tensordict.update_(td.get("next")) shared_tensordict.update_(root_next_td) From 082ba9a5d8c68634e084b50018d06a7feaac8103 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 10 Nov 2023 12:33:38 -0500 Subject: [PATCH 48/67] amend --- torchrl/envs/batched_envs.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f8540045d1f..d1a25501f33 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1438,8 +1438,8 @@ def _run_worker_pipe_cuda( env_input = shared_tensordict td, root_next_td = env.step_and_maybe_reset(env_input) if env_device_cpu: - next_shared_tensordict._fast_apply(_update_cuda, td.get("next")) - shared_tensordict._fast_apply(_update_cuda, root_next_td) + next_shared_tensordict._fast_apply(_update_cuda, td.get("next"), default=None) + shared_tensordict._fast_apply(_update_cuda, root_next_td, defaut=None) else: next_shared_tensordict.update_(td.get("next")) shared_tensordict.update_(root_next_td) @@ -1495,5 +1495,7 @@ def _run_worker_pipe_cuda( child_pipe.send(("_".join([cmd, "done"]), None)) def _update_cuda(t_dest, t_source): + if t_source is None: + return t_dest.copy_(t_source.pin_memory(), non_blocking=True) - return None \ No newline at end of file + return \ No newline at end of file From 4fd670f1fbff7d5d139deaea1be5992572f228c1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 10 Nov 2023 12:37:16 -0500 Subject: [PATCH 49/67] amend --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index d1a25501f33..0b76876e8da 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1439,7 +1439,7 @@ def _run_worker_pipe_cuda( td, root_next_td = env.step_and_maybe_reset(env_input) if env_device_cpu: next_shared_tensordict._fast_apply(_update_cuda, td.get("next"), default=None) - shared_tensordict._fast_apply(_update_cuda, root_next_td, defaut=None) + shared_tensordict._fast_apply(_update_cuda, root_next_td, default=None) else: next_shared_tensordict.update_(td.get("next")) shared_tensordict.update_(root_next_td) From 2a773f3a392c2fc0e6ef7f1252be4fdab1e72f78 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 10 Nov 2023 12:46:25 -0500 Subject: [PATCH 50/67] amend --- benchmarks/ecosystem/gym_env_throughput.py | 32 ++++++++++------------ torchrl/envs/batched_envs.py | 11 ++++++-- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 8b649fcf29c..246c5ee15f0 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -103,13 +103,13 @@ def make(envname=envname, gym_backend=gym_backend): for device in avail_devices: - def make(envname=envname, gym_backend=gym_backend, device=device): + def make(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") env_make = EnvCreator(make) # penv = SerialEnv(num_workers, env_make) - penv = ParallelEnv(num_workers, env_make) + penv = ParallelEnv(num_workers, env_make, device=device) collector = SyncDataCollector( penv, RandomPolicy(penv.action_spec), @@ -164,14 +164,14 @@ def make_env( for device in avail_devices: # async collector # + torchrl parallel env - def make_env( - envname=envname, gym_backend=gym_backend, device=device - ): + def make_env(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") penv = ParallelEnv( - num_workers // num_collectors, EnvCreator(make_env) + num_workers // num_collectors, + EnvCreator(make_env), + device=device, ) collector = MultiaSyncDataCollector( [penv] * num_collectors, @@ -206,10 +206,9 @@ def make_env( envname=envname, num_workers=num_workers, gym_backend=gym_backend, - device=device, ): with set_gym_backend(gym_backend): - penv = GymEnv(envname, num_envs=num_workers, device=device) + penv = GymEnv(envname, num_envs=num_workers, device="cpu") return penv penv = EnvCreator( @@ -247,14 +246,14 @@ def make_env( for device in avail_devices: # sync collector # + torchrl parallel env - def make_env( - envname=envname, gym_backend=gym_backend, device=device - ): + def make_env(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") penv = ParallelEnv( - num_workers // num_collectors, EnvCreator(make_env) + num_workers // num_collectors, + EnvCreator(make_env), + device=device, ) collector = MultiSyncDataCollector( [penv] * num_collectors, @@ -289,10 +288,9 @@ def make_env( envname=envname, num_workers=num_workers, gym_backend=gym_backend, - device=device, ): with set_gym_backend(gym_backend): - penv = GymEnv(envname, num_envs=num_workers, device=device) + penv = GymEnv(envname, num_envs=num_workers, device="cpu") return penv penv = EnvCreator( diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 0b76876e8da..54179f76941 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1438,8 +1438,12 @@ def _run_worker_pipe_cuda( env_input = shared_tensordict td, root_next_td = env.step_and_maybe_reset(env_input) if env_device_cpu: - next_shared_tensordict._fast_apply(_update_cuda, td.get("next"), default=None) - shared_tensordict._fast_apply(_update_cuda, root_next_td, default=None) + next_shared_tensordict._fast_apply( + _update_cuda, td.get("next"), default=None + ) + shared_tensordict._fast_apply( + _update_cuda, root_next_td, default=None + ) else: next_shared_tensordict.update_(td.get("next")) shared_tensordict.update_(root_next_td) @@ -1494,8 +1498,9 @@ def _run_worker_pipe_cuda( # don't send env through pipe child_pipe.send(("_".join([cmd, "done"]), None)) + def _update_cuda(t_dest, t_source): if t_source is None: return t_dest.copy_(t_source.pin_memory(), non_blocking=True) - return \ No newline at end of file + return From 7c04f623fac6935d9a996b469d0dfc3bff628056 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 27 Nov 2023 16:31:10 +0000 Subject: [PATCH 51/67] amend --- torchrl/envs/batched_envs.py | 141 +++++++++++++---------------------- torchrl/envs/utils.py | 12 ++- 2 files changed, 60 insertions(+), 93 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 54179f76941..f1592c11e70 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -292,7 +292,7 @@ def _set_properties(self): devices = set() for _meta_data in meta_data: device = _meta_data.device - devices.append(device) + devices.add(device) if self._device is None: if len(devices) > 1: raise ValueError( @@ -620,19 +620,26 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: _td.select(*self._selected_reset_keys_filt, strict=False) ) selected_output_keys = self._selected_reset_keys_filt + device = self.device if self._single_task: # select + clone creates 2 tds, but we can create one only out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in selected_output_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) - return out + _set_single_key( + self.shared_tensordict_parent, out, key, clone=True, device=device + ) else: - return self.shared_tensordict_parent.select( + out = self.shared_tensordict_parent.select( *selected_output_keys, strict=False, - ).clone() + ) + if out.device == device: + out = out.clone() + else: + out = out.to(self.device, non_blocking=True) + return out def _reset_proc_data(self, tensordict, tensordict_reset): # since we call `reset` directly, all the postproc has been completed @@ -640,31 +647,6 @@ def _reset_proc_data(self, tensordict, tensordict_reset): return _update_during_reset(tensordict_reset, tensordict, self.reset_keys) return tensordict_reset - # @_check_start - # def _step( - # self, - # tensordict: TensorDict, - # ) -> TensorDict: - # tensordict_in = tensordict.clone(False) - # next_td = self.shared_tensordict_parent.get("next") - # for i in range(self.num_workers): - # # shared_tensordicts are locked, and we need to select the keys since we update in-place. - # # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. - # out_td = self._envs[i]._step(tensordict_in[i]) - # next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) - # # We must pass a clone of the tensordict, as the values of this tensordict - # # will be modified in-place at further steps - # if self._single_task: - # out = TensorDict( - # {}, batch_size=self.shared_tensordict_parent.shape, device=self.device - # ) - # for key in self._selected_step_keys: - # _set_single_key(next_td, out, key, clone=True) - # else: - # # strict=False ensures that non-homogeneous keys are still there - # out = next_td.select(*self._selected_step_keys, strict=False).clone() - # return out - @_check_start def _step( self, @@ -684,15 +666,20 @@ def _step( next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps + device = self.device if self._single_task: out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True, device=device) else: # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() + out = next_td.select(*self._selected_step_keys, strict=False) + if out.device == device: + out = out.clone() + else: + out = out.to(self.device, non_blocking=True) return out def __getattr__(self, attr: str) -> Any: @@ -890,56 +877,18 @@ def step_and_maybe_reset( # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - tensordict.set("next", self.shared_tensordict_parent.get("next").clone()) - tensordict_ = self.shared_tensordict_parent.exclude( - "next", *self.reset_keys - ).clone() + next_td = self.shared_tensordict_parent.get("next") + tensordict_ = self.shared_tensordict_parent.exclude("next", *self.reset_keys) + device = self.device + if self.shared_tensordict_parent.device == device: + next_td = next_td.clone() + tensordict_ = tensordict_.clone() + else: + next_td = next_td.to(device, non_blocking=True) + tensordict_ = tensordict_.to(device, non_blocking=True) + tensordict.set("next", next_td) return tensordict, tensordict_ - # @_check_start - # def step_and_maybe_reset( - # self, tensordict: TensorDictBase - # ) -> Tuple[TensorDictBase, TensorDictBase]: - # if self._single_task and not self.has_lazy_inputs: - # # We must use the in_keys and nothing else for the following reasons: - # # - efficiency: copying all the keys will in practice mean doing a lot - # # of writing operations since the input tensordict may (and often will) - # # contain all the previous output data. - # # - value mismatch: if the batched env is placed within a transform - # # and this transform overrides an observation key (eg, CatFrames) - # # the shape, dtype or device may not necessarily match and writing - # # the value in-place will fail. - # for key in tensordict.keys(True, True): - # # we copy the input keys as well as the keys in the 'next' td, if any - # # as this mechanism can be used by a policy to set anticipatively the - # # keys of the next call (eg, with recurrent nets) - # if key in self._env_input_keys or ( - # isinstance(key, tuple) - # and key[0] == "next" - # and key in self.shared_tensordict_parent.keys(True, True) - # ): - # val = tensordict.get(key) - # self.shared_tensordict_parent.set_(key, val) - # else: - # self.shared_tensordict_parent.update_( - # tensordict.select(*self._env_input_keys, "next", strict=False) - # ) - # for i in range(self.num_workers): - # self.parent_channels[i].send(("step_and_maybe_reset", None)) - # - # for i in range(self.num_workers): - # event = self._events[i] - # event.wait() - # event.clear() - # - # # We must pass a clone of the tensordict, as the values of this tensordict - # # will be modified in-place at further steps - # tensordict.set("next", self.shared_tensordict_parent.get("next").clone()) - # tensordict_ = self.shared_tensordict_parent.exclude( - # "next", *self.reset_keys - # ).clone() - # return tensordict, tensordict_ - @_check_start def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self._single_task and not self.has_lazy_inputs: @@ -984,15 +933,20 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps next_td = self.shared_tensordict_parent.get("next") + device = self.device if self._single_task: out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True, device=device) else: # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() + out = next_td.select(*self._selected_step_keys, strict=False) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) return out @_check_start @@ -1055,19 +1009,26 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._cuda_stream.wait_event(event) selected_output_keys = self._selected_reset_keys_filt + device = self.device if self._single_task: # select + clone creates 2 tds, but we can create one only out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in selected_output_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) - return out + _set_single_key( + self.shared_tensordict_parent, out, key, clone=True, device=device + ) else: - return self.shared_tensordict_parent.select( + out = self.shared_tensordict_parent.select( *selected_output_keys, strict=False, - ).clone() + ) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) + return out @_check_start def _shutdown_workers(self) -> None: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 06eec73be97..9a2a71f24bd 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -237,7 +237,11 @@ def step_mdp( def _set_single_key( - source: TensorDictBase, dest: TensorDictBase, key: str | tuple, clone: bool = False + source: TensorDictBase, + dest: TensorDictBase, + key: str | tuple, + clone: bool = False, + device=None, ): # key should be already unraveled if isinstance(key, str): @@ -253,7 +257,9 @@ def _set_single_key( source = val dest = new_val else: - if clone: + if device is not None and val.device != device: + val = val.to(device, non_blocking=True) + elif clone: val = val.clone() dest._set_str(k, val, inplace=False, validated=True) # This is a temporary solution to understand if a key is heterogeneous @@ -262,7 +268,7 @@ def _set_single_key( if re.match(r"Found more than one unique shape in the tensors", str(err)): # this is a het key for s_td, d_td in zip(source.tensordicts, dest.tensordicts): - _set_single_key(s_td, d_td, k, clone) + _set_single_key(s_td, d_td, k, clone=clone, device=device) break else: raise err From b42277577050fd27bd1f17ad221d89c31c5b24b8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 27 Nov 2023 16:42:21 +0000 Subject: [PATCH 52/67] amend --- torchrl/envs/batched_envs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f1592c11e70..54b7be39ea7 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -746,7 +746,7 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] - if self.device.type == "cuda": + if self.shared_tensordict_parent.device.type == "cuda": func = _run_worker_pipe_cuda self._cuda_stream = torch.cuda.Stream(self.device) self._cuda_events = [ @@ -1400,10 +1400,10 @@ def _run_worker_pipe_cuda( td, root_next_td = env.step_and_maybe_reset(env_input) if env_device_cpu: next_shared_tensordict._fast_apply( - _update_cuda, td.get("next"), default=None + _update_cuda, td.get("next", default=None), ) shared_tensordict._fast_apply( - _update_cuda, root_next_td, default=None + _update_cuda, root_next_td ) else: next_shared_tensordict.update_(td.get("next")) From ed0287d03c58257d3ab0b0242ba84c4881b5ac1e Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 27 Nov 2023 16:48:12 +0000 Subject: [PATCH 53/67] amend --- torchrl/envs/batched_envs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 54b7be39ea7..3b38ef109bd 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -423,7 +423,7 @@ def _create_td(self) -> None: *(unravel_key(("next", key)) for key in self._env_output_keys), strict=False, ) - self.shared_tensordict_parent = shared_tensordict_parent.to(self.device) + self.shared_tensordict_parent = shared_tensordict_parent else: # Multi-task: we share tensordict that *may* have different keys shared_tensordict_parent = [ @@ -431,7 +431,7 @@ def _create_td(self) -> None: *self._selected_keys, *(unravel_key(("next", key)) for key in self._env_output_keys), strict=False, - ).to(self.device) + ) for tensordict in shared_tensordict_parent ] shared_tensordict_parent = torch.stack( @@ -638,7 +638,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if out.device == device: out = out.clone() else: - out = out.to(self.device, non_blocking=True) + out = out.to(device, non_blocking=True) return out def _reset_proc_data(self, tensordict, tensordict_reset): @@ -679,7 +679,7 @@ def _step( if out.device == device: out = out.clone() else: - out = out.to(self.device, non_blocking=True) + out = out.to(device, non_blocking=True) return out def __getattr__(self, attr: str) -> Any: From 05b0c08b9f3b6139f138f38605d5f5d8d48c7026 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 27 Nov 2023 21:02:58 +0000 Subject: [PATCH 54/67] amend --- test/test_env.py | 37 ++++++++++++++++++++++++++++++++++++ torchrl/envs/batched_envs.py | 7 +++---- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 6cee7f545d7..8bf51263147 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -354,6 +354,43 @@ def test_mb_env_batch_lock(self, device, seed=0): class TestParallel: + @pytest.mark.skipif( + not torch.cuda.device_count(), reason="No cuda device detected." + ) + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("hetero", [True, False]) + @pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"]) + @pytest.mark.parametrize("edevice", ["cpu", "cuda"]) + def test_parallel_devices(self, parallel, hetero, pdevice, edevice): + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + if not hetero: + env = cls( + 2, lambda: ContinuousActionVecMockEnv(device=edevice), device=pdevice + ) + else: + env1 = lambda: ContinuousActionVecMockEnv(device=edevice) + env2 = lambda: TransformedEnv(ContinuousActionVecMockEnv(device=edevice)) + env = cls(2, [env1, env2], device=pdevice) + + r = env.rollout(2) + if pdevice is not None: + assert env.device == torch.device(pdevice) + assert r.device == torch.device(pdevice) + assert all( + item.device == torch.device(pdevice) for item in r.values(True, True) + ) + else: + assert env.device == torch.device(edevice) + assert r.device == torch.device(edevice) + assert all( + item.device == torch.device(edevice) for item in r.values(True, True) + ) + if parallel: + assert env.shared_tensordict_parent.device == torch.device(edevice) + @pytest.mark.parametrize("num_parallel_env", [1, 10]) @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)]) def test_env_with_batch_size(self, num_parallel_env, env_batch_size): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 3b38ef109bd..e69a6fe4d65 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1400,11 +1400,10 @@ def _run_worker_pipe_cuda( td, root_next_td = env.step_and_maybe_reset(env_input) if env_device_cpu: next_shared_tensordict._fast_apply( - _update_cuda, td.get("next", default=None), - ) - shared_tensordict._fast_apply( - _update_cuda, root_next_td + _update_cuda, + td.get("next", default=None), ) + shared_tensordict._fast_apply(_update_cuda, root_next_td) else: next_shared_tensordict.update_(td.get("next")) shared_tensordict.update_(root_next_td) From 33de0fc52d24b36971cf6e82bba7baf2a3062bbf Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 28 Nov 2023 10:16:48 +0000 Subject: [PATCH 55/67] amend --- torchrl/envs/batched_envs.py | 77 ++++++++++-------------------------- 1 file changed, 20 insertions(+), 57 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index e69a6fe4d65..b6f1d9d229f 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -774,7 +774,6 @@ def _start_workers(self) -> None: "parent_pipe": parent_pipe, "child_pipe": child_pipe, "env_fun": env_fun, - "device": self.device, "env_fun_kwargs": self.create_env_kwargs[idx], "shared_tensordict": self.shared_tensordicts[idx], "_selected_input_keys": self._selected_input_keys, @@ -1138,7 +1137,6 @@ def _run_worker_pipe_shared_mem( child_pipe: connection.Connection, env_fun: Union[EnvBase, Callable], env_fun_kwargs: Dict[str, Any], - device: DEVICE_TYPING = None, mp_event: mp.Event = None, shared_tensordict: TensorDictBase = None, _selected_input_keys=None, @@ -1158,9 +1156,6 @@ def _run_worker_pipe_shared_mem( ) env = env_fun env_device = env.device - # we check if the devices mismatch. This tells us that the data need - # to be cast onto the right device before any op - device_mismatch = device != env_device i = -1 initialized = False @@ -1201,8 +1196,6 @@ def _run_worker_pipe_shared_mem( print(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") - if data is not None and device_mismatch: - data = data.to(env_device, non_blocking=True) cur_td = env.reset(tensordict=data) shared_tensordict.update_( cur_td.select(*_selected_reset_keys, strict=False) @@ -1213,12 +1206,7 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - if device_mismatch: - env_input = shared_tensordict.select(*_selected_input_keys).to( - env_device, non_blocking=True - ) - else: - env_input = shared_tensordict + env_input = shared_tensordict next_td = env._step(env_input) next_shared_tensordict.update_(next_td) mp_event.set() @@ -1227,12 +1215,7 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - if device_mismatch: - env_input = shared_tensordict.select(*_selected_input_keys).to( - env_device, non_blocking=True - ) - else: - env_input = shared_tensordict + env_input = shared_tensordict td, root_next_td = env.step_and_maybe_reset(env_input) next_shared_tensordict.update_(td.get("next")) shared_tensordict.update_(root_next_td) @@ -1299,7 +1282,6 @@ def _run_worker_pipe_cuda( child_pipe: connection.Connection, env_fun: Union[EnvBase, Callable], env_fun_kwargs: Dict[str, Any], - device: DEVICE_TYPING = None, cuda_event: torch.cuda.Event = None, shared_tensordict: TensorDictBase = None, _selected_input_keys=None, @@ -1308,23 +1290,23 @@ def _run_worker_pipe_cuda( has_lazy_inputs: bool = False, verbose: bool = False, ) -> None: - stream = torch.cuda.Stream(device) + parent_pipe.close() + pid = os.getpid() + if not isinstance(env_fun, EnvBase): + env = env_fun(**env_fun_kwargs) + else: + if env_fun_kwargs: + raise RuntimeError( + "env_fun_kwargs must be empty if an environment is passed to a process." + ) + env = env_fun + del env_fun + env_device = env.device + + stream = torch.cuda.Stream(env_device) with torch.cuda.StreamContext(stream): - parent_pipe.close() - pid = os.getpid() - if not isinstance(env_fun, EnvBase): - env = env_fun(**env_fun_kwargs) - else: - if env_fun_kwargs: - raise RuntimeError( - "env_fun_kwargs must be empty if an environment is passed to a process." - ) - env = env_fun - del env_fun - env_device = env.device # we check if the devices mismatch. This tells us that the data need # to be cast onto the right device before any op - device_mismatch = device != env_device env_device_cpu = env_device.type == "cpu" i = -1 initialized = False @@ -1365,8 +1347,6 @@ def _run_worker_pipe_cuda( print(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") - if data is not None and device_mismatch: - data = data.to(env_device, non_blocking=True) cur_td = env._reset(tensordict=data) shared_tensordict.update_(cur_td) stream.record_event(cuda_event) @@ -1376,12 +1356,7 @@ def _run_worker_pipe_cuda( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - if device_mismatch: - env_input = shared_tensordict.select(*_selected_input_keys).to( - env_device, non_blocking=True - ) - else: - env_input = shared_tensordict + env_input = shared_tensordict next_td = env._step(env_input) next_shared_tensordict.update_(next_td) stream.record_event(cuda_event) @@ -1391,22 +1366,10 @@ def _run_worker_pipe_cuda( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - if device_mismatch: - env_input = shared_tensordict.select(*_selected_input_keys).to( - env_device, non_blocking=True - ) - else: - env_input = shared_tensordict + env_input = shared_tensordict td, root_next_td = env.step_and_maybe_reset(env_input) - if env_device_cpu: - next_shared_tensordict._fast_apply( - _update_cuda, - td.get("next", default=None), - ) - shared_tensordict._fast_apply(_update_cuda, root_next_td) - else: - next_shared_tensordict.update_(td.get("next")) - shared_tensordict.update_(root_next_td) + next_shared_tensordict.update_(td.get("next")) + shared_tensordict.update_(root_next_td) stream.record_event(cuda_event) stream.synchronize() From 492a884a3a0afcc81354a926cda5289287204d9d Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 28 Nov 2023 11:05:46 +0000 Subject: [PATCH 56/67] amend --- test/test_env.py | 5 +++-- torchrl/envs/batched_envs.py | 12 +++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 8bf51263147..917d7003534 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -361,7 +361,8 @@ class TestParallel: @pytest.mark.parametrize("hetero", [True, False]) @pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"]) @pytest.mark.parametrize("edevice", ["cpu", "cuda"]) - def test_parallel_devices(self, parallel, hetero, pdevice, edevice): + @pytest.mark.parametrize("bwad", [True, False]) + def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad): if parallel: cls = ParallelEnv else: @@ -375,7 +376,7 @@ def test_parallel_devices(self, parallel, hetero, pdevice, edevice): env2 = lambda: TransformedEnv(ContinuousActionVecMockEnv(device=edevice)) env = cls(2, [env1, env2], device=pdevice) - r = env.rollout(2) + r = env.rollout(2, break_when_any_done=bwad) if pdevice is not None: assert env.device == torch.device(pdevice) assert r.device == torch.device(pdevice) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b6f1d9d229f..2167e9db81c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -450,13 +450,11 @@ def _create_td(self) -> None: # Multi-task: we share tensordict that *may* have different keys # LazyStacked already stores this so we don't need to do anything self.shared_tensordicts = self.shared_tensordict_parent - if self.device.type == "cpu": + if self.shared_tensordict_parent.device.type == "cpu": if self._share_memory: - for td in self.shared_tensordicts: - td.share_memory_() + self.shared_tensordict_parent.share_memory_() elif self._memmap: - for td in self.shared_tensordicts: - td.memmap_() + self.shared_tensordict_parent.memmap_() else: if self._share_memory: self.shared_tensordict_parent.share_memory_() @@ -946,6 +944,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: out = out.clone() else: out = out.to(device, non_blocking=True) + assert all( + val.device == device for val in + out.values(True, True) + ) return out @_check_start From ff4799d7825957c90725fc5213374ffae80c254e Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 28 Nov 2023 11:14:57 +0000 Subject: [PATCH 57/67] amend --- test/test_env.py | 14 +++++++------- torchrl/envs/batched_envs.py | 4 ---- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 917d7003534..0199d7be83c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -378,19 +378,19 @@ def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad): r = env.rollout(2, break_when_any_done=bwad) if pdevice is not None: - assert env.device == torch.device(pdevice) - assert r.device == torch.device(pdevice) + assert env.device.type == torch.device(pdevice).type + assert r.device.type == torch.device(pdevice).type assert all( - item.device == torch.device(pdevice) for item in r.values(True, True) + item.device.type == torch.device(pdevice).type for item in r.values(True, True) ) else: - assert env.device == torch.device(edevice) - assert r.device == torch.device(edevice) + assert env.device.type == torch.device(edevice).type + assert r.device.type == torch.device(edevice).type assert all( - item.device == torch.device(edevice) for item in r.values(True, True) + item.device.type == torch.device(edevice).type for item in r.values(True, True) ) if parallel: - assert env.shared_tensordict_parent.device == torch.device(edevice) + assert env.shared_tensordict_parent.device.type == torch.device(edevice).type @pytest.mark.parametrize("num_parallel_env", [1, 10]) @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)]) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 2167e9db81c..29c361d4f86 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -944,10 +944,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: out = out.clone() else: out = out.to(device, non_blocking=True) - assert all( - val.device == device for val in - out.values(True, True) - ) return out @_check_start From f25b95744a9a9e53cd460b0a5516264658cf2c9a Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 28 Nov 2023 11:19:59 +0000 Subject: [PATCH 58/67] amend --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 29c361d4f86..8bbf8b30634 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -746,7 +746,7 @@ def _start_workers(self) -> None: self._workers = [] if self.shared_tensordict_parent.device.type == "cuda": func = _run_worker_pipe_cuda - self._cuda_stream = torch.cuda.Stream(self.device) + self._cuda_stream = torch.cuda.Stream(self.shared_tensordict_parent.device) self._cuda_events = [ torch.cuda.Event(interprocess=True) for _ in range(_num_workers) ] From 8899dbd31701463d67521a677fa73042018acfb1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 28 Nov 2023 14:50:47 +0000 Subject: [PATCH 59/67] amend --- test/test_env.py | 10 +- torchrl/envs/batched_envs.py | 210 +++-------------------------------- 2 files changed, 24 insertions(+), 196 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 0199d7be83c..aed4e07b0b7 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -381,16 +381,20 @@ def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad): assert env.device.type == torch.device(pdevice).type assert r.device.type == torch.device(pdevice).type assert all( - item.device.type == torch.device(pdevice).type for item in r.values(True, True) + item.device.type == torch.device(pdevice).type + for item in r.values(True, True) ) else: assert env.device.type == torch.device(edevice).type assert r.device.type == torch.device(edevice).type assert all( - item.device.type == torch.device(edevice).type for item in r.values(True, True) + item.device.type == torch.device(edevice).type + for item in r.values(True, True) ) if parallel: - assert env.shared_tensordict_parent.device.type == torch.device(edevice).type + assert ( + env.shared_tensordict_parent.device.type == torch.device(edevice).type + ) @pytest.mark.parametrize("num_parallel_env", [1, 10]) @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)]) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 8bbf8b30634..557b0281eaa 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -744,20 +744,11 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] - if self.shared_tensordict_parent.device.type == "cuda": - func = _run_worker_pipe_cuda - self._cuda_stream = torch.cuda.Stream(self.shared_tensordict_parent.device) - self._cuda_events = [ - torch.cuda.Event(interprocess=True) for _ in range(_num_workers) - ] - self._events = None - kwargs = [{"cuda_event": self._cuda_events[i]} for i in range(_num_workers)] - else: - func = _run_worker_pipe_shared_mem - self._cuda_stream = None - self._cuda_events = None - self._events = [ctx.Event() for _ in range(_num_workers)] - kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)] + func = _run_worker_pipe_shared_mem + self._cuda_stream = None + self._cuda_events = None + self._events = [ctx.Event() for _ in range(_num_workers)] + kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)] with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: @@ -860,17 +851,10 @@ def step_and_maybe_reset( for i in range(self.num_workers): self.parent_channels[i].send(("step_and_maybe_reset", None)) - if self._events is not None: - # CPU case - for i in range(self.num_workers): - event = self._events[i] - event.wait() - event.clear() - else: - # CUDA case - for i in range(self.num_workers): - event = self._cuda_events[i] - self._cuda_stream.wait_event(event) + for i in range(self.num_workers): + event = self._events[i] + event.wait() + event.clear() # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -915,17 +899,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for i in range(self.num_workers): self.parent_channels[i].send(("step", None)) - if self._events is not None: - # CPU case - for i in range(self.num_workers): - event = self._events[i] - event.wait() - event.clear() - else: - # CUDA case - for i in range(self.num_workers): - event = self._cuda_events[i] - self._cuda_stream.wait_event(event) + for i in range(self.num_workers): + event = self._events[i] + event.wait() + event.clear() # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -993,17 +970,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: channel.send(out) workers.append(i) - if self._events is not None: - # CPU case - for i in workers: - event = self._events[i] - event.wait() - event.clear() - else: - # CUDA case - for i in workers: - event = self._cuda_events[i] - self._cuda_stream.wait_event(event) + for i in workers: + event = self._events[i] + event.wait() + event.clear() selected_output_keys = self._selected_reset_keys_filt device = self.device @@ -1153,7 +1123,6 @@ def _run_worker_pipe_shared_mem( "env_fun_kwargs must be empty if an environment is passed to a process." ) env = env_fun - env_device = env.device i = -1 initialized = False @@ -1275,151 +1244,6 @@ def _run_worker_pipe_shared_mem( child_pipe.send(("_".join([cmd, "done"]), None)) -def _run_worker_pipe_cuda( - parent_pipe: connection.Connection, - child_pipe: connection.Connection, - env_fun: Union[EnvBase, Callable], - env_fun_kwargs: Dict[str, Any], - cuda_event: torch.cuda.Event = None, - shared_tensordict: TensorDictBase = None, - _selected_input_keys=None, - _selected_reset_keys=None, - _selected_step_keys=None, - has_lazy_inputs: bool = False, - verbose: bool = False, -) -> None: - parent_pipe.close() - pid = os.getpid() - if not isinstance(env_fun, EnvBase): - env = env_fun(**env_fun_kwargs) - else: - if env_fun_kwargs: - raise RuntimeError( - "env_fun_kwargs must be empty if an environment is passed to a process." - ) - env = env_fun - del env_fun - env_device = env.device - - stream = torch.cuda.Stream(env_device) - with torch.cuda.StreamContext(stream): - # we check if the devices mismatch. This tells us that the data need - # to be cast onto the right device before any op - env_device_cpu = env_device.type == "cpu" - i = -1 - initialized = False - - child_pipe.send("started") - - while True: - try: - cmd, data = child_pipe.recv() - except EOFError as err: - raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err - if cmd == "seed": - if not initialized: - raise RuntimeError("call 'init' before closing") - # torch.manual_seed(data) - # np.random.seed(data) - new_seed = env.set_seed(data[0], static_seed=data[1]) - child_pipe.send(("seeded", new_seed)) - - elif cmd == "init": - if verbose: - print(f"initializing {pid}") - if initialized: - raise RuntimeError("worker already initialized") - i = 0 - next_shared_tensordict = shared_tensordict.get("next") - shared_tensordict = shared_tensordict.clone(False) - del shared_tensordict["next"] - - if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): - raise RuntimeError( - "tensordict must be placed in shared memory (share_memory_() or memmap_())" - ) - initialized = True - - elif cmd == "reset": - if verbose: - print(f"resetting worker {pid}") - if not initialized: - raise RuntimeError("call 'init' before resetting") - cur_td = env._reset(tensordict=data) - shared_tensordict.update_(cur_td) - stream.record_event(cuda_event) - stream.synchronize() - - elif cmd == "step": - if not initialized: - raise RuntimeError("called 'init' before step") - i += 1 - env_input = shared_tensordict - next_td = env._step(env_input) - next_shared_tensordict.update_(next_td) - stream.record_event(cuda_event) - stream.synchronize() - - elif cmd == "step_and_maybe_reset": - if not initialized: - raise RuntimeError("called 'init' before step") - i += 1 - env_input = shared_tensordict - td, root_next_td = env.step_and_maybe_reset(env_input) - next_shared_tensordict.update_(td.get("next")) - shared_tensordict.update_(root_next_td) - stream.record_event(cuda_event) - stream.synchronize() - - elif cmd == "close": - del shared_tensordict, data - if not initialized: - raise RuntimeError("call 'init' before closing") - env.close() - del env - stream.record_event(cuda_event) - stream.synchronize() - child_pipe.close() - if verbose: - print(f"{pid} closed") - break - - elif cmd == "load_state_dict": - env.load_state_dict(data) - stream.record_event(cuda_event) - stream.synchronize() - - elif cmd == "state_dict": - state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) - msg = "state_dict" - child_pipe.send((msg, state_dict)) - - else: - err_msg = f"{cmd} from env" - try: - attr = getattr(env, cmd) - if callable(attr): - args, kwargs = data - args_replace = [] - for _arg in args: - if isinstance(_arg, str) and _arg == "_self": - continue - else: - args_replace.append(_arg) - result = attr(*args_replace, **kwargs) - else: - result = attr - except Exception as err: - raise AttributeError( - f"querying {err_msg} resulted in an error." - ) from err - if cmd not in ("to"): - child_pipe.send(("_".join([cmd, "done"]), result)) - else: - # don't send env through pipe - child_pipe.send(("_".join([cmd, "done"]), None)) - - def _update_cuda(t_dest, t_source): if t_source is None: return From 65c9debc5377eee4c7f3750d8996c3e8fd637362 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 28 Nov 2023 15:23:20 +0000 Subject: [PATCH 60/67] amend --- torchrl/envs/batched_envs.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 557b0281eaa..cd08edcc320 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -745,8 +745,10 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] func = _run_worker_pipe_shared_mem - self._cuda_stream = None - self._cuda_events = None + if self.shared_tensordict_parent.device.type == "cuda": + self.event = torch.cuda.Event() + else: + self.event = None self._events = [ctx.Event() for _ in range(_num_workers)] kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)] with clear_mpi_env_vars(): @@ -812,13 +814,9 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: ) for i, channel in enumerate(self.parent_channels): channel.send(("load_state_dict", state_dict[f"worker{i}"])) - if self._events is not None: - for event in self._events: - event.wait() - event.clear() - else: - for event in self._cuda_events: - self._cuda_stream.wait_event(event) + for event in self._events: + event.wait() + event.clear() @_check_start def step_and_maybe_reset( @@ -896,6 +894,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: self.shared_tensordict_parent.update_( tensordict.select(*self._env_input_keys, "next", strict=False) ) + if self.event is not None: + self.event.record() + self.event.synchronize() for i in range(self.num_workers): self.parent_channels[i].send(("step", None)) From 77c2d6bb5d993f4e4a7c89c97988f49b4137dbd7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 28 Nov 2023 15:25:44 +0000 Subject: [PATCH 61/67] amend --- torchrl/envs/batched_envs.py | 40 +++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index cd08edcc320..9ccdc1ffc57 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1008,14 +1008,11 @@ def _shutdown_workers(self) -> None: if self._verbose: print(f"closing {i}") channel.send(("close", None)) - if self._events is not None: - self._events[i].wait() - self._events[i].clear() - else: - for event in self._cuda_events: - self._cuda_stream.wait_event(event) + self._events[i].wait() + self._events[i].clear() del self.shared_tensordicts, self.shared_tensordict_parent + for channel in self.parent_channels: channel.close() for proc in self._workers: @@ -1114,6 +1111,11 @@ def _run_worker_pipe_shared_mem( has_lazy_inputs: bool = False, verbose: bool = False, ) -> None: + device = shared_tensordict.device + if device.type == "cuda": + event = torch.cuda.Event() + else: + event = None parent_pipe.close() pid = os.getpid() if not isinstance(env_fun, EnvBase): @@ -1124,6 +1126,7 @@ def _run_worker_pipe_shared_mem( "env_fun_kwargs must be empty if an environment is passed to a process." ) env = env_fun + del env_fun i = -1 initialized = False @@ -1168,6 +1171,9 @@ def _run_worker_pipe_shared_mem( shared_tensordict.update_( cur_td.select(*_selected_reset_keys, strict=False) ) + if event is not None: + event.record() + event.synchronize() mp_event.set() elif cmd == "step": @@ -1177,6 +1183,9 @@ def _run_worker_pipe_shared_mem( env_input = shared_tensordict next_td = env._step(env_input) next_shared_tensordict.update_(next_td) + if event is not None: + event.record() + event.synchronize() mp_event.set() elif cmd == "step_and_maybe_reset": @@ -1187,16 +1196,19 @@ def _run_worker_pipe_shared_mem( td, root_next_td = env.step_and_maybe_reset(env_input) next_shared_tensordict.update_(td.get("next")) shared_tensordict.update_(root_next_td) + if event is not None: + event.record() + event.synchronize() mp_event.set() - elif cmd == "step_and_maybe_reset": - if not initialized: - raise RuntimeError("called 'init' before step") - i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) - next_shared_tensordict.update_(td.get("next")) - root_shared_tensordict.update_(root_next_td) - mp_event.set() + # elif cmd == "step_and_maybe_reset": + # if not initialized: + # raise RuntimeError("called 'init' before step") + # i += 1 + # td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) + # next_shared_tensordict.update_(td.get("next")) + # root_shared_tensordict.update_(root_next_td) + # mp_event.set() elif cmd == "close": del shared_tensordict, data From 7928744c01cdf2764c21748eed1d148c8d0435ec Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 28 Nov 2023 15:27:01 +0000 Subject: [PATCH 62/67] amend --- torchrl/envs/batched_envs.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9ccdc1ffc57..decc83b496a 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1201,15 +1201,6 @@ def _run_worker_pipe_shared_mem( event.synchronize() mp_event.set() - # elif cmd == "step_and_maybe_reset": - # if not initialized: - # raise RuntimeError("called 'init' before step") - # i += 1 - # td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) - # next_shared_tensordict.update_(td.get("next")) - # root_shared_tensordict.update_(root_next_td) - # mp_event.set() - elif cmd == "close": del shared_tensordict, data if not initialized: From e7fda36133dcdfe92f87671457c174c5a2aecee2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Nov 2023 17:10:13 +0000 Subject: [PATCH 63/67] amend --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index decc83b496a..b442e5be119 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1195,7 +1195,7 @@ def _run_worker_pipe_shared_mem( env_input = shared_tensordict td, root_next_td = env.step_and_maybe_reset(env_input) next_shared_tensordict.update_(td.get("next")) - shared_tensordict.update_(root_next_td) + root_shared_tensordict.update_(root_next_td) if event is not None: event.record() event.synchronize() From fb9a03a1cff3dbffa9e98f6bba29ebf97b1e040b Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Nov 2023 17:31:14 +0000 Subject: [PATCH 64/67] amend --- examples/dreamer/dreamer_utils.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index fba4247e2a7..d2b1d85d05d 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -147,6 +147,7 @@ def transformed_env_constructor( state_dim_gsde: Optional[int] = None, batch_dims: Optional[int] = 0, obs_norm_state_dict: Optional[dict] = None, +ignore_device: bool=False, ) -> Union[Callable, EnvCreator]: """ Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. @@ -179,6 +180,7 @@ def transformed_env_constructor( it should be set to 1 (or the number of dims of the batch). obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the environment + ignore_device (bool, optional): if True, the device is ignored. """ def make_transformed_env(**kwargs) -> TransformedEnv: @@ -189,14 +191,17 @@ def make_transformed_env(**kwargs) -> TransformedEnv: from_pixels = cfg.from_pixels if custom_env is None and custom_env_maker is None: - if isinstance(cfg.collector_device, str): - device = cfg.collector_device - elif isinstance(cfg.collector_device, Sequence): - device = cfg.collector_device[0] + if not ignore_device: + if isinstance(cfg.collector_device, str): + device = cfg.collector_device + elif isinstance(cfg.collector_device, Sequence): + device = cfg.collector_device[0] + else: + raise ValueError( + "collector_device must be either a string or a sequence of strings" + ) else: - raise ValueError( - "collector_device must be either a string or a sequence of strings" - ) + device = None env_kwargs = { "env_name": env_name, "device": device, @@ -252,19 +257,19 @@ def parallel_env_constructor( kwargs: keyword arguments for the `transformed_env_constructor` method. """ batch_transform = cfg.batch_transform + kwargs.update({"cfg": cfg, "use_env_creator": True}) if cfg.env_per_collector == 1: - kwargs.update({"cfg": cfg, "use_env_creator": True}) make_transformed_env = transformed_env_constructor(**kwargs) return make_transformed_env - kwargs.update({"cfg": cfg, "use_env_creator": True}) make_transformed_env = transformed_env_constructor( - return_transformed_envs=not batch_transform, **kwargs + return_transformed_envs=not batch_transform, ignore_device=True, **kwargs ) parallel_env = ParallelEnv( num_workers=cfg.env_per_collector, create_env_fn=make_transformed_env, create_env_kwargs=None, pin_memory=cfg.pin_memory, + device=cfg.collector_device, ) if batch_transform: kwargs.update( From d73ca2239ca9734244ec351ad8adec9a92c55c8c Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Nov 2023 17:35:48 +0000 Subject: [PATCH 65/67] amend --- examples/dreamer/dreamer_utils.py | 4 +++- torchrl/envs/batched_envs.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index d2b1d85d05d..4c2186e0501 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -262,7 +262,9 @@ def parallel_env_constructor( make_transformed_env = transformed_env_constructor(**kwargs) return make_transformed_env make_transformed_env = transformed_env_constructor( - return_transformed_envs=not batch_transform, ignore_device=True, **kwargs + return_transformed_envs=not batch_transform, + # ignore_device=True, + **kwargs ) parallel_env = ParallelEnv( num_workers=cfg.env_per_collector, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b442e5be119..0851ed15fd4 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1021,6 +1021,7 @@ def _shutdown_workers(self) -> None: del self.parent_channels self._cuda_events = None self._events = None + self.event = None @_check_start def set_seed( From c4d4c6bce3cc55e2d95edc17cb2c010ebb7c6769 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Nov 2023 17:37:16 +0000 Subject: [PATCH 66/67] amend --- examples/dreamer/dreamer_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 4c2186e0501..385e4a53aab 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -147,7 +147,7 @@ def transformed_env_constructor( state_dim_gsde: Optional[int] = None, batch_dims: Optional[int] = 0, obs_norm_state_dict: Optional[dict] = None, -ignore_device: bool=False, + ignore_device: bool = False, ) -> Union[Callable, EnvCreator]: """ Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. @@ -262,9 +262,7 @@ def parallel_env_constructor( make_transformed_env = transformed_env_constructor(**kwargs) return make_transformed_env make_transformed_env = transformed_env_constructor( - return_transformed_envs=not batch_transform, - # ignore_device=True, - **kwargs + return_transformed_envs=not batch_transform, ignore_device=True, **kwargs ) parallel_env = ParallelEnv( num_workers=cfg.env_per_collector, From b2840b0beec701c0be571f641b99f94e3034b332 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Nov 2023 21:12:40 +0000 Subject: [PATCH 67/67] doc --- torchrl/envs/batched_envs.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 0851ed15fd4..ac0a136c7f9 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -732,6 +732,32 @@ class ParallelEnv(_BatchedEnv): """ __doc__ += _BatchedEnv.__doc__ + __doc__ += """ + + .. note:: + The choice of the devices where ParallelEnv needs to be executed can + drastically influence its performance. The rule of thumbs is: + + - If the base environment (backend, e.g., Gym) is executed on CPU, the + sub-environments should be executed on CPU and the data should be + passed via shared physical memory. + - If the base environment is (or can be) executed on CUDA, the sub-environments + should be placed on CUDA too. + - If a CUDA device is available and the policy is to be executed on CUDA, + the ParallelEnv device should be set to CUDA. + + Therefore, supposing a CUDA device is available, we have the following scenarios: + + >>> # The sub-envs are executed on CPU, but the policy is on GPU + >>> env = ParallelEnv(N, MyEnv(..., device="cpu"), device="cuda") + >>> # The sub-envs are executed on CUDA + >>> env = ParallelEnv(N, MyEnv(..., device="cuda"), device="cuda") + >>> # this will create the exact same environment + >>> env = ParallelEnv(N, MyEnv(..., device="cuda")) + >>> # If no cuda device is available + >>> env = ParallelEnv(N, MyEnv(..., device="cpu")) + + """ def _start_workers(self) -> None: from torchrl.envs.env_creator import EnvCreator