Skip to content

Commit

Permalink
[Minor] Missing lint (pytorch#1556)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 20, 2023
1 parent fda88bb commit 57f1220
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 93 deletions.
211 changes: 211 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,217 @@ def test_unboundeddiscrete(
assert spec == torch.stack(spec.unbind(0), 0)
assert spec == torch.stack(spec.unbind(-1), -1)

class TestUnbind:
@pytest.mark.parametrize("shape1", [(5, 4)])
def test_binary(self, shape1):
spec = BinaryDiscreteTensorSpec(
n=4, shape=shape1, device="cpu", dtype=torch.bool
)
assert spec == torch.stack(spec.unbind(0), 0)
with pytest.raises(ValueError):
spec.unbind(-1)

@pytest.mark.parametrize(
"shape1,mini,maxi",
[
[(10,), -torch.ones([]), torch.ones([])],
[None, -torch.ones([10]), torch.ones([])],
[None, -torch.ones([]), torch.ones([10])],
[(10,), -torch.ones([]), torch.ones([10])],
[(10,), -torch.ones([10]), torch.ones([])],
[(10,), -torch.ones([10]), torch.ones([10])],
],
)
def test_bounded(self, shape1, mini, maxi):
spec = BoundedTensorSpec(
mini, maxi, shape=shape1, device="cpu", dtype=torch.bool
)
assert spec == torch.stack(spec.unbind(0), 0)
with pytest.raises(ValueError):
spec.unbind(-1)

def test_composite(self):
batch_size = (5,)
spec1 = BoundedTensorSpec(
-torch.ones([*batch_size, 10]),
torch.ones([*batch_size, 10]),
shape=(
*batch_size,
10,
),
device="cpu",
dtype=torch.bool,
)
spec2 = BinaryDiscreteTensorSpec(
n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool
)
spec3 = DiscreteTensorSpec(
n=4, shape=batch_size, device="cpu", dtype=torch.long
)
spec4 = MultiDiscreteTensorSpec(
nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long
)
spec5 = MultiOneHotDiscreteTensorSpec(
nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
spec6 = OneHotDiscreteTensorSpec(
n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
spec7 = UnboundedContinuousTensorSpec(
shape=(*batch_size, 9),
device="cpu",
dtype=torch.float64,
)
spec8 = UnboundedDiscreteTensorSpec(
shape=(*batch_size, 9),
device="cpu",
dtype=torch.long,
)
spec = CompositeSpec(
spec1=spec1,
spec2=spec2,
spec3=spec3,
spec4=spec4,
spec5=spec5,
spec6=spec6,
spec7=spec7,
spec8=spec8,
shape=batch_size,
)
assert spec == torch.stack(spec.unbind(0), 0)
assert spec == torch.stack(spec.unbind(-1), -1)

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_discrete(
self,
shape1,
):
spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long)
assert spec == torch.stack(spec.unbind(0), 0)
assert spec == torch.stack(spec.unbind(-1), -1)

@pytest.mark.parametrize(
"shape1",
[
(5,),
(
5,
6,
),
],
)
def test_multidiscrete(
self,
shape1,
):
if shape1 is None:
shape1 = (3,)
else:
shape1 = (*shape1, 3)
spec = MultiDiscreteTensorSpec(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec == torch.stack(spec.unbind(0), 0)
with pytest.raises(ValueError):
spec.unbind(-1)

@pytest.mark.parametrize(
"shape1",
[
(5,),
(
5,
6,
),
],
)
def test_multionehot(
self,
shape1,
):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = MultiOneHotDiscreteTensorSpec(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec == torch.stack(spec.unbind(0), 0)
with pytest.raises(ValueError):
spec.unbind(-1)

@pytest.mark.parametrize(
"shape1",
[
(5,),
(
5,
6,
),
],
)
def test_onehot(
self,
shape1,
):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = OneHotDiscreteTensorSpec(
n=15, shape=shape1, device="cpu", dtype=torch.long
)
assert spec == torch.stack(spec.unbind(0), 0)
with pytest.raises(ValueError):
spec.unbind(-1)

@pytest.mark.parametrize(
"shape1",
[
(5,),
(
5,
6,
),
],
)
def test_unbounded(
self,
shape1,
):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = UnboundedContinuousTensorSpec(
shape=shape1, device="cpu", dtype=torch.float64
)
assert spec == torch.stack(spec.unbind(0), 0)
assert spec == torch.stack(spec.unbind(-1), -1)

@pytest.mark.parametrize(
"shape1",
[
(5,),
(
5,
6,
),
],
)
def test_unboundeddiscrete(
self,
shape1,
):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = UnboundedDiscreteTensorSpec(shape=shape1, device="cpu", dtype=torch.long)
assert spec == torch.stack(spec.unbind(0), 0)
assert spec == torch.stack(spec.unbind(-1), -1)


@pytest.mark.parametrize(
"shape,stack_dim",
[[(), 0], [(2,), 0], [(2,), 1], [(2, 3), 0], [(2, 3), 1], [(2, 3), 2]],
Expand Down
Loading

0 comments on commit 57f1220

Please sign in to comment.