Skip to content

Commit

Permalink
[Feature] Indexing Discrete and OneHot specs (pytorch#1081)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
remidomingues and vmoens authored Apr 25, 2023
1 parent bae693e commit ce88a95
Show file tree
Hide file tree
Showing 2 changed files with 311 additions and 1 deletion.
109 changes: 109 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,6 +2348,115 @@ def test_to_numpy(self):
c.to_numpy(td_fail)


@pytest.mark.parametrize("spec", OneHotDiscreteTensorSpec(n=4, shape=[3, 4]))
@pytest.mark.parametrize(
"idx",
[
5,
range(10),
np.array([[2, 10]]),
(slice(None), slice(1, 2), 1),
(1, ..., 2, ..., 3),
(1, 1, 1, 1),
torch.tensor([10, 2]),
], # [:,1:2,1]
)
def test_invalid_indices(spec, idx):
with pytest.raises(IndexError):
spec[idx]


@pytest.mark.parametrize("spec_class", [OneHotDiscreteTensorSpec, DiscreteTensorSpec])
def test_valid_indices(spec_class):
empty_spec = spec_class(0)
spec = spec_class(n=4, shape=[3, 4])
spec_3d = spec_class(n=4, shape=[5, 3, 4])
spec_4d = spec_class(n=6, shape=[5, 3, 4, 6])
spec_5d = spec_class(n=7, shape=[5, 3, 4, 6, 7])

# Integers
assert spec[1].shape == torch.Size([4])
if not isinstance(spec, OneHotDiscreteTensorSpec):
assert spec[0, 1].shape == torch.Size([])
# Lists
assert spec_3d[[1, 2]].shape == torch.Size([2, 3, 4])
assert spec[[0]].shape == torch.Size([1, 4])
assert spec[[[[0]]]].shape == torch.Size([1, 1, 1, 4])
assert spec[[0, 1]].shape == torch.Size([2, 4])
assert spec[[[0, 1]]].shape == torch.Size([1, 2, 4])
assert spec_3d[[0, 1], [0, 1]].shape == torch.Size([2, 4])
assert spec[[[0, 1], [0, 1]]].shape == torch.Size([2, 2, 4])
# Tuples
assert spec_3d[1, 2].shape == torch.Size([4])
assert spec_3d[(1, 2)].shape == torch.Size([4])
assert spec_3d[((1, 2))].shape == torch.Size([4])
# Ranges
assert spec[range(2)].shape == torch.Size([2, 4])
# Slices
assert spec[:].shape == torch.Size([3, 4])
assert spec[10:].shape == torch.Size([0, 4])
assert spec[:1].shape == torch.Size([1, 4])
assert spec[1:2].shape == torch.Size([1, 4])
assert spec[10:1:-1].shape == torch.Size([1, 4])
assert spec[-5:-1].shape == torch.Size([2, 4])
assert spec_3d[[1, 2], 3:].shape == torch.Size([2, 0, 4])
# None (adds a singleton dimension where needed)
assert spec[None].shape == torch.Size([1, 3, 4])
assert spec[None, :2].shape == torch.Size([1, 2, 4])
expected_shape = [1, 0] if isinstance(spec, OneHotDiscreteTensorSpec) else [1]
assert empty_spec[None].shape == torch.Size(expected_shape)
# Ellipsis
expected_shape = [0] if isinstance(spec, OneHotDiscreteTensorSpec) else []
assert empty_spec[...].shape == torch.Size(expected_shape)
expected_shape = [2, 4] if isinstance(spec, OneHotDiscreteTensorSpec) else [3, 2]
assert spec[..., :2].shape == torch.Size(expected_shape)
expected_shape = (
[2, 1, 1, 4] if isinstance(spec, OneHotDiscreteTensorSpec) else [3, 2, 1, 1]
)
assert spec[..., :2, None, None].shape == torch.Size(expected_shape)
expected_shape = [3, 6] if isinstance(spec, OneHotDiscreteTensorSpec) else [3, 4]
assert spec_4d[1, ..., 2].shape == torch.Size(expected_shape)
assert spec[1, ...].shape == torch.Size([4])
expected_shape = [1, 4] if isinstance(spec, OneHotDiscreteTensorSpec) else [4, 1]
assert spec[1, ..., None].shape == torch.Size(expected_shape)
expected_shape = [2, 4] if isinstance(spec, OneHotDiscreteTensorSpec) else [5, 2]
assert spec_3d[..., [0, 1], [0]].shape == torch.Size(expected_shape)
expected_shape = (
[1, 3, 1, 4] if isinstance(spec, OneHotDiscreteTensorSpec) else [1, 3, 4, 1]
)
assert spec_3d[None, 1, ..., None].shape == torch.Size(expected_shape)
# Numpy arrays
assert spec[np.array([[1, 2]])].shape == torch.Size([1, 2, 4])
# Tensors
assert spec[torch.randint(3, (3, 2))].shape == torch.Size([3, 2, 4])
# Tuples
# Note: nested tuples are supported by specs but transformed into lists, similarity to numpy
assert spec_3d[(0, 1), (0, 1)].shape == torch.Size([2, 4])
assert spec_3d[:2, (0, 1)].shape == torch.Size([2, 2, 4])
assert spec_3d[:2, [0, 1]].shape == torch.Size([2, 2, 4])
assert spec_3d[:2, torch.tensor([0, 1])].shape == torch.Size([2, 2, 4])
assert spec_3d[:2, range(3)].shape == torch.Size([2, 3, 4])
assert spec_3d[:2, np.array([[1, 2]])].shape == torch.Size([2, 1, 2, 4])
assert spec_3d[:2, [0]].shape == torch.Size([2, 1, 4])
assert spec_3d[:2, 0].shape == torch.Size([2, 4])
assert spec_3d[[0, 1], [0]].shape == torch.Size([2, 4])
assert spec_4d[:, 1:2, 1].shape == torch.Size([5, 1, 6])
assert spec_3d[1:, range(3)].shape == torch.Size([4, 3, 4])
assert spec_3d[[[[[0, 1]]]], [[0]]].shape == torch.Size([1, 1, 1, 2, 4])
assert spec_3d[0, [[[[0, 1]]]]].shape == torch.Size([1, 1, 1, 2, 4])
assert spec_3d[0, ((((0, 1))))].shape == torch.Size([2, 4])
assert spec_3d[((((0, 1)))), [0, 2]].shape == torch.Size([2, 4])
assert spec_4d[2:, [[[0, 1]]], :3].shape == torch.Size([3, 1, 1, 2, 3, 6])
assert spec_5d[2:, [[[0, 1]]], [[0, 1]], :3].shape == torch.Size([3, 1, 1, 2, 3, 7])
assert spec_5d[2:, [[[0, 1]]], 0, :3].shape == torch.Size([3, 1, 1, 2, 3, 7])
assert spec_5d[2:, [[[0, 1]]], :3, 0].shape == torch.Size(
[3, 1, 1, 2, 3, 7]
) # Matches tensordict & tensor's behavior. Numpy would return (1, 1, 2, 3, 3, 7).
# TODO: Fix these tests.
# assert spec_5d[2:, [[[0, 1]]], :3, [0]].shape == torch.Size([1, 1, 2, 3, 3, 7])
# assert spec_5d[2:, [[[0, 1]]], :3, [[[0, 1]]]].shape == torch.Size([1, 1, 2, 3, 3, 7])


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
203 changes: 202 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from __future__ import annotations

import abc
import math
import warnings
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import dataclass
from functools import wraps
Expand All @@ -30,17 +32,39 @@
import numpy as np
import torch
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import _getitem_batch_size

from torchrl._utils import get_binary_env_var

DEVICE_TYPING = Union[torch.device, str, int]

INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List]

