diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 71b7a481ce0..246c5ee15f0 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -76,12 +76,12 @@ def make(envname=envname, gym_backend=gym_backend): # regular parallel env for device in avail_devices: - def make(envname=envname, gym_backend=gym_backend, device=device): + def make(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") # env_make = EnvCreator(make) - penv = ParallelEnv(num_workers, EnvCreator(make)) + penv = ParallelEnv(num_workers, EnvCreator(make), device=device) with torch.inference_mode(): # warmup penv.rollout(2) @@ -103,13 +103,13 @@ def make(envname=envname, gym_backend=gym_backend, device=device): for device in avail_devices: - def make(envname=envname, gym_backend=gym_backend, device=device): + def make(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") env_make = EnvCreator(make) # penv = SerialEnv(num_workers, env_make) - penv = ParallelEnv(num_workers, env_make) + penv = ParallelEnv(num_workers, env_make, device=device) collector = SyncDataCollector( penv, RandomPolicy(penv.action_spec), @@ -164,14 +164,14 @@ def make_env( for device in avail_devices: # async collector # + torchrl parallel env - def make_env( - envname=envname, gym_backend=gym_backend, device=device - ): + def make_env(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") penv = ParallelEnv( - num_workers // num_collectors, EnvCreator(make_env) + num_workers // num_collectors, + EnvCreator(make_env), + device=device, ) collector = MultiaSyncDataCollector( [penv] * num_collectors, @@ -206,10 +206,9 @@ def make_env( envname=envname, num_workers=num_workers, gym_backend=gym_backend, - device=device, ): with set_gym_backend(gym_backend): - penv = GymEnv(envname, num_envs=num_workers, device=device) + penv = GymEnv(envname, num_envs=num_workers, device="cpu") return penv penv = EnvCreator( @@ -247,14 +246,14 @@ def make_env( for device in avail_devices: # sync collector # + torchrl parallel env - def make_env( - envname=envname, gym_backend=gym_backend, device=device - ): + def make_env(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") penv = ParallelEnv( - num_workers // num_collectors, EnvCreator(make_env) + num_workers // num_collectors, + EnvCreator(make_env), + device=device, ) collector = MultiSyncDataCollector( [penv] * num_collectors, @@ -289,10 +288,9 @@ def make_env( envname=envname, num_workers=num_workers, gym_backend=gym_backend, - device=device, ): with set_gym_backend(gym_backend): - penv = GymEnv(envname, num_envs=num_workers, device=device) + penv = GymEnv(envname, num_envs=num_workers, device="cpu") return penv penv = EnvCreator( diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index fba4247e2a7..385e4a53aab 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -147,6 +147,7 @@ def transformed_env_constructor( state_dim_gsde: Optional[int] = None, batch_dims: Optional[int] = 0, obs_norm_state_dict: Optional[dict] = None, + ignore_device: bool = False, ) -> Union[Callable, EnvCreator]: """ Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. @@ -179,6 +180,7 @@ def transformed_env_constructor( it should be set to 1 (or the number of dims of the batch). obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the environment + ignore_device (bool, optional): if True, the device is ignored. """ def make_transformed_env(**kwargs) -> TransformedEnv: @@ -189,14 +191,17 @@ def make_transformed_env(**kwargs) -> TransformedEnv: from_pixels = cfg.from_pixels if custom_env is None and custom_env_maker is None: - if isinstance(cfg.collector_device, str): - device = cfg.collector_device - elif isinstance(cfg.collector_device, Sequence): - device = cfg.collector_device[0] + if not ignore_device: + if isinstance(cfg.collector_device, str): + device = cfg.collector_device + elif isinstance(cfg.collector_device, Sequence): + device = cfg.collector_device[0] + else: + raise ValueError( + "collector_device must be either a string or a sequence of strings" + ) else: - raise ValueError( - "collector_device must be either a string or a sequence of strings" - ) + device = None env_kwargs = { "env_name": env_name, "device": device, @@ -252,19 +257,19 @@ def parallel_env_constructor( kwargs: keyword arguments for the `transformed_env_constructor` method. """ batch_transform = cfg.batch_transform + kwargs.update({"cfg": cfg, "use_env_creator": True}) if cfg.env_per_collector == 1: - kwargs.update({"cfg": cfg, "use_env_creator": True}) make_transformed_env = transformed_env_constructor(**kwargs) return make_transformed_env - kwargs.update({"cfg": cfg, "use_env_creator": True}) make_transformed_env = transformed_env_constructor( - return_transformed_envs=not batch_transform, **kwargs + return_transformed_envs=not batch_transform, ignore_device=True, **kwargs ) parallel_env = ParallelEnv( num_workers=cfg.env_per_collector, create_env_fn=make_transformed_env, create_env_kwargs=None, pin_memory=cfg.pin_memory, + device=cfg.collector_device, ) if batch_transform: kwargs.update( diff --git a/test/test_env.py b/test/test_env.py index 6cee7f545d7..aed4e07b0b7 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -354,6 +354,48 @@ def test_mb_env_batch_lock(self, device, seed=0): class TestParallel: + @pytest.mark.skipif( + not torch.cuda.device_count(), reason="No cuda device detected." + ) + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("hetero", [True, False]) + @pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"]) + @pytest.mark.parametrize("edevice", ["cpu", "cuda"]) + @pytest.mark.parametrize("bwad", [True, False]) + def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad): + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + if not hetero: + env = cls( + 2, lambda: ContinuousActionVecMockEnv(device=edevice), device=pdevice + ) + else: + env1 = lambda: ContinuousActionVecMockEnv(device=edevice) + env2 = lambda: TransformedEnv(ContinuousActionVecMockEnv(device=edevice)) + env = cls(2, [env1, env2], device=pdevice) + + r = env.rollout(2, break_when_any_done=bwad) + if pdevice is not None: + assert env.device.type == torch.device(pdevice).type + assert r.device.type == torch.device(pdevice).type + assert all( + item.device.type == torch.device(pdevice).type + for item in r.values(True, True) + ) + else: + assert env.device.type == torch.device(edevice).type + assert r.device.type == torch.device(edevice).type + assert all( + item.device.type == torch.device(edevice).type + for item in r.values(True, True) + ) + if parallel: + assert ( + env.shared_tensordict_parent.device.type == torch.device(edevice).type + ) + @pytest.mark.parametrize("num_parallel_env", [1, 10]) @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)]) def test_env_with_batch_size(self, num_parallel_env, env_batch_size): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f0e132eb092..ac0a136c7f9 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -122,11 +122,16 @@ class _BatchedEnv(EnvBase): memmap (bool): whether or not the returned tensordict will be placed in memory map. policy_proof (callable, optional): if provided, it'll be used to get the list of tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc. - device (str, int, torch.device): for consistency, this argument is kept. However this - argument should not be passed, as the device will be inferred from the environments. - It is assumed that all environments will run on the same device as a common shared - tensordict will be used to pass data from process to process. The device can be - changed after instantiation using :obj:`env.to(device)`. + device (str, int, torch.device): The device of the batched environment can be passed. + If not, it is inferred from the env. In this case, it is assumed that + the device of all environments match. If it is provided, it can differ + from the sub-environment device(s). In that case, the data will be + automatically cast to the appropriate device during collection. + This can be used to speed up collection in case casting to device + introduces an overhead (eg, numpy-based environents etc.): by using + a ``"cuda"`` device for the batched environment but a ``"cpu"`` + device for the nested environments, one can keep the overhead to a + minimum. num_threads (int, optional): number of threads for this process. Defaults to the number of workers. This parameter has no effect for the :class:`~SerialEnv` class. @@ -162,14 +167,7 @@ def __init__( num_threads: int = None, num_sub_threads: int = 1, ): - if device is not None: - raise ValueError( - "Device setting for batched environment can't be done at initialization. " - "The device will be inferred from the constructed environment. " - "It can be set through the `to(device)` method." - ) - - super().__init__(device=None) + super().__init__(device=device) self.is_closed = True if num_threads is None: num_threads = num_workers + 1 # 1 more thread for this proc @@ -218,7 +216,7 @@ def __init__( "memmap and shared memory are mutually exclusive features." ) self._batch_size = None - self._device = None + self._device = torch.device(device) if device is not None else device self._dummy_env_str = None self._seeds = None self.__dict__["_input_spec"] = None @@ -273,7 +271,9 @@ def _set_properties(self): self._properties_set = True if self._single_task: self._batch_size = meta_data.batch_size - device = self._device = meta_data.device + device = meta_data.device + if self._device is None: + self._device = device input_spec = meta_data.specs["input_spec"].to(device) output_spec = meta_data.specs["output_spec"].to(device) @@ -289,8 +289,18 @@ def _set_properties(self): self._batch_locked = meta_data.batch_locked else: self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size]) - device = self._device = meta_data[0].device - # TODO: check that all action_spec and reward spec match (issue #351) + devices = set() + for _meta_data in meta_data: + device = _meta_data.device + devices.add(device) + if self._device is None: + if len(devices) > 1: + raise ValueError( + f"The device wasn't passed to {type(self)}, but more than one device was found in the sub-environments. " + f"Please indicate a device to be used for collection." + ) + device = list(devices)[0] + self._device = device input_spec = [] for md in meta_data: @@ -413,7 +423,7 @@ def _create_td(self) -> None: *(unravel_key(("next", key)) for key in self._env_output_keys), strict=False, ) - self.shared_tensordict_parent = shared_tensordict_parent.to(self.device) + self.shared_tensordict_parent = shared_tensordict_parent else: # Multi-task: we share tensordict that *may* have different keys shared_tensordict_parent = [ @@ -421,7 +431,7 @@ def _create_td(self) -> None: *self._selected_keys, *(unravel_key(("next", key)) for key in self._env_output_keys), strict=False, - ).to(self.device) + ) for tensordict in shared_tensordict_parent ] shared_tensordict_parent = torch.stack( @@ -440,13 +450,11 @@ def _create_td(self) -> None: # Multi-task: we share tensordict that *may* have different keys # LazyStacked already stores this so we don't need to do anything self.shared_tensordicts = self.shared_tensordict_parent - if self.device.type == "cpu": + if self.shared_tensordict_parent.device.type == "cpu": if self._share_memory: - for td in self.shared_tensordicts: - td.share_memory_() + self.shared_tensordict_parent.share_memory_() elif self._memmap: - for td in self.shared_tensordicts: - td.memmap_() + self.shared_tensordict_parent.memmap_() else: if self._share_memory: self.shared_tensordict_parent.share_memory_() @@ -483,7 +491,6 @@ def close(self) -> None: self.__dict__["_input_spec"] = None self.__dict__["_output_spec"] = None self._properties_set = False - self.event = None self._shutdown_workers() self.is_closed = True @@ -507,11 +514,6 @@ def to(self, device: DEVICE_TYPING): if device == self.device: return self self._device = device - self.meta_data = ( - self.meta_data.to(device) - if self._single_task - else [meta_data.to(device) for meta_data in self.meta_data] - ) if not self.is_closed: warn( "Casting an open environment to another device requires closing and re-opening it. " @@ -543,7 +545,7 @@ def _start_workers(self) -> None: for idx in range(_num_workers): env = self.create_env_fn[idx](**self.create_env_kwargs[idx]) - self._envs.append(env.to(self.device)) + self._envs.append(env) self.is_closed = False @_check_start @@ -603,29 +605,39 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if tensordict_.is_empty(): tensordict_ = None else: - # reset will do modifications in-place. We want the original - # tensorict to be unchaned, so we clone it - tensordict_ = tensordict_.clone(False) + env_device = _env.device + if env_device != self.device: + tensordict_ = tensordict_.to(env_device) + else: + tensordict_ = tensordict_.clone(False) else: tensordict_ = None + _td = _env.reset(tensordict=tensordict_, **kwargs) self.shared_tensordicts[i].update_( _td.select(*self._selected_reset_keys_filt, strict=False) ) selected_output_keys = self._selected_reset_keys_filt + device = self.device if self._single_task: # select + clone creates 2 tds, but we can create one only out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in selected_output_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) - return out + _set_single_key( + self.shared_tensordict_parent, out, key, clone=True, device=device + ) else: - return self.shared_tensordict_parent.select( + out = self.shared_tensordict_parent.select( *selected_output_keys, strict=False, - ).clone() + ) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) + return out def _reset_proc_data(self, tensordict, tensordict_reset): # since we call `reset` directly, all the postproc has been completed @@ -643,19 +655,29 @@ def _step( for i in range(self.num_workers): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. - out_td = self._envs[i]._step(tensordict_in[i]) + env_device = self._envs[i].device + if env_device != self.device: + data_in = tensordict_in[i].to(env_device, non_blocking=True) + else: + data_in = tensordict_in[i] + out_td = self._envs[i]._step(data_in) next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps + device = self.device if self._single_task: out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True, device=device) else: # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() + out = next_td.select(*self._selected_step_keys, strict=False) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) return out def __getattr__(self, attr: str) -> Any: @@ -710,6 +732,32 @@ class ParallelEnv(_BatchedEnv): """ __doc__ += _BatchedEnv.__doc__ + __doc__ += """ + + .. note:: + The choice of the devices where ParallelEnv needs to be executed can + drastically influence its performance. The rule of thumbs is: + + - If the base environment (backend, e.g., Gym) is executed on CPU, the + sub-environments should be executed on CPU and the data should be + passed via shared physical memory. + - If the base environment is (or can be) executed on CUDA, the sub-environments + should be placed on CUDA too. + - If a CUDA device is available and the policy is to be executed on CUDA, + the ParallelEnv device should be set to CUDA. + + Therefore, supposing a CUDA device is available, we have the following scenarios: + + >>> # The sub-envs are executed on CPU, but the policy is on GPU + >>> env = ParallelEnv(N, MyEnv(..., device="cpu"), device="cuda") + >>> # The sub-envs are executed on CUDA + >>> env = ParallelEnv(N, MyEnv(..., device="cuda"), device="cuda") + >>> # this will create the exact same environment + >>> env = ParallelEnv(N, MyEnv(..., device="cuda")) + >>> # If no cuda device is available + >>> env = ParallelEnv(N, MyEnv(..., device="cpu")) + + """ def _start_workers(self) -> None: from torchrl.envs.env_creator import EnvCreator @@ -722,39 +770,39 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] - self._events = [] - if self.device.type == "cuda": + func = _run_worker_pipe_shared_mem + if self.shared_tensordict_parent.device.type == "cuda": self.event = torch.cuda.Event() else: self.event = None + self._events = [ctx.Event() for _ in range(_num_workers)] + kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)] with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: print(f"initiating worker {idx}") # No certainty which module multiprocessing_context is parent_pipe, child_pipe = ctx.Pipe() - event = ctx.Event() - self._events.append(event) env_fun = self.create_env_fn[idx] if not isinstance(env_fun, EnvCreator): env_fun = CloudpickleWrapper(env_fun) - + kwargs[idx].update( + { + "parent_pipe": parent_pipe, + "child_pipe": child_pipe, + "env_fun": env_fun, + "env_fun_kwargs": self.create_env_kwargs[idx], + "shared_tensordict": self.shared_tensordicts[idx], + "_selected_input_keys": self._selected_input_keys, + "_selected_reset_keys": self._selected_reset_keys, + "_selected_step_keys": self._selected_step_keys, + "has_lazy_inputs": self.has_lazy_inputs, + } + ) process = _ProcessNoWarn( - target=_run_worker_pipe_shared_mem, + target=func, num_threads=self.num_sub_threads, - args=( - parent_pipe, - child_pipe, - env_fun, - self.create_env_kwargs[idx], - self.device, - event, - self.shared_tensordicts[idx], - self._selected_input_keys, - self._selected_reset_keys, - self._selected_step_keys, - self.has_lazy_inputs, - ), + kwargs=kwargs[idx], ) process.daemon = True process.start() @@ -834,10 +882,16 @@ def step_and_maybe_reset( # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - tensordict.set("next", self.shared_tensordict_parent.get("next").clone()) - tensordict_ = self.shared_tensordict_parent.exclude( - "next", *self.reset_keys - ).clone() + next_td = self.shared_tensordict_parent.get("next") + tensordict_ = self.shared_tensordict_parent.exclude("next", *self.reset_keys) + device = self.device + if self.shared_tensordict_parent.device == device: + next_td = next_td.clone() + tensordict_ = tensordict_.clone() + else: + next_td = next_td.to(device, non_blocking=True) + tensordict_ = tensordict_.to(device, non_blocking=True) + tensordict.set("next", next_td) return tensordict, tensordict_ @_check_start @@ -880,15 +934,20 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps next_td = self.shared_tensordict_parent.get("next") + device = self.device if self._single_task: out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True, device=device) else: # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() + out = next_td.select(*self._selected_step_keys, strict=False) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) return out @_check_start @@ -944,19 +1003,26 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: event.clear() selected_output_keys = self._selected_reset_keys_filt + device = self.device if self._single_task: # select + clone creates 2 tds, but we can create one only out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in selected_output_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) - return out + _set_single_key( + self.shared_tensordict_parent, out, key, clone=True, device=device + ) else: - return self.shared_tensordict_parent.select( + out = self.shared_tensordict_parent.select( *selected_output_keys, strict=False, - ).clone() + ) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) + return out @_check_start def _shutdown_workers(self) -> None: @@ -981,6 +1047,7 @@ def _shutdown_workers(self) -> None: del self.parent_channels self._cuda_events = None self._events = None + self.event = None @_check_start def set_seed( @@ -1063,7 +1130,6 @@ def _run_worker_pipe_shared_mem( child_pipe: connection.Connection, env_fun: Union[EnvBase, Callable], env_fun_kwargs: Dict[str, Any], - device: DEVICE_TYPING = None, mp_event: mp.Event = None, shared_tensordict: TensorDictBase = None, _selected_input_keys=None, @@ -1072,13 +1138,11 @@ def _run_worker_pipe_shared_mem( has_lazy_inputs: bool = False, verbose: bool = False, ) -> None: - if device is None: - device = torch.device("cpu") + device = shared_tensordict.device if device.type == "cuda": event = torch.cuda.Event() else: event = None - parent_pipe.close() pid = os.getpid() if not isinstance(env_fun, EnvBase): @@ -1089,7 +1153,6 @@ def _run_worker_pipe_shared_mem( "env_fun_kwargs must be empty if an environment is passed to a process." ) env = env_fun - env = env.to(device) del env_fun i = -1 @@ -1144,7 +1207,8 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - next_td = env._step(shared_tensordict) + env_input = shared_tensordict + next_td = env._step(env_input) next_shared_tensordict.update_(next_td) if event is not None: event.record() @@ -1155,7 +1219,8 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) + env_input = shared_tensordict + td, root_next_td = env.step_and_maybe_reset(env_input) next_shared_tensordict.update_(td.get("next")) root_shared_tensordict.update_(root_next_td) if event is not None: @@ -1208,3 +1273,10 @@ def _run_worker_pipe_shared_mem( else: # don't send env through pipe child_pipe.send(("_".join([cmd, "done"]), None)) + + +def _update_cuda(t_dest, t_source): + if t_source is None: + return + t_dest.copy_(t_source.pin_memory(), non_blocking=True) + return diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 06eec73be97..9a2a71f24bd 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -237,7 +237,11 @@ def step_mdp( def _set_single_key( - source: TensorDictBase, dest: TensorDictBase, key: str | tuple, clone: bool = False + source: TensorDictBase, + dest: TensorDictBase, + key: str | tuple, + clone: bool = False, + device=None, ): # key should be already unraveled if isinstance(key, str): @@ -253,7 +257,9 @@ def _set_single_key( source = val dest = new_val else: - if clone: + if device is not None and val.device != device: + val = val.to(device, non_blocking=True) + elif clone: val = val.clone() dest._set_str(k, val, inplace=False, validated=True) # This is a temporary solution to understand if a key is heterogeneous @@ -262,7 +268,7 @@ def _set_single_key( if re.match(r"Found more than one unique shape in the tensors", str(err)): # this is a het key for s_td, d_td in zip(source.tensordicts, dest.tensordicts): - _set_single_key(s_td, d_td, k, clone) + _set_single_key(s_td, d_td, k, clone=clone, device=device) break else: raise err