From 64c0b8ef4e9516cbf7bcee13bd13f71576fee1ff Mon Sep 17 00:00:00 2001 From: Brianna Date: Thu, 6 Jun 2024 09:29:30 -0500 Subject: [PATCH] [BugFix] Fix `to` in MultiDiscreteTensorSpec (#2204) Co-authored-by: Vincent Moens --- test/test_specs.py | 147 +++++++++++++++++++++++++++++++++++ torchrl/data/tensor_specs.py | 2 +- 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/test/test_specs.py b/test/test_specs.py index 5f39aaadd34..d20dbdb71d5 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -1908,6 +1908,153 @@ def test_unboundeddiscrete( assert spec == torch.stack(spec.unbind(-1), -1) +@pytest.mark.parametrize( + "device", + [torch.device("cpu")] + + [torch.device(f"cuda:{i}" for i in range(torch.cuda.device_count()))], +) +class TestTo: + @pytest.mark.parametrize("shape1", [(5, 4)]) + def test_binary(self, shape1, device): + spec = BinaryDiscreteTensorSpec( + n=4, shape=shape1, device="cpu", dtype=torch.bool + ) + assert spec.to(device).device == device + + @pytest.mark.parametrize( + "shape1,mini,maxi", + [ + [(10,), -torch.ones([]), torch.ones([])], + [None, -torch.ones([10]), torch.ones([])], + [None, -torch.ones([]), torch.ones([10])], + [(10,), -torch.ones([]), torch.ones([10])], + [(10,), -torch.ones([10]), torch.ones([])], + [(10,), -torch.ones([10]), torch.ones([10])], + ], + ) + def test_bounded(self, shape1, mini, maxi, device): + spec = BoundedTensorSpec( + mini, maxi, shape=shape1, device="cpu", dtype=torch.bool + ) + assert spec.to(device).device == device + + def test_composite(self, device): + batch_size = (5,) + spec1 = BoundedTensorSpec( + -torch.ones([*batch_size, 10]), + torch.ones([*batch_size, 10]), + shape=( + *batch_size, + 10, + ), + device="cpu", + dtype=torch.bool, + ) + spec2 = BinaryDiscreteTensorSpec( + n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool + ) + spec3 = DiscreteTensorSpec( + n=4, shape=batch_size, device="cpu", dtype=torch.long + ) + spec4 = MultiDiscreteTensorSpec( + nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long + ) + spec5 = MultiOneHotDiscreteTensorSpec( + nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long + ) + spec6 = OneHotDiscreteTensorSpec( + n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long + ) + spec7 = UnboundedContinuousTensorSpec( + shape=(*batch_size, 9), + device="cpu", + dtype=torch.float64, + ) + spec8 = UnboundedDiscreteTensorSpec( + shape=(*batch_size, 9), + device="cpu", + dtype=torch.long, + ) + spec = CompositeSpec( + spec1=spec1, + spec2=spec2, + spec3=spec3, + spec4=spec4, + spec5=spec5, + spec6=spec6, + spec7=spec7, + spec8=spec8, + shape=batch_size, + ) + assert spec.to(device).device == device + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_discrete( + self, + shape1, + device, + ): + spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + assert spec.to(device).device == device + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_multidiscrete(self, shape1, device): + if shape1 is None: + shape1 = (3,) + else: + shape1 = (*shape1, 3) + spec = MultiDiscreteTensorSpec( + nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long + ) + assert spec.to(device).device == device + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_multionehot(self, shape1, device): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = MultiOneHotDiscreteTensorSpec( + nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long + ) + assert spec.to(device).device == device + + def test_non_tensor(self, device): + spec = NonTensorSpec(shape=(3, 4), device="cpu") + assert spec.to(device).device == device + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_onehot(self, shape1, device): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = OneHotDiscreteTensorSpec( + n=15, shape=shape1, device="cpu", dtype=torch.long + ) + assert spec.to(device).device == device + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_unbounded(self, shape1, device): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = UnboundedContinuousTensorSpec( + shape=shape1, device="cpu", dtype=torch.float64 + ) + assert spec.to(device).device == device + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_unboundeddiscrete(self, shape1, device): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = UnboundedDiscreteTensorSpec(shape=shape1, device="cpu", dtype=torch.long) + assert spec.to(device).device == device + + @pytest.mark.parametrize( "shape,stack_dim", [[(), 0], [(2,), 0], [(2,), 1], [(2, 3), 0], [(2, 3), 1], [(2, 3), 2]], diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index c56c3a00bd6..a6ccf0dbeaf 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3369,7 +3369,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: return self mask = self.mask.to(dest) if self.mask is not None else None return self.__class__( - n=self.nvec.to(dest), + nvec=self.nvec.to(dest), shape=None, device=dest_device, dtype=dest_dtype,