Skip to content

Commit

Permalink
Move parse_shape and parse_index to utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Nov 8, 2022
1 parent 1a73200 commit 194b21a
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 106 deletions.
109 changes: 5 additions & 104 deletions firelang/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from collections import OrderedDict, defaultdict
import inspect
import numpy as np
import torch
from torch import Tensor
from torch.nn import Module, ModuleList, ModuleDict
from .utils.index import parse_index, IndexLike
from .utils.shape import parse_shape

__all__ = [
"clear_cache",
Expand All @@ -25,11 +26,6 @@ def current_cache_sizes():
return {key: len(dct) for key, dct in _cache.items()}


IndexLike = Union[
int, slice, List, Tensor, None, Tuple[Union[int, slice, List, Tensor, None], ...]
]


class StackingSlicing(Module):
init_locals: Mapping[str, Any]
unsliceable_params: Set[str]
Expand Down Expand Up @@ -65,7 +61,7 @@ def _sanity_check(self):
)

def view(self, *shape, inplace: bool = False):
shape = _parse_shape(shape, int(np.prod(self.shape)))
shape = parse_shape(shape, int(np.prod(self.shape)))

if inplace:
self.shape = shape
Expand All @@ -89,7 +85,7 @@ def view(self, *shape, inplace: bool = False):
return new

def __getitem__(self, index: IndexLike):
idtensor: Tensor = _parse_index(index, self.shape)
idtensor: Tensor = parse_index(index, self.shape)
ids = idtensor.reshape(-1)
shape = tuple(idtensor.shape)

