From 57f05800e3ae631b242d60f5efd46887762d07d4 Mon Sep 17 00:00:00 2001 From: Chuanbo HUA Date: Thu, 5 Sep 2024 22:48:22 +0900 Subject: [PATCH] [BugFix] Fix invalid CUDA ID error when loading Bounded variables across devices (#2421) --- torchrl/data/tensor_specs.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 9bbd068b434..60c1009990e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -397,16 +397,6 @@ def high(self, value): self.device = value.device self._high = value.cpu() - @low.setter - def low(self, value): - self.device = value.device - self._low = value.cpu() - - @high.setter - def high(self, value): - self.device = value.device - self._high = value.cpu() - def __post_init__(self): self.low = self.low.clone() self.high = self.high.clone() @@ -2269,9 +2259,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self + self.space.device = dest_device return Bounded( - low=self.space.low.to(dest), - high=self.space.high.to(dest), + low=self.space.low, + high=self.space.high, shape=self.shape, device=dest_device, dtype=dest_dtype,