diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 1e160515ca6..c03fb40f1ac 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1402,7 +1402,6 @@ def type_check(self, value: torch.Tensor, key: NestedKey | None = None) -> None: spec.type_check(val) def is_in(self, value) -> bool: - raise RuntimeError if self.dim == 0 and not hasattr(value, "unbind"): # We don't use unbind because value could be a tuple or a nested tensor return all( @@ -1834,7 +1833,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError if self.mask is None: shape = torch.broadcast_shapes(self._safe_shape, val.shape) shape_match = val.shape == shape @@ -2288,7 +2286,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError val_shape = _remove_neg_shapes(tensordict.utils._shape(val)) shape = torch.broadcast_shapes(self._safe_shape, val_shape) shape = list(shape) @@ -2489,7 +2486,6 @@ def one(self, shape=None): ) def is_in(self, val: Any) -> bool: - raise RuntimeError shape = torch.broadcast_shapes(self._safe_shape, val.shape) return ( is_non_tensor(val) @@ -2682,7 +2678,6 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError shape = torch.broadcast_shapes(self._safe_shape, val.shape) return val.shape == shape and val.dtype == self.dtype @@ -3034,7 +3029,6 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten return torch.cat(out, -1) def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError vals = self._split(val) if vals is None: return False @@ -3435,7 +3429,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError if self.mask is None: shape = torch.broadcast_shapes(self._safe_shape, val.shape) shape_match = val.shape == shape @@ -4066,7 +4059,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val.squeeze(0) if val_is_scalar else val def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError if self.mask is not None: vals = val.unbind(-1) splits = self._split_self()