diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 9bbd068b434..60c1009990e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -397,16 +397,6 @@ def high(self, value): self.device = value.device self._high = value.cpu() - @low.setter - def low(self, value): - self.device = value.device - self._low = value.cpu() - - @high.setter - def high(self, value): - self.device = value.device - self._high = value.cpu() - def __post_init__(self): self.low = self.low.clone() self.high = self.high.clone() @@ -2269,9 +2259,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self + self.space.device = dest_device return Bounded( - low=self.space.low.to(dest), - high=self.space.high.to(dest), + low=self.space.low, + high=self.space.high, shape=self.shape, device=dest_device, dtype=dest_dtype,