SHAPE_INDEX_TYPING = Union[
int,
range,
List[int],
np.ndarray,
slice,
None,
torch.Tensor,
type(...),
Tuple[
int,
range,
List[int],
np.ndarray,
slice,
None,
torch.Tensor,
type(...),
Tuple[Any],
],
]

# By default, we do not check that an obs is in the domain. THis should be done when validating the env beforehand
_CHECK_SPEC_ENCODE = get_binary_env_var("CHECK_SPEC_ENCODE")


_DEFAULT_SHAPE = torch.Size((1,))

DEVICE_ERR_MSG = "device of empty CompositeSpec is not defined."
Expand All @@ -58,6 +82,167 @@ def _default_dtype_and_device(
return dtype, device


def _validate_idx(shape: list[int], idx: int, axis: int = 0):
if idx >= shape[axis] or idx < 0 and -idx > shape[axis]:
raise IndexError(
f"index {idx} is out of bounds for axis {axis} with size {shape[axis]}"
)


def _validate_iterable(
idx: Iterable[Any], expected_type: type, iterable_classname: str
):
for item in idx:
if isinstance(item, Iterable):
_validate_iterable(item, expected_type, iterable_classname)
else:
if not isinstance(item, expected_type):
raise IndexError(
f"{iterable_classname} indexing expects {expected_type} indices"
)


def _slice_indexing(shape: list[int], idx: slice):
if idx.step == 0:
raise ValueError("slice step cannot be zero")
# Slicing an empty shape returns the shape
if len(shape) == 0:
return shape

if idx.start is None:
start = 0
else:
start = idx.start if idx.start >= 0 else max(shape[0] + idx.start, 0)

if idx.stop is None:
stop = shape[0]
else:
stop = idx.stop if idx.stop >= 0 else max(shape[0] + idx.stop, 0)

step = 1 if idx.step is None else idx.step
if step > 0:
if start >= stop:
n_items = 0
else:
stop = min(stop, shape[0])
n_items = math.ceil((stop - start) / step)
else:
if start <= stop:
n_items = 0
else:
start = min(start, shape[0] - 1)
n_items = math.ceil((stop - start) / step)
return [n_items] + shape[1:]


def _shape_indexing(shape: list[int], idx: SHAPE_INDEX_TYPING):
if idx is Ellipsis or (
isinstance(idx, slice) and (idx.step is idx.start is idx.stop is None)
):
return shape

if idx is None:
return [1] + shape

if len(shape) == 0 and (
isinstance(idx, int)
or isinstance(idx, range)
or isinstance(idx, list)
and len(idx) > 0
):
raise IndexError(
f"cannot use integer indices on 0-dimensional shape. `{idx}` received"
)

if isinstance(idx, int):
_validate_idx(shape, idx)
return shape[1:]

if isinstance(idx, range):
if len(idx) > 0 and (idx.start >= shape[0] or idx.stop > shape[0]):
raise IndexError(f"index out of bounds for axis 0 with size {shape[0]}")
return [len(idx)] + shape[1:]

if isinstance(idx, slice):
return _slice_indexing(shape, idx)

if isinstance(idx, tuple):
# Supports int, None, slice and ellipsis indices
head_new_axes, tail_new_axes = 0, 0
# Index on the current shape dimension
shape_idx = 0
ellipsis = False
prev_is_list = False
shape_len = len(shape)
for item_idx, item in enumerate(idx):
if item is None:
if ellipsis:
tail_new_axes += 1
else:
head_new_axes += 1
elif isinstance(item, int):
_validate_idx(shape, item, shape_idx)
del shape[shape_idx]
elif isinstance(item, slice):
shape[shape_idx] = _slice_indexing([shape[shape_idx]], item)[0]
shape_idx += 1
elif item is Ellipsis:
if ellipsis:
raise IndexError("an index can only have a single ellipsis (`...`)")
# Move to the end of the shape, subtracted by the number of future indices impacting the dimensions (i.e. all except None and ...)
shape_idx = len(shape) - len(
[i for i in idx[item_idx + 1 :] if not (i is None or i is Ellipsis)]
)
ellipsis = True
elif any(
isinstance(item, _type)
for _type in [list, tuple, range, np.ndarray, torch.Tensor]
):
while isinstance(idx, tuple) and len(idx) == 1:
idx = idx[0]

# Nested tuples are handled as a list. Numpy behavior
if isinstance(item, tuple):
item = list(item)

if prev_is_list and isinstance(item, list):
del shape[shape_idx]
continue

if isinstance(item, list):
prev_is_list = True

if shape_idx >= len(shape):
raise IndexError("Raise IndexError: too many indices for array")

res = _shape_indexing([shape[shape_idx]], item)
shape = shape[:shape_idx] + res + shape[shape_idx + 1 :]
shape_idx += len(res)
else:
raise IndexError(
f"tuple indexing only supports integers, slices (`:`), ellipsis (`...`), new axis (`None`), tuples and list indices. {str(type(idx))} received"
)

if len(idx) - head_new_axes - tail_new_axes - int(ellipsis) > shape_len:
raise IndexError(
f"shape is {shape_len}-dimensional, but {len(idx) - head_new_axes - tail_new_axes - int(ellipsis)} dimensions were indexed"
)
return [1] * head_new_axes + shape + [1] * tail_new_axes

if isinstance(idx, list):
# int indexing only
_validate_iterable(idx, int, "list")
for item in np.array(idx).reshape(-1):
_validate_idx(shape, item, 0)
return list(np.array(idx).shape) + shape[1:]

if isinstance(idx, np.ndarray) or isinstance(idx, torch.Tensor):
# Out of bounds check
for item in idx.reshape(-1):
_validate_idx(shape, item)
return list(_getitem_batch_size(shape, idx))


class invertible_dict(dict):
"""An invertible dictionary.
Expand Down Expand Up @@ -922,6 +1107,14 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
index = index.expand((*tensor_to_index.shape[:-1], index.shape[-1]))
return tensor_to_index.gather(-1, index)

def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"""Indexes the current TensorSpec based on the provided index."""
# Excluding encoding dimension is excluded from indexing
indexed_shape = _shape_indexing(list(self.shape[:-1]), idx)
spec = deepcopy(self)
spec.shape = torch.Size(indexed_shape + [self.shape[-1]])
return spec

def _project(self, val: torch.Tensor) -> torch.Tensor:
# idx = val.sum(-1) != 1
out = torch.nn.functional.gumbel_softmax(val.to(torch.float))
Expand Down Expand Up @@ -1636,6 +1829,14 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
def is_in(self, val: torch.Tensor) -> bool:
return (0 <= val).all() and (val < self.space.n).all()

def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"""Indexes the current TensorSpec based on the provided index."""
# Excluding encoding dimension is excluded from indexing
indexed_shape = _shape_indexing(list(self.shape), idx)
spec = deepcopy(self)
spec.shape = torch.Size(indexed_shape)
return spec

def __eq__(self, other):
return (
type(self) == type(other)
Expand Down

0 comments on commit ce88a95

Please sign in to comment.