Expand Down Expand Up @@ -198,99 +194,4 @@ def _recover_args_from_locals(
positional.extend(value)
elif sign.kind == inspect.Parameter.VAR_KEYWORD:
keywords = {**keywords, **value}
return positional, keywords


def _complete_ellipsis(index: Tuple[Any | Ellipsis], ndim: int):
num_ellip = index.count(Ellipsis)
assert num_ellip <= 1, f"Invalid index {index}"
if num_ellip == 0:
return index

i = index.index(Ellipsis)
completed = (
list(index[:i]) + [slice(None)] * (ndim - len(index) + 1) + list(index[i + 1 :])
)
return tuple(completed)


def _parse_index(index: IndexLike, shape: Tuple[int]) -> Tensor:
if not isinstance(index, tuple):
index = (index,)
index = _complete_ellipsis(index, ndim=len(shape))

nindex = len(index)

nindex_notnan = len([idx for idx in index if idx is not None])
stride = int(np.prod(shape[nindex_notnan:]))

ids = torch.tensor([0], dtype=torch.long)
slice_shape = []
dim = len(index) - 1
shape_dim = nindex_notnan - 1
for dim in range(nindex - 1, -1, -1):
if index[dim] is None:
slice_shape.append(1)
continue

index_at_dim = index[dim]
size_at_dim = shape[shape_dim]
if isinstance(index_at_dim, int):
ids = ids + index_at_dim * stride
elif isinstance(index_at_dim, slice):
offsets = torch.arange(size_at_dim)[index_at_dim] * stride
ids = (offsets[:, None] + ids[None, :]).reshape(-1)
slice_shape.append(len(offsets))
elif isinstance(index_at_dim, list):
offsets = torch.tensor(index_at_dim) * stride
ids = (offsets[:, None] + ids[None, :]).reshape(-1)
slice_shape.append(len(offsets))
elif isinstance(index_at_dim, Tensor):
assert index_at_dim.ndim == 1, (
f"Index at dimension {dim} should be 1-dimensional, "
f"not {index_at_dim.ndim}-d."
)
ids = ids.to(index_at_dim.device)
offsets = index_at_dim * stride
ids = (offsets[:, None] + ids[None, :]).reshape(-1)
slice_shape.append(len(offsets))
else:
raise TypeError(
f"Index at dimension {dim} should be "
f"a `int`, a `slice`, or a `Tensor`, not {type(index_at_dim)}"
)

stride *= size_at_dim
shape_dim -= 1

slice_shape = list(reversed(slice_shape))
return ids.reshape(slice_shape)


def _replace_minus_one(shape: Tuple[int], num_elements: int):
num_minus_one = shape.count(-1)
assert num_minus_one <= 1, f"Invalid shape {shape}"
if num_minus_one == 0:
return shape

otherdim = int(np.prod([size for size in shape if size >= 1]))
inferred = num_elements // otherdim
assert (
inferred * otherdim == num_elements
), f"Invalid new shape {shape} for {num_elements} elements"
i = shape.index(-1)
replaced = list(shape[:i]) + [inferred] + list(shape[i + 1 :])
return tuple(replaced)


def _parse_shape(shape, num_elements):
if len(shape) == 1:
assert isinstance(shape, Iterable)
shape = tuple(shape[0])
shape = _replace_minus_one(shape, num_elements=num_elements)
given_num_elements = int(np.prod(shape))
assert (
given_num_elements == num_elements
), f"Inconsistent shape: should have {num_elements} elements, not {given_num_elements}."
return shape

return positional, keywords
75 changes: 75 additions & 0 deletions firelang/utils/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations
from typing import List, Tuple, Union, Any
import numpy as np
import torch
from torch import Tensor

IndexLike = Union[
int, slice, List, Tensor, None, Tuple[Union[int, slice, List, Tensor, None]]
]


def parse_index(index: IndexLike, shape: Tuple[int]) -> Tensor:
if not isinstance(index, tuple):
index = (index,)
index = _complete_ellipsis(index, ndim=len(shape))

nindex = len(index)

nindex_notnan = len([idx for idx in index if idx is not None])
stride = int(np.prod(shape[nindex_notnan:]))

ids = torch.tensor([0], dtype=torch.long)
slice_shape = []
dim = len(index) - 1
shape_dim = nindex_notnan - 1
for dim in range(nindex - 1, -1, -1):
if index[dim] is None:
slice_shape.append(1)
continue

index_at_dim = index[dim]
size_at_dim = shape[shape_dim]
if isinstance(index_at_dim, int):
ids = ids + index_at_dim * stride
elif isinstance(index_at_dim, slice):
offsets = torch.arange(size_at_dim)[index_at_dim] * stride
ids = (offsets[:, None] + ids[None, :]).reshape(-1)
slice_shape.append(len(offsets))
elif isinstance(index_at_dim, list):
offsets = torch.tensor(index_at_dim) * stride
ids = (offsets[:, None] + ids[None, :]).reshape(-1)
slice_shape.append(len(offsets))
elif isinstance(index_at_dim, Tensor):
assert index_at_dim.ndim == 1, (
f"Index at dimension {dim} should be 1-dimensional, "
f"not {index_at_dim.ndim}-d."
)
ids = ids.to(index_at_dim.device)
offsets = index_at_dim * stride
ids = (offsets[:, None] + ids[None, :]).reshape(-1)
slice_shape.append(len(offsets))
else:
raise TypeError(
f"Index at dimension {dim} should be "
f"a `int`, a `slice`, or a `Tensor`, not {type(index_at_dim)}"
)

stride *= size_at_dim
shape_dim -= 1

slice_shape = list(reversed(slice_shape))
return ids.reshape(slice_shape)


def _complete_ellipsis(index: Tuple[Any | Ellipsis], ndim: int):
num_ellip = index.count(Ellipsis)
assert num_ellip <= 1, f"Invalid index {index}"
if num_ellip == 0:
return index

i = index.index(Ellipsis)
completed = (
list(index[:i]) + [slice(None)] * (ndim - len(index) + 1) + list(index[i + 1 :])
)
return tuple(completed)
40 changes: 38 additions & 2 deletions firelang/utils/shape.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,45 @@
from typing import Tuple
from __future__ import annotations
from typing import Tuple, Iterable
import numpy as np


def check_shape_consistency(shape1: Tuple[int], shape2: Tuple[int]):
assert len(shape1) == len(
shape2
), f"Shape inconsistent in number of dimension between {shape1} and {shape2}."
for d1, d2 in zip(shape1, shape2):
if d1 != 1 and d2 != 1 and d1 != d2:
raise ValueError(f"Inconsistent shape: {shape1} and {shape2}")
raise ValueError(f"Inconsistent shape: {shape1} and {shape2}")


def parse_shape(shape, num_elements):
if len(shape) == 1:
shape = shape[0]
if isinstance(shape, int):
shape = (shape,)
elif isinstance(shape, Iterable):
shape = tuple(shape)
else:
raise TypeError(f"Invalid shape {shape}")
shape = _replace_minus_one(shape, num_elements=num_elements)
given_num_elements = int(np.prod(shape))
assert (
given_num_elements == num_elements
), f"Inconsistent shape: should have {num_elements} elements, not {given_num_elements}."
return shape


def _replace_minus_one(shape: Tuple[int], num_elements: int):
num_minus_one = shape.count(-1)
assert num_minus_one <= 1, f"Invalid shape {shape}"
if num_minus_one == 0:
return shape

otherdim = int(np.prod([size for size in shape if size >= 1]))
inferred = num_elements // otherdim
assert (
inferred * otherdim == num_elements
), f"Invalid new shape {shape} for {num_elements} elements"
i = shape.index(-1)
replaced = list(shape[:i]) + [inferred] + list(shape[i + 1 :])
return tuple(replaced)

0 comments on commit 194b21a

Please sign in to comment.