Skip to content

Commit

Permalink
[BugFix] Fix to in MultiDiscreteTensorSpec (#2204)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
Quinticx and vmoens authored Jun 6, 2024
1 parent 013d110 commit 64c0b8e
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 1 deletion.
147 changes: 147 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,6 +1908,153 @@ def test_unboundeddiscrete(
assert spec == torch.stack(spec.unbind(-1), -1)


@pytest.mark.parametrize(
"device",
[torch.device("cpu")]
+ [torch.device(f"cuda:{i}" for i in range(torch.cuda.device_count()))],
)
class TestTo:
@pytest.mark.parametrize("shape1", [(5, 4)])
def test_binary(self, shape1, device):
spec = BinaryDiscreteTensorSpec(
n=4, shape=shape1, device="cpu", dtype=torch.bool
)
assert spec.to(device).device == device

@pytest.mark.parametrize(
"shape1,mini,maxi",
[
[(10,), -torch.ones([]), torch.ones([])],
[None, -torch.ones([10]), torch.ones([])],
[None, -torch.ones([]), torch.ones([10])],
[(10,), -torch.ones([]), torch.ones([10])],
[(10,), -torch.ones([10]), torch.ones([])],
[(10,), -torch.ones([10]), torch.ones([10])],
],
)
def test_bounded(self, shape1, mini, maxi, device):
spec = BoundedTensorSpec(
mini, maxi, shape=shape1, device="cpu", dtype=torch.bool
)
assert spec.to(device).device == device

def test_composite(self, device):
batch_size = (5,)
spec1 = BoundedTensorSpec(
-torch.ones([*batch_size, 10]),
torch.ones([*batch_size, 10]),
shape=(
*batch_size,
10,
),
device="cpu",
dtype=torch.bool,
)
spec2 = BinaryDiscreteTensorSpec(
n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool
)
spec3 = DiscreteTensorSpec(
n=4, shape=batch_size, device="cpu", dtype=torch.long
)
spec4 = MultiDiscreteTensorSpec(
nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long
)
spec5 = MultiOneHotDiscreteTensorSpec(
nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
spec6 = OneHotDiscreteTensorSpec(
n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
spec7 = UnboundedContinuousTensorSpec(
shape=(*batch_size, 9),
device="cpu",
dtype=torch.float64,
)
spec8 = UnboundedDiscreteTensorSpec(
shape=(*batch_size, 9),
device="cpu",
dtype=torch.long,
)
spec = CompositeSpec(
spec1=spec1,
spec2=spec2,
spec3=spec3,
spec4=spec4,
spec5=spec5,
spec6=spec6,
spec7=spec7,
spec8=spec8,
shape=batch_size,
)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_discrete(
self,
shape1,
device,
):
spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_multidiscrete(self, shape1, device):
if shape1 is None:
shape1 = (3,)
else:
shape1 = (*shape1, 3)
spec = MultiDiscreteTensorSpec(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_multionehot(self, shape1, device):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = MultiOneHotDiscreteTensorSpec(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec.to(device).device == device

def test_non_tensor(self, device):
spec = NonTensorSpec(shape=(3, 4), device="cpu")
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(self, shape1, device):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = OneHotDiscreteTensorSpec(
n=15, shape=shape1, device="cpu", dtype=torch.long
)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_unbounded(self, shape1, device):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = UnboundedContinuousTensorSpec(
shape=shape1, device="cpu", dtype=torch.float64
)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_unboundeddiscrete(self, shape1, device):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = UnboundedDiscreteTensorSpec(shape=shape1, device="cpu", dtype=torch.long)
assert spec.to(device).device == device


@pytest.mark.parametrize(
"shape,stack_dim",
[[(), 0], [(2,), 0], [(2,), 1], [(2, 3), 0], [(2, 3), 1], [(2, 3), 2]],
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3369,7 +3369,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
return self
mask = self.mask.to(dest) if self.mask is not None else None
return self.__class__(
n=self.nvec.to(dest),
nvec=self.nvec.to(dest),
shape=None,
device=dest_device,
dtype=dest_dtype,
Expand Down

0 comments on commit 64c0b8e

Please sign in to comment.