From db7f08d76c0b1a99cd9fe5f3c586ecd879d379ad Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 16:40:57 +0000 Subject: [PATCH] [Refactor] compile compatibility improvements ghstack-source-id: 95f8241b56e42b80e828485cb5f377288bff6f5e Pull Request resolved: https://github.com/pytorch/rl/pull/2578 --- test/test_collector.py | 22 ---- torchrl/collectors/collectors.py | 5 +- torchrl/data/tensor_specs.py | 6 +- torchrl/envs/batched_envs.py | 11 +- torchrl/modules/distributions/continuous.py | 21 ++- torchrl/modules/distributions/utils.py | 13 +- .../modules/models/decision_transformer.py | 5 + torchrl/modules/tensordict_module/actors.py | 4 + torchrl/modules/tensordict_module/common.py | 3 +- .../modules/tensordict_module/exploration.py | 122 ++++++++++-------- torchrl/objectives/common.py | 18 ++- torchrl/objectives/cql.py | 1 + torchrl/objectives/crossq.py | 14 +- torchrl/objectives/decision_transformer.py | 8 +- torchrl/objectives/value/advantages.py | 50 ++++--- 15 files changed, 176 insertions(+), 127 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 7c185830a92..1309254ce2d 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -3172,28 +3172,6 @@ def make_and_test_policy( ) -@pytest.mark.parametrize( - "ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] -) -def test_no_stopiteration(ctype): - # Tests that there is no StopIteration raised and that the length of the collector is properly set - if ctype is SyncDataCollector: - envs = SerialEnv(16, CountingEnv) - else: - envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)] - - collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300) - try: - c_iter = iter(collector) - for i in range(len(collector)): # noqa: B007 - c = next(c_iter) - assert c is not None - assert i == 1 - finally: - collector.shutdown() - del collector - - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index fe1a796ea2d..319722a552e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -147,7 +147,6 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): _iterator = None total_frames: int frames_per_batch: int - requested_frames_per_batch: int trust_policy: bool compiled_policy: bool cudagraphed_policy: bool @@ -306,7 +305,7 @@ def __class_getitem__(self, index): def __len__(self) -> int: if self.total_frames > 0: - return -(self.total_frames // -self.requested_frames_per_batch) + return -(self.total_frames // -self.frames_per_batch) raise RuntimeError("Non-terminating collectors do not have a length") @@ -701,7 +700,7 @@ def __init__( remainder = total_frames % frames_per_batch if remainder != 0 and RL_WARNINGS: warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})." f"This means {frames_per_batch - remainder} additional frames will be collected." "To silence this message, set the environment variable RL_WARNINGS to False." ) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7fbfaab3280..2ef74bb4521 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2312,10 +2312,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self - self.space.device = dest_device + space = self.space.to(dest_device) return Bounded( - low=self.space.low, - high=self.space.high, + low=space.low, + high=space.high, shape=self.shape, device=dest_device, dtype=dest_dtype, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9e59e0f69d6..17bd28c8390 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1356,12 +1356,15 @@ def _start_workers(self) -> None: from torchrl.envs.env_creator import EnvCreator + num_threads = max( + 1, torch.get_num_threads() - self.num_workers + ) # 1 more thread for this proc + if self.num_threads is None: - self.num_threads = max( - 1, torch.get_num_threads() - self.num_workers - ) # 1 more thread for this proc + self.num_threads = num_threads - torch.set_num_threads(self.num_threads) + if self.num_threads != torch.get_num_threads(): + torch.set_num_threads(self.num_threads) if self._mp_start_method is not None: ctx = mp.get_context(self._mp_start_method) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 62b5df5d14b..6c200c15ee4 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -397,7 +397,6 @@ def __init__( event_dims: int | None = None, tanh_loc: bool = False, safe_tanh: bool = True, - **kwargs, ): if not isinstance(loc, torch.Tensor): loc = torch.as_tensor(loc, dtype=torch.get_default_dtype()) @@ -683,6 +682,7 @@ def __init__( event_dims: int = 1, atol: float = 1e-6, rtol: float = 1e-6, + safe: bool = True, ): minmax_msg = "high value has been found to be equal or less than low value" if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): @@ -695,12 +695,19 @@ def __init__( if not all(high > low): raise ValueError(minmax_msg) - t = SafeTanhTransform() - non_trivial_min = (isinstance(low, torch.Tensor) and (low != -1.0).any()) or ( - not isinstance(low, torch.Tensor) and low != -1.0 + if safe: + if is_dynamo_compiling(): + _err_compile_safetanh() + t = SafeTanhTransform() + else: + t = torch.distributions.TanhTransform() + non_trivial_min = is_dynamo_compiling or ( + (isinstance(low, torch.Tensor) and (low != -1.0).any()) + or (not isinstance(low, torch.Tensor) and low != -1.0) ) - non_trivial_max = (isinstance(high, torch.Tensor) and (high != 1.0).any()) or ( - not isinstance(high, torch.Tensor) and high != 1.0 + non_trivial_max = is_dynamo_compiling or ( + (isinstance(high, torch.Tensor) and (high != 1.0).any()) + or (not isinstance(high, torch.Tensor) and high != 1.0) ) self.non_trivial = non_trivial_min or non_trivial_max @@ -778,7 +785,7 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: def _err_compile_safetanh(): raise RuntimeError( "safe_tanh=True in TanhNormal is not compatible with torch.compile with torch pre 2.6.0. " - "To deactivate it, pass safe_tanh=False. " + " To deactivate it, pass safe_tanh=False. " "If you are using a ProbabilisticTensorDictModule, this can be done via " "`distribution_kwargs={'safe_tanh': False}`. " "See https://github.com/pytorch/pytorch/issues/133529 for more details." diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 546d93cb228..8c332c4efed 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -9,6 +9,11 @@ from torch import autograd, distributions as d from torch.distributions import Independent, Transform, TransformedDistribution +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + def _cast_device(elt: Union[torch.Tensor, float], device) -> Union[torch.Tensor, float]: if isinstance(elt, torch.Tensor): @@ -40,10 +45,12 @@ class FasterTransformedDistribution(TransformedDistribution): __doc__ = __doc__ + TransformedDistribution.__doc__ def __init__(self, base_distribution, transforms, validate_args=None): + if is_dynamo_compiling(): + return super().__init__( + base_distribution, transforms, validate_args=validate_args + ) if isinstance(transforms, Transform): - self.transforms = [ - transforms, - ] + self.transforms = [transforms] elif isinstance(transforms, list): raise ValueError("Make a ComposeTransform first.") else: diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 8eb72f1f9ea..8a20ad2eba8 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -90,7 +90,12 @@ def __init__( state_dim, action_dim, config: dict | DTConfig = None, + device: torch.device | None = None, ): + if device is not None: + with torch.device(device): + return self.__init__(state_dim, action_dim, config) + if not _has_transformers: raise ImportError( "transformers is not installed. Please install it with `pip install transformers`." diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 2ad5918861d..888729835b5 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1783,6 +1783,7 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries of the context will be masked. Defaults to 5. spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module. + device (torch.device, optional): if provided, the device where the buffers / specs will be placed. Examples: >>> import torch @@ -1836,6 +1837,7 @@ def __init__( *, inference_context: int = 5, spec: Optional[TensorSpec] = None, + device: torch.device | None = None, ): super().__init__(policy) self.observation_key = "observation" @@ -1857,6 +1859,8 @@ def __init__( self._spec[self.action_key] = None else: self._spec = Composite({key: None for key in policy.out_keys}) + if device is not None: + self._spec = self._spec.to(device) self.checked = False @property diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 4018589bfa1..f722bc2bd7d 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -69,7 +69,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): keys = [out_key] values = [spec] else: - keys = list(spec.keys(True, True)) + # Make dynamo happy with the list creation + keys = [key for key in spec.keys(True, True)] # noqa: C416 values = [spec[key] for key in keys] for _spec, _key in zip(values, keys): if _spec is None: diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 2ccdf599f2d..df947236970 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -133,11 +133,14 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), + self.eps.data.copy_( + torch.maximum( + self.eps_end, + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ), + ) ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -150,7 +153,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: action_key = self.action_key out = action_tensordict.get(action_key) - eps = self.eps.item() + eps = self.eps cond = torch.rand(action_tensordict.shape, device=out.device) < eps # cond = torch.zeros(action_tensordict.shape, device=out.device, dtype=torch.bool).bernoulli_(eps) cond = expand_as_right(cond, out) @@ -307,19 +310,20 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): - self.sigma.data[0] = max( - self.sigma_end.item(), - ( - self.sigma - - (self.sigma_init - self.sigma_end) / self.annealing_num_steps - ).item(), + self.sigma.data.copy_( + torch.maximum( + self.sigma_end( + self.sigma + - (self.sigma_init - self.sigma_end) / self.annealing_num_steps + ), + ) ) def _add_noise(self, action: torch.Tensor) -> torch.Tensor: - sigma = self.sigma.item() + sigma = self.sigma noise = torch.normal( - mean=torch.ones(action.shape) * self.mean.item(), - std=torch.ones(action.shape) * self.std.item(), + mean=torch.ones(action.shape) * self.mean, + std=torch.ones(action.shape) * self.std, ).to(action.device) action = action + noise * sigma spec = self.spec @@ -365,6 +369,9 @@ class AdditiveGaussianModule(TensorDictModuleBase): its output spec will be of type Composite. One needs to know where to find the action spec. default: "action" + safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space + given the :obj:`TensorSpec.project` heuristic. + default: True .. note:: It is @@ -386,6 +393,7 @@ def __init__( std: float = 1.0, *, action_key: Optional[NestedKey] = "action", + safe: bool = True, ): if not isinstance(sigma_init, float): warnings.warn("eps_init should be a float.") @@ -410,7 +418,9 @@ def __init__( else: raise RuntimeError("spec cannot be None.") self._spec = spec - self.register_forward_hook(_forward_hook_safe_action) + self.safe = safe + if self.safe: + self.register_forward_hook(_forward_hook_safe_action) @property def spec(self): @@ -426,19 +436,21 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): - self.sigma.data[0] = max( - self.sigma_end.item(), - ( - self.sigma - - (self.sigma_init - self.sigma_end) / self.annealing_num_steps - ).item(), + self.sigma.data.copy_( + torch.maximum( + self.sigma_end, + ( + self.sigma + - (self.sigma_init - self.sigma_end) / self.annealing_num_steps + ), + ) ) def _add_noise(self, action: torch.Tensor) -> torch.Tensor: - sigma = self.sigma.item() + sigma = self.sigma noise = torch.normal( - mean=torch.ones(action.shape) * self.mean.item(), - std=torch.ones(action.shape) * self.std.item(), + mean=torch.ones(action.shape) * self.mean, + std=torch.ones(action.shape) * self.std, ).to(action.device) action = action + noise * sigma spec = self.spec[self.action_key] @@ -636,12 +648,14 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): if self.annealing_num_steps > 0: - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), + self.eps.data.copy_( + torch.maximum( + self.eps_end, + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ), + ) ) else: raise ValueError( @@ -664,9 +678,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." ) - tensordict = self.ou.add_sample( - tensordict, self.eps.item(), is_init=is_init - ) + tensordict = self.ou.add_sample(tensordict, self.eps, is_init=is_init) return tensordict @@ -730,6 +742,10 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase): default: "action" is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps. default: "is_init" + safe (boolean, optional): if False, the TensorSpec can be None. If it + is set to False but the spec is passed, the projection will still + happen. + Default is True. Examples: >>> import torch @@ -772,6 +788,7 @@ def __init__( *, action_key: Optional[NestedKey] = "action", is_init_key: Optional[NestedKey] = "is_init", + safe: bool = True, ): super().__init__() @@ -815,7 +832,9 @@ def __init__( self._spec.update(ou_specs) if len(set(self.out_keys)) != len(self.out_keys): raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}") - self.register_forward_hook(_forward_hook_safe_action) + self.safe = safe + if self.safe: + self.register_forward_hook(_forward_hook_safe_action) @property def spec(self): @@ -830,12 +849,14 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): if self.annealing_num_steps > 0: - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), + self.eps.data.copy_( + torch.maximum( + self.eps_end, + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ), + ) ) else: raise ValueError( @@ -857,9 +878,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." ) - tensordict = self.ou.add_sample( - tensordict, self.eps.item(), is_init=is_init - ) + tensordict = self.ou.add_sample(tensordict, self.eps, is_init=is_init) return tensordict @@ -923,11 +942,12 @@ def _make_noise_pair( tensordict.set(self.noise_key, noise) tensordict.set(self.steps_key, steps) else: - noise = tensordict.get(self.noise_key) - steps = tensordict.get(self.steps_key) + # We must clone for cudagraph, otherwise the same tensor may re-enter the compiled region + noise = tensordict.get(self.noise_key).clone() + steps = tensordict.get(self.steps_key).clone() if is_init is not None: - noise[is_init] = 0 - steps[is_init] = 0 + noise = torch.masked_fill(noise, is_init, 0) + steps = torch.masked_fill(steps, is_init, 0) return noise, steps def add_sample( @@ -977,9 +997,9 @@ def add_sample( * np.sqrt(self.dt) * torch.randn_like(prev_noise) ) - tensordict.set_(self.noise_key, noise - self.x0) - tensordict.set_(self.key, tensordict.get(self.key) + eps * noise) - tensordict.set_(self.steps_key, n_steps + 1) + tensordict.set(self.noise_key, noise - self.x0) + tensordict.set(self.key, tensordict.get(self.key) + eps * noise) + tensordict.set(self.steps_key, n_steps + 1) return tensordict def current_sigma(self, n_steps: torch.Tensor) -> torch.Tensor: diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index be05e2fa66b..57310a5fc3d 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple +import torch from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams @@ -515,7 +516,22 @@ def _default_value_estimator(self): from :obj:`torchrl.objectives.utils.DEFAULT_VALUE_FUN_PARAMS`. """ - self.make_value_estimator(self.default_value_estimator) + self.make_value_estimator( + self.default_value_estimator, device=self._default_device + ) + + @property + def _default_device(self) -> torch.device | None: + """A util to find the default device. + + Returns ``None`` if parameters are spread across multiple devices. + """ + devices = set() + for p in self.parameters(): + devices.add(p.device) + if len(devices) == 1: + return list(devices)[0] + return None def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): """Value-function constructor. diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 14c5d54a61d..191096e7492 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -375,6 +375,7 @@ def __init__( ) self._make_vmap() self.reduction = reduction + _ = self.target_entropy def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index c555b7a609c..ca6559ac5b8 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -340,6 +340,8 @@ def __init__( self._action_spec = action_spec self._make_vmap() self.reduction = reduction + # init target entropy + _ = self.target_entropy def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -513,15 +515,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: **metadata_actor, **value_metadata, } - td_out = TensorDict(out, []) - # td_out = td_out.named_apply( - # lambda name, value: ( - # _reduce(value, reduction=self.reduction) - # if name.startswith("loss_") - # else value - # ), - # batch_size=[], - # ) + td_out = TensorDict(out) return td_out @property @@ -543,6 +537,7 @@ def actor_loss( Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action. """ + tensordict = tensordict.copy() with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -584,6 +579,7 @@ def qvalue_loss( Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing the detached `"td_error"` to be used for prioritized sampling. """ + tensordict = tensordict.copy() # # compute next action with torch.no_grad(): with set_exploration_type( diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 632d3e615b6..1b1f0aa4e0b 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -292,6 +292,7 @@ def __init__( *, loss_function: str = "l2", reduction: str = None, + device: torch.device | None = None, ) -> None: self._in_keys = None self._out_keys = None @@ -343,7 +344,7 @@ def out_keys(self, values): def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets - tensordict = tensordict.clone(False) + tensordict = tensordict.copy() target_actions = tensordict.get(self.tensor_keys.action_target).detach() with self.actor_network_params.to_module(self.actor_network): @@ -356,8 +357,5 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_function=self.loss_function, ) loss = _reduce(loss, reduction=self.reduction) - out = { - "loss": loss, - } - td_out = TensorDict(out, []) + td_out = TensorDict(loss=loss) return td_out diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 739fb9a018e..8ac64bf3d21 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -197,6 +197,8 @@ def forward( to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the target params to be passed to the functional value network module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. Returns: An updated TensorDict with an advantage and a value_error keys as defined in the constructor. @@ -213,8 +215,14 @@ def __init__( advantage_key: NestedKey = None, value_target_key: NestedKey = None, value_key: NestedKey = None, + device: torch.device | None = None, ): super().__init__() + if device is None: + device = torch.get_default_device() + # this is saved for tracking only and should not be used to cast anything else than buffers during + # init. + self._device = device self._tensor_keys = None self.differentiable = differentiable self.skip_existing = skip_existing @@ -518,7 +526,8 @@ class TD0Estimator(ValueEstimatorBase): of the advantage entry. Defaults to ``"value_target"``. value_key (str or tuple of str, optional): [Deprecated] the value key to read from the input tensordict. Defaults to ``"state_value"``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. """ @@ -544,8 +553,9 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.average_rewards = average_rewards @_self_set_skip_existing @@ -668,7 +678,6 @@ def value_estimate( if self.gamma.device != device: self.gamma = self.gamma.to(device) gamma = self.gamma - steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -727,7 +736,8 @@ class TD1Estimator(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -761,8 +771,9 @@ def __init__( value_key=value_key, shifted=shifted, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.average_rewards = average_rewards self.time_dim = time_dim @@ -887,7 +898,6 @@ def value_estimate( if self.gamma.device != device: self.gamma = self.gamma.to(device) gamma = self.gamma - steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -951,7 +961,8 @@ class TDLambdaEstimator(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -987,9 +998,10 @@ def __init__( value_key=value_key, skip_existing=skip_existing, shifted=shifted, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device)) self.average_rewards = average_rewards self.vectorized = vectorized self.time_dim = time_dim @@ -1115,7 +1127,6 @@ def value_estimate( if self.gamma.device != device: self.gamma = self.gamma.to(device) gamma = self.gamma - steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1197,7 +1208,8 @@ class GAE(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension marked with the ``"time"`` name if any, and to the last dimension @@ -1245,9 +1257,10 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device)) self.average_gae = average_gae self.vectorized = vectorized self.time_dim = time_dim @@ -1530,7 +1543,8 @@ class VTrace(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -1575,13 +1589,14 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) if not isinstance(gamma, torch.Tensor): - gamma = torch.tensor(gamma, device=device) + gamma = torch.tensor(gamma, device=self._device) if not isinstance(rho_thresh, torch.Tensor): - rho_thresh = torch.tensor(rho_thresh, device=device) + rho_thresh = torch.tensor(rho_thresh, device=self._device) if not isinstance(c_thresh, torch.Tensor): - c_thresh = torch.tensor(c_thresh, device=device) + c_thresh = torch.tensor(c_thresh, device=self._device) self.register_buffer("gamma", gamma) self.register_buffer("rho_thresh", rho_thresh) @@ -1716,7 +1731,6 @@ def forward( if self.gamma.device != device: self.gamma = self.gamma.to(device) gamma = self.gamma - steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward)