From 430f1bde95bc4f90ec70f791a72315443ca36b06 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 12:34:08 -0400 Subject: [PATCH] [Feature] Partial steps in batched envs ghstack-source-id: a1a69e55cddf10290cb59dc1a3c6136bd257368a Pull Request resolved: https://github.com/pytorch/rl/pull/2377 --- test/mocking_classes.py | 5 +- test/test_env.py | 93 +++++++++++++ torchrl/envs/batched_envs.py | 254 ++++++++++++++++++++++++++--------- 3 files changed, 290 insertions(+), 62 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 795fda399de..4d86d8ec0ac 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1038,7 +1038,10 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) - self.count += action.to(dtype=torch.int, device=self.device) + self.count += action.to( + dtype=torch.int, + device=self.action_spec.device if self.device is None else self.device, + ) tensordict = TensorDict( source={ "observation": self.count.clone(), diff --git a/test/test_env.py b/test/test_env.py index b945498573d..bbec29a0d78 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import contextlib import functools import gc import os.path @@ -3340,6 +3341,98 @@ def test_pendulum_env(self): assert r.shape == torch.Size((5, 10)) +@pytest.mark.parametrize("device", [None, *get_default_devices()]) +@pytest.mark.parametrize("env_device", [None, *get_default_devices()]) +class TestPartialSteps: + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_parallel_partial_steps( + self, use_buffers, device, env_device, maybe_fork_ParallelEnv + ): + with torch.device(device) if device is not None else contextlib.nullcontext(): + penv = maybe_fork_ParallelEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.action_spec.one()) + td = penv.step(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_parallel_partial_step_and_maybe_reset( + self, use_buffers, device, env_device, maybe_fork_ParallelEnv + ): + with torch.device(device) if device is not None else contextlib.nullcontext(): + penv = maybe_fork_ParallelEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.action_spec.one()) + td, tdreset = penv.step_and_maybe_reset(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_serial_partial_steps(self, use_buffers, device, env_device): + with torch.device(device) if device is not None else contextlib.nullcontext(): + penv = SerialEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.action_spec.one()) + td = penv.step(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device): + with torch.device(device) if device is not None else contextlib.nullcontext(): + penv = SerialEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.action_spec.one()) + td = penv.step(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f915af52bcc..73ecdba64a9 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1031,12 +1031,18 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if out_tds is not None: out_tds[i] = _td + device = self.device if not self._use_buffers: result = LazyStackedTensorDict.maybe_dense_stack(out_tds) + if result.device != device: + if device is None: + result = result.clear_device_() + else: + result = result.to(device, non_blocking=self.non_blocking) + self._sync_w2m() return result selected_output_keys = self._selected_reset_keys_filt - device = self.device # select + clone creates 2 tds, but we can create one only def select_and_clone(name, tensor): @@ -1066,18 +1072,29 @@ def _step( self, tensordict: TensorDict, ) -> TensorDict: - tensordict_in = tensordict.clone(False) + partial_steps = tensordict.get("_step", None) + tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None + if partial_steps is not None: + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero().squeeze().tolist() + tensordict_in = tensordict + else: + workers_range = range(self.num_workers) + tensordict_in = tensordict.clone(False) + # if self._use_buffers: + # shared_tensordict_parent = self.shared_tensordict_parent + data_in = [] - for i in range(self.num_workers): + for i, td_ in zip(workers_range, tensordict_in): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. env_device = self._envs[i].device if env_device != self.device and env_device is not None: - data_in.append( - tensordict_in[i].to(env_device, non_blocking=self.non_blocking) - ) + data_in.append(td_.to(env_device, non_blocking=self.non_blocking)) else: - data_in.append(tensordict_in[i]) + data_in.append(td_) self._sync_m2w() out_tds = None @@ -1086,7 +1103,7 @@ def _step( if self._use_buffers: next_td = self.shared_tensordict_parent.get("next") - for i, _data_in in enumerate(data_in): + for i, _data_in in zip(workers_range, data_in): out_td = self._envs[i]._step(_data_in) next_td[i].update_( out_td, @@ -1095,32 +1112,43 @@ def _step( ) if out_tds is not None: out_tds.append(out_td) - else: - for i, _data_in in enumerate(data_in): - out_td = self._envs[i]._step(_data_in) - out_tds.append(out_td) - return LazyStackedTensorDict.maybe_dense_stack(out_tds) - # We must pass a clone of the tensordict, as the values of this tensordict - # will be modified in-place at further steps - device = self.device + # We must pass a clone of the tensordict, as the values of this tensordict + # will be modified in-place at further steps + device = self.device - def select_and_clone(name, tensor): - if name in self._selected_step_keys: - return tensor.clone() + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.clone() - out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) - if out_tds is not None: - out.update( - LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys + if partial_steps is not None: + next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range]) + out = next_td.named_apply( + select_and_clone, nested_keys=True, filter_empty=True ) + if out_tds is not None: + out.update( + LazyStackedTensorDict(*out_tds), + keys_to_update=self._non_tensor_keys, + ) + + if out.device != device: + if device is None: + out = out.clear_device_() + elif out.device != device: + out = out.to(device, non_blocking=self.non_blocking) + self._sync_w2m() + else: + for i, _data_in in zip(workers_range, data_in): + out_td = self._envs[i]._step(_data_in) + out_tds.append(out_td) + out = LazyStackedTensorDict.maybe_dense_stack(out_tds) + + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result - if out.device != device: - if device is None: - out = out.clear_device_() - elif out.device != device: - out = out.to(device, non_blocking=self.non_blocking) - self._sync_w2m() return out def __getattr__(self, attr: str) -> Any: @@ -1435,20 +1463,29 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: + partial_steps = tensordict.get("_step", None) + tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None + if partial_steps is not None: + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero().squeeze().tolist() + else: + workers_range = range(self.num_workers) td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1) - for i in range(td.shape[0]): + for i in workers_range: # We send the same td multiple times as it is in shared mem and we just need to index it # in each process. # If we don't do this, we need to unbind it but then the custom pickler will require # some extra metadata to be collected. self.parent_channels[i].send(("step_and_maybe_reset", (td, i))) - results = [None] * self.num_workers + results = [None] * len(workers_range) consumed_indices = [] - events = set(range(self.num_workers)) - while len(consumed_indices) < self.num_workers: + events = set(workers_range) + while len(consumed_indices) < len(workers_range): for i in list(events): if self._events[i].is_set(): results[i] = self.parent_channels[i].recv() @@ -1457,9 +1494,14 @@ def _step_and_maybe_reset_no_buffers( events.discard(i) out_next, out_root = zip(*(future for future in results)) - return TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack( + out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack( out_root ) + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result + return out @torch.no_grad() @_check_start @@ -1471,6 +1513,41 @@ def step_and_maybe_reset( # return self._step_and_maybe_reset_no_buffers(tensordict) return super().step_and_maybe_reset(tensordict) + partial_steps = tensordict.get("_step", None) + tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None + if partial_steps is not None: + workers_range = partial_steps.nonzero().squeeze().tolist() + shared_tensordict_parent = TensorDict.lazy_stack( + [self.shared_tensordict_parent[i] for i in workers_range] + ) + next_td = TensorDict.lazy_stack( + [self._shared_tensordict_parent_next[i] for i in workers_range] + ) + tensordict_ = TensorDict.lazy_stack( + [self._shared_tensordict_parent_root[i] for i in workers_range] + ) + if self.shared_tensordict_parent.device is None: + tensordict = tensordict._fast_apply( + lambda x, y: x[partial_steps].to(y.device) + if y is not None + else x[partial_steps], + self.shared_tensordict_parent, + default=None, + device=None, + batch_size=shared_tensordict_parent.shape, + ) + else: + tensordict = tensordict[partial_steps].to( + self.shared_tensordict_parent.device + ) + else: + workers_range = range(self.num_workers) + shared_tensordict_parent = self.shared_tensordict_parent + next_td = self._shared_tensordict_parent_next + tensordict_ = self._shared_tensordict_parent_root + # We must use the in_keys and nothing else for the following reasons: # - efficiency: copying all the keys will in practice mean doing a lot # of writing operations since the input tensordict may (and often will) @@ -1479,7 +1556,7 @@ def step_and_maybe_reset( # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. - self.shared_tensordict_parent.update_( + shared_tensordict_parent.update_( tensordict, keys_to_update=self._env_input_keys, non_blocking=self.non_blocking, @@ -1489,46 +1566,41 @@ def step_and_maybe_reset( # if we have input "next" data (eg, RNNs which pass the next state) # the sub-envs will need to process them through step_and_maybe_reset. # We keep track of which keys are present to let the worker know what - # should be passd to the env (we don't want to pass done states for instance) + # should be passed to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) - data = [ - {"next_td_passthrough_keys": next_td_keys} - for _ in range(self.num_workers) - ] - self.shared_tensordict_parent.get("next").update_( + data = [{"next_td_passthrough_keys": next_td_keys} for _ in workers_range] + shared_tensordict_parent.get("next").update_( next_td_passthrough, non_blocking=self.non_blocking ) else: # next_td_keys = None - data = [{} for _ in range(self.num_workers)] + data = [{} for _ in workers_range] if self._non_tensor_keys: - for i in range(self.num_workers): + for i in workers_range: data[i]["non_tensor_data"] = tensordict[i].select( *self._non_tensor_keys, strict=False ) self._sync_m2w() - for i in range(self.num_workers): - self.parent_channels[i].send(("step_and_maybe_reset", data[i])) + for i, _data in zip(workers_range, data): + self.parent_channels[i].send(("step_and_maybe_reset", _data)) - for i in range(self.num_workers): + for i in workers_range: event = self._events[i] event.wait(self._timeout) event.clear() if self._non_tensor_keys: non_tensor_tds = [] - for i in range(self.num_workers): + for i in workers_range: msg, non_tensor_td = self.parent_channels[i].recv() non_tensor_tds.append(non_tensor_td) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - next_td = self._shared_tensordict_parent_next - tensordict_ = self._shared_tensordict_parent_root device = self.device - if self.shared_tensordict_parent.device == device: + if shared_tensordict_parent.device == device: next_td = next_td.clone() tensordict_ = tensordict_.clone() elif device is not None: @@ -1558,22 +1630,48 @@ def step_and_maybe_reset( keys_to_update=[("next", key) for key in self._non_tensor_keys], ) tensordict_.update(non_tensor_tds, keys_to_update=self._non_tensor_keys) + + if partial_steps is not None: + result = tensordict.new_zeros(tensordict_save.shape) + result_ = tensordict_.new_zeros(tensordict_save.shape) + result[partial_steps] = tensordict + result_[partial_steps] = tensordict_ + return result, result_ + return tensordict, tensordict_ def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: + partial_steps = tensordict.get("_step", None) + tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None + if partial_steps is not None: + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero().squeeze().tolist() + else: + workers_range = range(self.num_workers) + data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1) - for i, local_data in enumerate(data.unbind(0)): + for i, local_data in zip(workers_range, data.unbind(0)): self.parent_channels[i].send(("step", local_data)) # for i in range(data.shape[0]): # self.parent_channels[i].send(("step", (data, i))) out_tds = [] - for i, channel in enumerate(self.parent_channels): + for i in workers_range: + channel = self.parent_channels[i] self._events[i].wait() td = channel.recv() out_tds.append(td) - return LazyStackedTensorDict.maybe_dense_stack(out_tds) + out = LazyStackedTensorDict.maybe_dense_stack(out_tds) + if self.device is not None and out.device != self.device: + out = out.to(self.device, non_blocking=self.non_blocking) + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result + return out @torch.no_grad() @_check_start @@ -1588,8 +1686,34 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. + partial_steps = tensordict.get("_step", None) + tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None + if partial_steps is not None: + workers_range = partial_steps.nonzero().squeeze().tolist() + shared_tensordict_parent = TensorDict.lazy_stack( + [self.shared_tensordicts[i] for i in workers_range] + ) + if self.shared_tensordict_parent.device is None: + tensordict = tensordict._fast_apply( + lambda x, y: x[partial_steps].to(y.device) + if y is not None + else x[partial_steps], + self.shared_tensordict_parent, + default=None, + device=None, + batch_size=shared_tensordict_parent.shape, + ) + else: + tensordict = tensordict[partial_steps].to( + self.shared_tensordict_parent.device + ) + else: + workers_range = range(self.num_workers) + shared_tensordict_parent = self.shared_tensordict_parent - self.shared_tensordict_parent.update_( + shared_tensordict_parent.update_( tensordict, keys_to_update=list(self._env_input_keys), non_blocking=self.non_blocking, @@ -1605,14 +1729,14 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: {"next_td_passthrough_keys": next_td_keys} for _ in range(self.num_workers) ] - self.shared_tensordict_parent.get("next").update_( + shared_tensordict_parent.get("next").update_( next_td_passthrough, non_blocking=self.non_blocking ) else: data = [{} for _ in range(self.num_workers)] if self._non_tensor_keys: - for i in range(self.num_workers): + for i in workers_range: data[i]["non_tensor_data"] = tensordict[i].select( *self._non_tensor_keys, strict=False ) @@ -1622,23 +1746,23 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self.event is not None: self.event.record() self.event.synchronize() - for i in range(self.num_workers): + for i in workers_range: self.parent_channels[i].send(("step", data[i])) - for i in range(self.num_workers): + for i in workers_range: event = self._events[i] event.wait(self._timeout) event.clear() if self._non_tensor_keys: non_tensor_tds = [] - for i in range(self.num_workers): + for i in workers_range: msg, non_tensor_td = self.parent_channels[i].recv() non_tensor_tds.append(non_tensor_td) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - next_td = self.shared_tensordict_parent.get("next") + next_td = shared_tensordict_parent.get("next") device = self.device if next_td.device != device and device is not None: @@ -1665,6 +1789,10 @@ def select_and_clone(name, tensor): keys_to_update=self._non_tensor_keys, ) self._sync_w2m() + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result return out def _reset_no_buffers( @@ -1698,7 +1826,11 @@ def _reset_no_buffers( self._events[i].wait() td = channel.recv() out_tds[i] = td - return LazyStackedTensorDict.maybe_dense_stack(out_tds) + result = LazyStackedTensorDict.maybe_dense_stack(out_tds) + device = self.device + if device is not None and result.device != device: + return result.to(self.device, non_blocking=self.non_blocking) + return result @torch.no_grad() @_check_start