From 1428536530ef9d5c7d02907944b289c6dd58f499 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 30 Jun 2024 21:06:35 +0100 Subject: [PATCH] init --- torchrl/collectors/utils.py | 16 ++++++++++++++-- torchrl/data/tensor_specs.py | 9 ++++----- torchrl/envs/gym_like.py | 8 ++++---- torchrl/envs/libs/brax.py | 9 +++------ torchrl/objectives/sac.py | 3 +-- 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 91b66b94420..d777da3de2a 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -210,8 +210,20 @@ def nest(x, splits=splits): layout = as_nested if as_nested is not bool else None - def nest(*x): - return torch.nested.nested_tensor(list(x), layout=layout) + if torch.__version__ < "2.4": + # Layout must be True, there is no other layout available + if layout not in (True,): + raise RuntimeError( + f"layout={layout} is only available for torch>=v2.4" + ) + + def nest(*x): + return torch.nested.nested_tensor(list(x)) + + else: + + def nest(*x): + return torch.nested.nested_tensor(list(x), layout=layout) return out_splits[0]._fast_apply( nest, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 6f75f207293..04c24cb8d57 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3982,7 +3982,7 @@ def encode( if isinstance(vals, TensorDict): out = vals.empty() # create and empty tensordict similar to vals else: - out = TensorDict({}, torch.Size([]), _run_checks=False) + out = TensorDict._new_unsafe({}, torch.Size([])) for key, item in vals.items(): if item is None: raise RuntimeError( @@ -4047,13 +4047,12 @@ def rand(self, shape=None) -> TensorDictBase: for key, item in self.items(): if item is not None: _dict[key] = item.rand(shape) - return TensorDict( + # No need to run checks since we know Composite is compliant with + # TensorDict requirements + return TensorDict._new_unsafe( _dict, batch_size=torch.Size([*shape, *self.shape]), device=self._device, - # No need to run checks since we know Composite is compliant with - # TensorDict requirements - _run_checks=False, ) def keys( diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index d9f80ffa391..47f93f09779 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -343,8 +343,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for key, val in TensorDict(obs_dict, []).items(True, True) ) else: - tensordict_out = TensorDict( - obs_dict, batch_size=tensordict.batch_size, _run_checks=False + tensordict_out = TensorDict._new_unsafe( + obs_dict, + batch_size=tensordict.batch_size, ) if self.device is not None: tensordict_out = tensordict_out.to(self.device, non_blocking=True) @@ -377,10 +378,9 @@ def _reset( source = self.read_obs(obs) - tensordict_out = TensorDict( + tensordict_out = TensorDict._new_unsafe( source=source, batch_size=self.batch_size, - _run_checks=not self.validated, ) if self.info_dict_reader and info is not None: for info_dict_reader in self.info_dict_reader: diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 3b4b4cee224..c86ba9a543c 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -321,7 +321,7 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: state["reward"] = state.get("reward").view(*self.reward_spec.shape) state["done"] = state.get("done").view(*self.reward_spec.shape) done = state["done"].bool() - tensordict_out = TensorDict( + tensordict_out = TensorDict._new_unsafe( source={ "observation": state.get("obs"), # "reward": reward, @@ -331,7 +331,6 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: }, batch_size=self.batch_size, device=self.device, - _run_checks=False, ) return tensordict_out @@ -357,7 +356,7 @@ def _step_without_grad(self, tensordict: TensorDictBase): next_state.set("done", next_state.get("done").view(self.reward_spec.shape)) done = next_state["done"].bool() reward = next_state["reward"] - tensordict_out = TensorDict( + tensordict_out = TensorDict._new_unsafe( source={ "observation": next_state.get("obs"), "reward": reward, @@ -367,7 +366,6 @@ def _step_without_grad(self, tensordict: TensorDictBase): }, batch_size=self.batch_size, device=self.device, - _run_checks=False, ) return tensordict_out @@ -396,7 +394,7 @@ def _step_with_grad(self, tensordict: TensorDictBase): next_state.get("pipeline_state").update(dict(zip(qp_keys, next_qp_values))) # build result - tensordict_out = TensorDict( + tensordict_out = TensorDict._new_unsafe( source={ "observation": next_obs, "reward": next_reward, @@ -406,7 +404,6 @@ def _step_with_grad(self, tensordict: TensorDictBase): }, batch_size=self.batch_size, device=self.device, - _run_checks=False, ) return tensordict_out diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 9594895394a..51017384dbe 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -632,7 +632,7 @@ def _actor_loss( @property @_cache_values def _cached_target_params_actor_value(self): - return TensorDict( + return TensorDict._new_unsafe( { "module": { "0": self.target_actor_network_params, @@ -640,7 +640,6 @@ def _cached_target_params_actor_value(self): } }, torch.Size([]), - _run_checks=False, ) def _qvalue_v1_loss(