Skip to content

Commit

Permalink
[Feature] Indexing specs (pytorch#1105)
Browse files Browse the repository at this point in the history
  • Loading branch information
remidomingues authored Apr 28, 2023
1 parent 69dc0fa commit e80930d
Show file tree
Hide file tree
Showing 2 changed files with 340 additions and 80 deletions.
194 changes: 145 additions & 49 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,11 +2348,21 @@ def test_to_numpy(self):
c.to_numpy(td_fail)


@pytest.mark.parametrize("spec", OneHotDiscreteTensorSpec(n=4, shape=[3, 4]))
# MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080.
@pytest.mark.parametrize(
"spec_class",
[
BinaryDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
CompositeSpec,
],
)
@pytest.mark.parametrize(
"idx",
[
5,
(0, 1),
range(10),
np.array([[2, 10]]),
(slice(None), slice(1, 2), 1),
Expand All @@ -2361,74 +2371,112 @@ def test_to_numpy(self):
torch.tensor([10, 2]),
], # [:,1:2,1]
)
def test_invalid_indices(spec, idx):
def test_invalid_indexing(spec_class, idx):
if spec_class in [BinaryDiscreteTensorSpec, OneHotDiscreteTensorSpec]:
spec = spec_class(n=4, shape=[3, 4])
elif spec_class == MultiDiscreteTensorSpec:
spec = spec_class([2, 2, 2], shape=[3])
elif spec_class == MultiOneHotDiscreteTensorSpec:
spec = spec_class([4], shape=[3, 4])
elif spec_class == CompositeSpec:
spec = spec_class(k=UnboundedDiscreteTensorSpec(shape=(3, 4)), shape=(3,))
with pytest.raises(IndexError):
spec[idx]


@pytest.mark.parametrize("spec_class", [OneHotDiscreteTensorSpec, DiscreteTensorSpec])
def test_valid_indices(spec_class):
empty_spec = spec_class(0)
spec = spec_class(n=4, shape=[3, 4])
spec_3d = spec_class(n=4, shape=[5, 3, 4])
spec_4d = spec_class(n=6, shape=[5, 3, 4, 6])
spec_5d = spec_class(n=7, shape=[5, 3, 4, 6, 7])
# BoundedTensorSpec, MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080.
@pytest.mark.parametrize(
"spec_class",
[
BinaryDiscreteTensorSpec,
DiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
UnboundedContinuousTensorSpec,
UnboundedDiscreteTensorSpec,
CompositeSpec,
],
)
def test_valid_indexing(spec_class):
# Default args. UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec
args = {"0d": [], "2d": [], "3d": [], "4d": [], "5d": []}
kwargs = {}
if spec_class in [
BinaryDiscreteTensorSpec,
DiscreteTensorSpec,
OneHotDiscreteTensorSpec,
]:
args = {"0d": [0], "2d": [3], "3d": [4], "4d": [6], "5d": [7]}
elif spec_class == MultiOneHotDiscreteTensorSpec:
args = {"0d": [[0]], "2d": [[3]], "3d": [[4]], "4d": [[6]], "5d": [[7]]}
elif spec_class == MultiDiscreteTensorSpec:
args = {
"0d": [[0]],
"2d": [[2] * 3],
"3d": [[2] * 4],
"4d": [[1] * 6],
"5d": [[2] * 7],
}
elif spec_class == BoundedTensorSpec:
min_max = (-1, -1)
args = {
"0d": min_max,
"2d": min_max,
"3d": min_max,
"4d": min_max,
"5d": min_max,
}
elif spec_class == CompositeSpec:
kwargs = {
"k1": UnboundedDiscreteTensorSpec(shape=(5, 3, 4, 6, 7, 8)),
"k2": OneHotDiscreteTensorSpec(n=7, shape=(5, 3, 4, 6, 7)),
}

spec_0d = spec_class(*args["0d"], **kwargs)
if spec_class in [
UnboundedContinuousTensorSpec,
UnboundedDiscreteTensorSpec,
CompositeSpec,
]:
spec_0d = spec_class(*args["0d"], shape=[], **kwargs)
spec_2d = spec_class(*args["2d"], shape=[5, 3], **kwargs)
spec_3d = spec_class(*args["3d"], shape=[5, 3, 4], **kwargs)
spec_4d = spec_class(*args["4d"], shape=[5, 3, 4, 6], **kwargs)
spec_5d = spec_class(*args["5d"], shape=[5, 3, 4, 6, 7], **kwargs)

# Integers
assert spec[1].shape == torch.Size([4])
if not isinstance(spec, OneHotDiscreteTensorSpec):
assert spec[0, 1].shape == torch.Size([])
assert spec_2d[1].shape == torch.Size([3])
# Lists
assert spec_3d[[1, 2]].shape == torch.Size([2, 3, 4])
assert spec[[0]].shape == torch.Size([1, 4])
assert spec[[[[0]]]].shape == torch.Size([1, 1, 1, 4])
assert spec[[0, 1]].shape == torch.Size([2, 4])
assert spec[[[0, 1]]].shape == torch.Size([1, 2, 4])
assert spec_2d[[0]].shape == torch.Size([1, 3])
assert spec_2d[[[[0]]]].shape == torch.Size([1, 1, 1, 3])
assert spec_2d[[0, 1]].shape == torch.Size([2, 3])
assert spec_2d[[[0, 1]]].shape == torch.Size([1, 2, 3])
assert spec_3d[[0, 1], [0, 1]].shape == torch.Size([2, 4])
assert spec[[[0, 1], [0, 1]]].shape == torch.Size([2, 2, 4])
assert spec_2d[[[0, 1], [0, 1]]].shape == torch.Size([2, 2, 3])
# Tuples
assert spec_3d[1, 2].shape == torch.Size([4])
assert spec_3d[(1, 2)].shape == torch.Size([4])
assert spec_3d[((1, 2))].shape == torch.Size([4])
# Ranges
assert spec[range(2)].shape == torch.Size([2, 4])
assert spec_2d[range(2)].shape == torch.Size([2, 3])
# Slices
assert spec[:].shape == torch.Size([3, 4])
assert spec[10:].shape == torch.Size([0, 4])
assert spec[:1].shape == torch.Size([1, 4])
assert spec[1:2].shape == torch.Size([1, 4])
assert spec[10:1:-1].shape == torch.Size([1, 4])
assert spec[-5:-1].shape == torch.Size([2, 4])
assert spec_2d[:].shape == torch.Size([5, 3])
assert spec_2d[10:].shape == torch.Size([0, 3])
assert spec_2d[:1].shape == torch.Size([1, 3])
assert spec_2d[1:2].shape == torch.Size([1, 3])
assert spec_2d[10:1:-1].shape == torch.Size([3, 3])
assert spec_2d[-5:-1].shape == torch.Size([4, 3])
assert spec_3d[[1, 2], 3:].shape == torch.Size([2, 0, 4])
# None (adds a singleton dimension where needed)
assert spec[None].shape == torch.Size([1, 3, 4])
assert spec[None, :2].shape == torch.Size([1, 2, 4])
expected_shape = [1, 0] if isinstance(spec, OneHotDiscreteTensorSpec) else [1]
assert empty_spec[None].shape == torch.Size(expected_shape)
assert spec_2d[None].shape == torch.Size([1, 5, 3])
assert spec_2d[None, :2].shape == torch.Size([1, 2, 3])
# Ellipsis
expected_shape = [0] if isinstance(spec, OneHotDiscreteTensorSpec) else []
assert empty_spec[...].shape == torch.Size(expected_shape)
expected_shape = [2, 4] if isinstance(spec, OneHotDiscreteTensorSpec) else [3, 2]
assert spec[..., :2].shape == torch.Size(expected_shape)
expected_shape = (
[2, 1, 1, 4] if isinstance(spec, OneHotDiscreteTensorSpec) else [3, 2, 1, 1]
)
assert spec[..., :2, None, None].shape == torch.Size(expected_shape)
expected_shape = [3, 6] if isinstance(spec, OneHotDiscreteTensorSpec) else [3, 4]
assert spec_4d[1, ..., 2].shape == torch.Size(expected_shape)
assert spec[1, ...].shape == torch.Size([4])
expected_shape = [1, 4] if isinstance(spec, OneHotDiscreteTensorSpec) else [4, 1]
assert spec[1, ..., None].shape == torch.Size(expected_shape)
expected_shape = [2, 4] if isinstance(spec, OneHotDiscreteTensorSpec) else [5, 2]
assert spec_3d[..., [0, 1], [0]].shape == torch.Size(expected_shape)
expected_shape = (
[1, 3, 1, 4] if isinstance(spec, OneHotDiscreteTensorSpec) else [1, 3, 4, 1]
)
assert spec_3d[None, 1, ..., None].shape == torch.Size(expected_shape)
assert spec_2d[1, ...].shape == torch.Size([3])
# Numpy arrays
assert spec[np.array([[1, 2]])].shape == torch.Size([1, 2, 4])
assert spec_2d[np.array([[1, 2]])].shape == torch.Size([1, 2, 3])
# Tensors
assert spec[torch.randint(3, (3, 2))].shape == torch.Size([3, 2, 4])
assert spec_2d[torch.randint(3, (3, 2))].shape == torch.Size([3, 2, 3])
# Tuples
# Note: nested tuples are supported by specs but transformed into lists, similarity to numpy
assert spec_3d[(0, 1), (0, 1)].shape == torch.Size([2, 4])
Expand Down Expand Up @@ -2456,6 +2504,54 @@ def test_valid_indices(spec_class):
# assert spec_5d[2:, [[[0, 1]]], :3, [0]].shape == torch.Size([1, 1, 2, 3, 3, 7])
# assert spec_5d[2:, [[[0, 1]]], :3, [[[0, 1]]]].shape == torch.Size([1, 1, 2, 3, 3, 7])

# Specific tests when specs have non-indexable dimensions
if spec_class in [
BinaryDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
]:
# Ellipsis
assert spec_0d[None].shape == torch.Size([1, 0])
assert spec_0d[...].shape == torch.Size([0])
assert spec_2d[..., :2].shape == torch.Size([2, 3])
assert spec_2d[..., :2, None, None].shape == torch.Size([2, 1, 1, 3])
assert spec_4d[1, ..., 2].shape == torch.Size([3, 6])
assert spec_2d[1, ..., None].shape == torch.Size([1, 3])
assert spec_3d[..., [0, 1], [0]].shape == torch.Size([2, 4])
assert spec_3d[None, 1, ..., None].shape == torch.Size([1, 3, 1, 4])
assert spec_4d[:, None, ..., None, :].shape == torch.Size([5, 1, 3, 1, 4, 6])

# BoundedTensorSpec, DiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, CompositeSpec
else:
# Integers
assert spec_2d[0, 1].shape == torch.Size([])

# Ellipsis
assert spec_0d[None].shape == torch.Size([1])
assert spec_0d[...].shape == torch.Size([])
assert spec_2d[..., :2].shape == torch.Size([5, 2])
assert spec_2d[..., :2, None, None].shape == torch.Size([5, 2, 1, 1])
assert spec_4d[1, ..., 2].shape == torch.Size([3, 4])
assert spec_2d[1, ..., None].shape == torch.Size([3, 1])
assert spec_3d[..., [0, 1], [0]].shape == torch.Size([5, 2])
assert spec_3d[None, 1, ..., None].shape == torch.Size([1, 3, 4, 1])
assert spec_4d[:, None, ..., None, :].shape == torch.Size([5, 1, 3, 4, 1, 6])

# Additional tests for composite spec
if spec_class == CompositeSpec:
assert spec_2d[1]["k1"].shape == torch.Size([3, 4, 6, 7, 8])
assert spec_3d[[1, 2]]["k1"].shape == torch.Size([2, 3, 4, 6, 7, 8])
assert spec_2d[torch.randint(3, (3, 2))]["k1"].shape == torch.Size(
[3, 2, 3, 4, 6, 7, 8]
)
assert spec_0d["k1"].shape == torch.Size([5, 3, 4, 6, 7, 8])
assert spec_0d[None]["k1"].shape == torch.Size([1, 5, 3, 4, 6, 7, 8])

assert spec_2d[..., 0]["k1"].shape == torch.Size([5, 4, 6, 7, 8])
assert spec_4d[1, ..., 2]["k2"].shape == torch.Size([3, 4, 7])
assert spec_2d[1, ..., None]["k2"].shape == torch.Size([3, 1, 4, 6, 7])


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
Loading

0 comments on commit e80930d

Please sign in to comment.