Skip to content

Commit

Permalink
[BugFix] Fix invalid CUDA ID error when loading Bounded variables acr…
Browse files Browse the repository at this point in the history
…oss devices (pytorch#2421)
  • Loading branch information
cbhua authored Sep 5, 2024
1 parent df4fa78 commit 57f0580
Showing 1 changed file with 3 additions and 12 deletions.
15 changes: 3 additions & 12 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,16 +397,6 @@ def high(self, value):
self.device = value.device
self._high = value.cpu()

@low.setter
def low(self, value):
self.device = value.device
self._low = value.cpu()

@high.setter
def high(self, value):
self.device = value.device
self._high = value.cpu()

def __post_init__(self):
self.low = self.low.clone()
self.high = self.high.clone()
Expand Down Expand Up @@ -2269,9 +2259,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded:
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
self.space.device = dest_device
return Bounded(
low=self.space.low.to(dest),
high=self.space.high.to(dest),
low=self.space.low,
high=self.space.high,
shape=self.shape,
device=dest_device,
dtype=dest_dtype,
Expand Down

0 comments on commit 57f0580

Please sign in to comment.