From 69453a66db52256d060634e3b5a905bf12e35c4d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 31 Jan 2024 11:47:08 +0000 Subject: [PATCH] [BugFix] Fix flaky gym penv test (#1853) --- test/_utils_internal.py | 24 +++++++++---------- torchrl/data/datasets/minari_data.py | 4 ++-- torchrl/data/datasets/openx.py | 2 +- torchrl/data/replay_buffers/replay_buffers.py | 4 ++-- torchrl/data/replay_buffers/samplers.py | 4 ++-- torchrl/data/replay_buffers/storages.py | 2 +- torchrl/data/replay_buffers/utils.py | 4 ++-- torchrl/data/replay_buffers/writers.py | 4 ++-- torchrl/data/rlhf/utils.py | 6 ++--- torchrl/data/tensor_specs.py | 12 +++++----- torchrl/envs/libs/dm_control.py | 4 ++-- torchrl/envs/libs/envpool.py | 6 ++--- torchrl/envs/libs/gym.py | 1 + torchrl/envs/libs/pettingzoo.py | 6 ++--- torchrl/envs/transforms/gym_transforms.py | 4 ++-- torchrl/envs/transforms/r3m.py | 4 ++-- torchrl/envs/transforms/rlhf.py | 2 +- torchrl/envs/transforms/transforms.py | 8 +++---- torchrl/envs/transforms/vc1.py | 4 ++-- torchrl/envs/transforms/vip.py | 4 ++-- torchrl/modules/distributions/continuous.py | 4 ++-- torchrl/modules/models/exploration.py | 4 +++- torchrl/modules/planners/mppi.py | 2 +- .../modules/tensordict_module/exploration.py | 12 +++++----- torchrl/objectives/a2c.py | 6 +++-- torchrl/objectives/deprecated.py | 14 ++++++----- 26 files changed, 78 insertions(+), 73 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 8d473dbf4ee..c9fdc7e39ba 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -330,9 +330,9 @@ def rollout_consistency_assertion( ): """Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise.""" - done = rollout[:, :-1]["next", done_key].squeeze(-1) + done = rollout[..., :-1]["next", done_key].squeeze(-1) # data resulting from step, when it's not done - r_not_done = rollout[:, :-1]["next"][~done] + r_not_done = rollout[..., :-1]["next"][~done] # data resulting from step, when it's not done, after step_mdp r_not_done_tp1 = rollout[:, 1:][~done] torch.testing.assert_close( @@ -343,17 +343,15 @@ def rollout_consistency_assertion( if done_strict and not done.any(): raise RuntimeError("No done detected, test could not complete.") - - # data resulting from step, when it's done - r_done = rollout[:, :-1]["next"][done] - # data resulting from step, when it's done, after step_mdp and reset - r_done_tp1 = rollout[:, 1:][done] - assert ( - (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1 - ).all(), ( - f"Entries in next tensordict do not match entries in root " - f"tensordict after reset : {(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) < 1e-1}" - ) + if done.any(): + # data resulting from step, when it's done + r_done = rollout[..., :-1]["next"][done] + # data resulting from step, when it's done, after step_mdp and reset + r_done_tp1 = rollout[..., 1:][done] + # check that at least one obs after reset does not match the version before reset + assert not torch.isclose( + r_done[observation_key], r_done_tp1[observation_key] + ).all() def rand_reset(env): diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 5deeccd3253..babe5638c91 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -412,8 +412,8 @@ def _proc_spec(spec): ) return BoundedTensorSpec( shape=spec["shape"], - low=torch.tensor(spec["low"]), - high=torch.tensor(spec["high"]), + low=torch.as_tensor(spec["low"]), + high=torch.as_tensor(spec["high"]), dtype=_DTYPE_DIR[spec["dtype"]], ) elif spec["type"] == "Discrete": diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 0b825188a5b..598ab782147 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -684,7 +684,7 @@ def _slice_data(data: TensorDict, slice_len, pad_value): truncated, dim=data.ndim - 1, value=True, - index=torch.tensor(-1, device=truncated.device), + index=torch.as_tensor(-1, device=truncated.device), ) done = data.get(("next", "done")) data.set(("next", "truncated"), truncated) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 79bf3b9b180..c3999806aaf 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -867,7 +867,7 @@ def add(self, data: TensorDictBase) -> int: device=data.device, ) if data.batch_size: - data_add["_rb_batch_size"] = torch.tensor(data.batch_size) + data_add["_rb_batch_size"] = torch.as_tensor(data.batch_size) else: data_add = data @@ -1441,7 +1441,7 @@ def __getitem__( if isinstance(index, slice) and index == slice(None): return self if isinstance(index, (list, range, np.ndarray)): - index = torch.tensor(index) + index = torch.as_tensor(index) if isinstance(index, torch.Tensor): if index.ndim > 1: raise RuntimeError( diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 3460f6ed51c..15e46ae1038 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -461,10 +461,10 @@ def dumps(self, path): filename=path / "mintree.memmap", ) mm_st.copy_( - torch.tensor([self._sum_tree[i] for i in range(self._max_capacity)]) + torch.as_tensor([self._sum_tree[i] for i in range(self._max_capacity)]) ) mm_mt.copy_( - torch.tensor([self._min_tree[i] for i in range(self._max_capacity)]) + torch.as_tensor([self._min_tree[i] for i in range(self._max_capacity)]) ) with open(path / "sampler_metadata.json", "w") as file: json.dump( diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 8a0510b11b4..fd847f25c74 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1005,7 +1005,7 @@ def __getitem__(self, index): if isinstance(index, slice) and index == slice(None): return self if isinstance(index, (list, range, np.ndarray)): - index = torch.tensor(index) + index = torch.as_tensor(index) if isinstance(index, torch.Tensor): if index.ndim > 1: raise RuntimeError( diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index c042f54c652..7846a6bb9d4 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -28,11 +28,11 @@ def _to_torch( data: Tensor, device, pin_memory: bool = False, non_blocking: bool = False ) -> torch.Tensor: if isinstance(data, np.generic): - return torch.tensor(data, device=device) + return torch.as_tensor(data, device=device) elif isinstance(data, np.ndarray): data = torch.from_numpy(data) elif not isinstance(data, Tensor): - data = torch.tensor(data, device=device) + data = torch.as_tensor(data, device=device) if pin_memory: data = data.pin_memory() diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 41d551535ac..156d32f9539 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -357,7 +357,7 @@ def __getstate__(self): def dumps(self, path): path = Path(path).absolute() path.mkdir(exist_ok=True) - t = torch.tensor(self._current_top_values) + t = torch.as_tensor(self._current_top_values) try: MemoryMappedTensor.from_filename( filename=path / "current_top_values.memmap", @@ -453,7 +453,7 @@ def __getitem__(self, index): if isinstance(index, slice) and index == slice(None): return self if isinstance(index, (list, range, np.ndarray)): - index = torch.tensor(index) + index = torch.as_tensor(index) if isinstance(index, torch.Tensor): if index.ndim > 1: raise RuntimeError( diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index ed7c7d1d35f..311b2584aa5 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -100,7 +100,7 @@ def update(self, kl_values: Sequence[float]): ) n_steps = len(kl_values) # renormalize kls - kl_value = -torch.tensor(kl_values).mean() / self.coef + kl_value = -torch.as_tensor(kl_values).mean() / self.coef proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ mult = 1 + proportional_error * n_steps / self.horizon self.coef *= mult # βₜ₊₁ @@ -314,10 +314,10 @@ def _get_done_status(self, generated, batch): # of generated tokens done_idx = torch.minimum( (generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex, - torch.tensor(self.max_new_tokens) - 1, + torch.as_tensor(self.max_new_tokens) - 1, ) truncated_idx = ( - torch.tensor(self.max_new_tokens, device=generated.device).expand_as( + torch.as_tensor(self.max_new_tokens, device=generated.device).expand_as( done_idx ) - 1 diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index b4d628a9051..1cfc970e61f 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1374,9 +1374,9 @@ def encode( ) -> torch.Tensor: if not isinstance(val, torch.Tensor): if ignore_device: - val = torch.tensor(val) + val = torch.as_tensor(val) else: - val = torch.tensor(val, device=self.device) + val = torch.as_tensor(val, device=self.device) if space is None: space = self.space @@ -1555,9 +1555,9 @@ def __init__( dtype = torch.get_default_dtype() if not isinstance(low, torch.Tensor): - low = torch.tensor(low, dtype=dtype, device=device) + low = torch.as_tensor(low, dtype=dtype, device=device) if not isinstance(high, torch.Tensor): - high = torch.tensor(high, dtype=dtype, device=device) + high = torch.as_tensor(high, dtype=dtype, device=device) if high.device != device: high = high.to(device) if low.device != device: @@ -1857,8 +1857,8 @@ def __init__( dtype, device = _default_dtype_and_device(dtype, device) box = ( ContinuousBox( - torch.tensor(-np.inf, device=device).expand(shape), - torch.tensor(np.inf, device=device).expand(shape), + torch.as_tensor(-np.inf, device=device).expand(shape), + torch.as_tensor(np.inf, device=device).expand(shape), ) if shape == _DEFAULT_SHAPE else None diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index b2fdac0a802..2e96efcaf6a 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -102,9 +102,9 @@ def _get_envs(to_dict: bool = True) -> Dict[str, Any]: def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor: if isinstance(array, np.ndarray): - return torch.tensor(array.copy()) + return torch.as_tensor(array.copy()) else: - return torch.tensor(array) + return torch.as_tensor(array) class DMControlWrapper(GymLikeEnv): diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py index acf2da598b1..410e25a1b28 100644 --- a/torchrl/envs/libs/envpool.py +++ b/torchrl/envs/libs/envpool.py @@ -264,7 +264,7 @@ def _transform_step_output( f"The output of step was had {len(out)} elements, but only 4 or 5 are supported." ) obs = self._treevalue_or_numpy_to_tensor_or_dict(obs) - reward_and_done = {self.reward_key: torch.tensor(reward)} + reward_and_done = {self.reward_key: torch.as_tensor(reward)} reward_and_done["done"] = done reward_and_done["terminated"] = terminated reward_and_done["truncated"] = truncated @@ -290,7 +290,7 @@ def _treevalue_or_numpy_to_tensor_or_dict( if isinstance(x, treevalue.TreeValue): ret = self._treevalue_to_dict(x) elif not isinstance(x, dict): - ret = {"observation": torch.tensor(x)} + ret = {"observation": torch.as_tensor(x)} else: ret = x return ret @@ -304,7 +304,7 @@ def _treevalue_to_dict( """ import treevalue - return {k[0]: torch.tensor(v) for k, v in treevalue.flatten(tv)} + return {k[0]: torch.as_tensor(v) for k, v in treevalue.flatten(tv)} def _set_seed(self, seed: Optional[int]): if seed is not None: diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 48da354e7ba..59730c6df8c 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -1506,6 +1506,7 @@ def _read_obs(self, obs, key, tensor, index): def __call__(self, info_dict, tensordict): terminal_obs = info_dict.get(self.backend_key[self.backend], None) for key, item in self.info_spec.items(True, True): + key = (key,) if isinstance(key, str) else key final_obs_buffer = item.zero() if terminal_obs is not None: for i, obs in enumerate(terminal_obs): diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 14e45eb4bc4..a1470776f10 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -462,7 +462,7 @@ def _init_env(self): "info": CompositeSpec( { key: UnboundedContinuousTensorSpec( - shape=torch.tensor(value).shape, + shape=torch.as_tensor(value).shape, device=self.device, ) for key, value in info_dict[agent].items() @@ -501,7 +501,7 @@ def _init_env(self): device=self.device, ) except AttributeError: - state_example = torch.tensor(self.state(), device=self.device) + state_example = torch.as_tensor(self.state(), device=self.device) state_spec = UnboundedContinuousTensorSpec( shape=state_example.shape, dtype=state_example.dtype, @@ -560,7 +560,7 @@ def _reset( if group_info is not None: agent_info_dict = info_dict[agent] for agent_info, value in agent_info_dict.items(): - group_info.get(agent_info)[index] = torch.tensor( + group_info.get(agent_info)[index] = torch.as_tensor( value, device=self.device ) diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py index b5aed62d503..99f38ebb32c 100644 --- a/torchrl/envs/transforms/gym_transforms.py +++ b/torchrl/envs/transforms/gym_transforms.py @@ -135,7 +135,7 @@ def _get_lives(self): if callable(lives): lives = lives() elif isinstance(lives, list) and all(callable(_lives) for _lives in lives): - lives = torch.tensor([_lives() for _lives in lives]) + lives = torch.as_tensor([_lives() for _lives in lives]) return lives def _call(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -170,7 +170,7 @@ def _reset(self, tensordict, tensordict_reset): end_of_life = False tensordict_reset.set( self.eol_key, - torch.tensor(end_of_life).expand( + torch.as_tensor(end_of_life).expand( parent.full_done_spec[self.done_key].shape ), ) diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index 05017a8a8ec..1c12cf9be15 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -292,8 +292,8 @@ def _init(self): std = [0.229, 0.224, 0.225] normalize = ObservationNorm( in_keys=in_keys, - loc=torch.tensor(mean).view(3, 1, 1), - scale=torch.tensor(std).view(3, 1, 1), + loc=torch.as_tensor(mean).view(3, 1, 1), + scale=torch.as_tensor(std).view(3, 1, 1), standard_normal=True, ) transforms.append(normalize) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 79ee94318cb..623bc2864fe 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -146,7 +146,7 @@ def find_sample_log_prob(module): self.functional_actor.apply(find_sample_log_prob) if not isinstance(coef, torch.Tensor): - coef = torch.tensor(coef) + coef = torch.as_tensor(coef) self.register_buffer("coef", coef) def _reset( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 52ac8e8f66d..e59c481419c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1332,7 +1332,7 @@ def check_val(val): if val is None: return None, None, torch.finfo(torch.get_default_dtype()).max if not isinstance(val, torch.Tensor): - val = torch.tensor(val) + val = torch.as_tensor(val) if not val.dtype.is_floating_point: val = val.float() eps = torch.finfo(val.dtype).resolution @@ -1626,10 +1626,10 @@ def __init__( out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) clamp_min_tensor = ( - clamp_min if isinstance(clamp_min, Tensor) else torch.tensor(clamp_min) + clamp_min if isinstance(clamp_min, Tensor) else torch.as_tensor(clamp_min) ) clamp_max_tensor = ( - clamp_max if isinstance(clamp_max, Tensor) else torch.tensor(clamp_max) + clamp_max if isinstance(clamp_max, Tensor) else torch.as_tensor(clamp_max) ) self.register_buffer("clamp_min", clamp_min_tensor) self.register_buffer("clamp_max", clamp_max_tensor) @@ -2396,7 +2396,7 @@ def __init__( out_keys_inv=out_keys_inv, ) if not isinstance(standard_normal, torch.Tensor): - standard_normal = torch.tensor(standard_normal) + standard_normal = torch.as_tensor(standard_normal) self.register_buffer("standard_normal", standard_normal) self.eps = 1e-6 diff --git a/torchrl/envs/transforms/vc1.py b/torchrl/envs/transforms/vc1.py index 252ddfc4a90..746bfe52f1d 100644 --- a/torchrl/envs/transforms/vc1.py +++ b/torchrl/envs/transforms/vc1.py @@ -132,8 +132,8 @@ def _map_tv_to_torchrl( elif isinstance(model_transforms, transforms.Normalize): return ObservationNorm( in_keys=in_keys, - loc=torch.tensor(model_transforms.mean).reshape(3, 1, 1), - scale=torch.tensor(model_transforms.std).reshape(3, 1, 1), + loc=torch.as_tensor(model_transforms.mean).reshape(3, 1, 1), + scale=torch.as_tensor(model_transforms.std).reshape(3, 1, 1), standard_normal=True, ) elif isinstance(model_transforms, transforms.ToTensor): diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index ed288fdea9e..9c272b42b89 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -266,8 +266,8 @@ def _init(self): std = [0.229, 0.224, 0.225] normalize = ObservationNorm( in_keys=in_keys, - loc=torch.tensor(mean).view(3, 1, 1), - scale=torch.tensor(std).view(3, 1, 1), + loc=torch.as_tensor(mean).view(3, 1, 1), + scale=torch.as_tensor(std).view(3, 1, 1), standard_normal=True, ) transforms.append(normalize) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index d4256dcd61f..eb5f2a38944 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -240,11 +240,11 @@ def __init__( if isinstance(max, torch.Tensor): max = max.to(self.device) else: - max = torch.tensor(max, device=self.device) + max = torch.as_tensor(max, device=self.device) if isinstance(min, torch.Tensor): min = min.to(self.device) else: - min = torch.tensor(min, device=self.device) + min = torch.as_tensor(min, device=self.device) self.min = min self.max = max self.update(loc, scale) diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index f909b6568c6..59819d940d0 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -345,7 +345,9 @@ def __init__( ) if sigma_init != 0.0: - self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device)) + self.register_buffer( + "sigma_init", torch.as_tensor(sigma_init, device=device) + ) @property def sigma(self): diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index c65b81eb11d..9c0bbc8f147 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -145,7 +145,7 @@ def __init__( self.num_candidates = num_candidates self.top_k = top_k self.reward_key = reward_key - self.register_buffer("temperature", torch.tensor(temperature)) + self.register_buffer("temperature", torch.as_tensor(temperature)) def planning(self, tensordict: TensorDictBase) -> torch.Tensor: batch_size = tensordict.batch_size diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index f641fdfef88..c8fa9cc040f 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -107,10 +107,10 @@ def __init__( super().__init__() - self.register_buffer("eps_init", torch.tensor([eps_init])) - self.register_buffer("eps_end", torch.tensor([eps_end])) + self.register_buffer("eps_init", torch.as_tensor([eps_init])) + self.register_buffer("eps_end", torch.as_tensor([eps_end])) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) + self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32)) if spec is not None: if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: @@ -254,12 +254,12 @@ def __init__( ) super().__init__(policy) - self.register_buffer("eps_init", torch.tensor([eps_init])) - self.register_buffer("eps_end", torch.tensor([eps_end])) + self.register_buffer("eps_init", torch.as_tensor([eps_init])) + self.register_buffer("eps_end", torch.as_tensor([eps_end])) if self.eps_end > self.eps_init: raise RuntimeError("eps should decrease over time or be constant") self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) + self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32)) self.action_key = action_key self.action_mask_key = action_mask_key if spec is not None: diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index c32a795a2a0..de963bcfdb9 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -283,8 +283,10 @@ def __init__( except AttributeError: device = torch.device("cpu") - self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) - self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device)) + self.register_buffer( + "entropy_coef", torch.as_tensor(entropy_coef, device=device) + ) + self.register_buffer("critic_coef", torch.as_tensor(critic_coef, device=device)) if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index e920bc83960..6ef7ab7386e 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -174,22 +174,24 @@ def __init__( except AttributeError: device = torch.device("cpu") - self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device)) self.register_buffer( - "min_log_alpha", torch.tensor(min_alpha, device=device).log() + "min_log_alpha", torch.as_tensor(min_alpha, device=device).log() ) self.register_buffer( - "max_log_alpha", torch.tensor(max_alpha, device=device).log() + "max_log_alpha", torch.as_tensor(max_alpha, device=device).log() ) self.fixed_alpha = fixed_alpha if fixed_alpha: self.register_buffer( - "log_alpha", torch.tensor(math.log(alpha_init), device=device) + "log_alpha", torch.as_tensor(math.log(alpha_init), device=device) ) else: self.register_parameter( "log_alpha", - torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + torch.nn.Parameter( + torch.as_tensor(math.log(alpha_init), device=device) + ), ) self._target_entropy = target_entropy @@ -230,7 +232,7 @@ def target_entropy(self): np.prod(action_spec[self.tensor_keys.action].shape) ) self.register_buffer( - "target_entropy_buffer", torch.tensor(target_entropy, device=device) + "target_entropy_buffer", torch.as_tensor(target_entropy, device=device) ) return self.target_entropy_buffer return target_entropy