Skip to content

Commit

Permalink
Merge pull request #221 from arogozhnikov/type-hinting
Browse files Browse the repository at this point in the history
Improve type hinting
  • Loading branch information
arogozhnikov authored Nov 5, 2022
2 parents b1f905a + 4bd1170 commit e2c9bb2
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 103 deletions.
30 changes: 17 additions & 13 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

__author__ = 'Alex Rogozhnikov'

_backends = {}
_backends: dict = {}
_debug_importing = False


Expand Down Expand Up @@ -54,7 +54,7 @@ def get_backend(tensor) -> 'AbstractBackend':

class AbstractBackend:
""" Base backend class, major part of methods are only for debugging purposes. """
framework_name = None
framework_name: str

def is_appropriate_type(self, tensor):
""" helper method should recognize tensors it can handle """
Expand Down Expand Up @@ -137,7 +137,7 @@ def __rmul__(self, other):
return self

def __hash__(self):
return None.__hash__()
return hash(None)


class NumpyBackend(AbstractBackend):
Expand Down Expand Up @@ -322,16 +322,20 @@ def arange(self, start, stop):
return self.torch.arange(start, stop, dtype=self.torch.int64)

def reduce(self, x, operation, reduced_axes):
for axis in sorted(reduced_axes, reverse=True):
if operation == 'min':
x, _ = x.min(dim=axis)
elif operation == 'max':
x, _ = x.max(dim=axis)
elif operation in ['sum', 'mean', 'prod']:
x = getattr(x, operation)(dim=axis)
else:
raise NotImplementedError('Unknown reduction ', operation)
return x
if operation == 'min':
return x.amin(dim=reduced_axes)
elif operation == 'max':
return x.amax(dim=reduced_axes)
elif operation == 'sum':
return x.sum(dim=reduced_axes)
elif operation == 'mean':
return x.mean(dim=reduced_axes)
elif operation == 'prod':
for i in list(sorted(reduced_axes))[::-1]:
x = x.prod(dim=i)
return x
else:
raise NotImplementedError('Unknown reduction ', operation)

def transpose(self, x, axes):
return x.permute(axes)
Expand Down
53 changes: 27 additions & 26 deletions einops/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import string
import typing
from collections import OrderedDict
from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar
from typing import Set, Tuple, List, Dict, Union, Callable, Optional, TypeVar, cast

if typing.TYPE_CHECKING:
import numpy as np
Expand All @@ -13,7 +13,7 @@
from .parsing import ParsedExpression, _ellipsis, AnonymousAxis

Tensor = TypeVar('Tensor')
ReductionCallable = Callable[[Tensor, List[int]], Tensor]
ReductionCallable = Callable[[Tensor, Tuple[int]], Tensor]
Reduction = Union[str, ReductionCallable]

_reductions = ('min', 'max', 'sum', 'mean', 'prod')
Expand All @@ -38,10 +38,9 @@ def _product(sequence: List[int]) -> int:


def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend):
reduced_axes = tuple(reduced_axes)
if callable(reduction_type):
# custom callable
return reduction_type(tensor, reduced_axes)
return reduction_type(tensor, tuple(reduced_axes))
else:
# one of built-in operations
if len(reduced_axes) == 0:
Expand All @@ -50,7 +49,7 @@ def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], bac
if reduction_type == 'mean':
if not backend.is_float_type(tensor):
raise NotImplementedError('reduce_mean is not available for non-floating tensors')
return backend.reduce(tensor, reduction_type, reduced_axes)
return backend.reduce(tensor, reduction_type, tuple(reduced_axes))


def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes):
Expand Down Expand Up @@ -171,7 +170,7 @@ def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) ->
after_ellipsis = input_axis + len(shape) - len(self.input_composite_axes)
if input_axis == self.ellipsis_position_in_lhs:
assert len(known_axes) == 0 and len(unknown_axes) == 1
unknown_axis, = unknown_axes
unknown_axis: int = unknown_axes[0]
ellipsis_shape = shape[before_ellipsis:after_ellipsis + 1]
for d in ellipsis_shape:
if d is None:
Expand Down Expand Up @@ -201,7 +200,7 @@ def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) ->
raise EinopsError("Shape mismatch, can't divide axis of length {} in chunks of {}".format(
length, known_product))

unknown_axis: int = unknown_axes[0]
unknown_axis = unknown_axes[0]
inferred_length: int = length // known_product
axes_lengths[unknown_axis] = inferred_length

Expand Down Expand Up @@ -249,9 +248,9 @@ def _prepare_transformation_recipe(pattern: str,
""" Perform initial parsing of pattern and provided supplementary info
axes_lengths is a tuple of tuples (axis_name, axis_length)
"""
left, rght = pattern.split('->')
left = ParsedExpression(left)
rght = ParsedExpression(rght)
left_str, rght_str = pattern.split('->')
left = ParsedExpression(left_str)
rght = ParsedExpression(rght_str)

# checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction
if not left.has_ellipsis and rght.has_ellipsis:
Expand Down Expand Up @@ -280,7 +279,7 @@ def _prepare_transformation_recipe(pattern: str,
raise EinopsError('Unknown reduction {}. Expect one of {}.'.format(operation, _reductions))

# parsing all dimensions to find out lengths
axis_name2known_length = OrderedDict()
axis_name2known_length: Dict[Union[str, AnonymousAxis], int]= OrderedDict()
for composite_axis in left.composition:
for axis_name in composite_axis:
if isinstance(axis_name, AnonymousAxis):
Expand All @@ -302,7 +301,7 @@ def _prepare_transformation_recipe(pattern: str,
axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)}
reduced_axes: List[int] = [position for axis, position in axis_name2position.items() if
axis not in rght.identifiers]
reduced_axes: List[int] = list(sorted(reduced_axes))
reduced_axes = list(sorted(reduced_axes))

for elementary_axis, axis_length in axes_lengths:
if not ParsedExpression.check_axis_name(elementary_axis):
Expand All @@ -314,8 +313,8 @@ def _prepare_transformation_recipe(pattern: str,
input_axes_known_unknown = []
# some of shapes will be inferred later - all information is prepared for faster inference
for composite_axis in left.composition:
known = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
unknown = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
if len(unknown) > 1:
raise EinopsError('Could not infer sizes for {}'.format(unknown))
assert len(unknown) + len(known) == len(composite_axis)
Expand All @@ -324,7 +323,7 @@ def _prepare_transformation_recipe(pattern: str,
[axis_name2position[axis] for axis in unknown])
)

axis_position_after_reduction = {}
axis_position_after_reduction: Dict[str, int] = {}
for axis_name in itertools.chain(*left.composition):
if axis_name in rght.identifiers:
axis_position_after_reduction[axis_name] = len(axis_position_after_reduction)
Expand Down Expand Up @@ -481,7 +480,7 @@ def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths)
if len(tensor) == 0:
raise TypeError("Rearrange can't be applied to an empty list")
tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor)
return reduce(tensor, pattern, reduction='rearrange', **axes_lengths)
return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths)


def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
Expand Down Expand Up @@ -581,7 +580,7 @@ def parse_shape(x, pattern: str) -> dict:
else:
composition = exp.composition
result = {}
for (axis_name,), axis_length in zip(composition, shape):
for (axis_name,), axis_length in zip(composition, shape): # type: ignore
if axis_name != '_':
result[axis_name] = axis_length
return result
Expand Down Expand Up @@ -644,15 +643,14 @@ def _compactify_pattern_for_einsum(pattern: str) -> str:
# numpy allows this, so make sure users
# don't accidentally do something like this.
raise ValueError("Einsum pattern must contain '->'.")
lefts, right = pattern.split('->')
lefts = lefts.split(',')
lefts_str, right_str = pattern.split('->')

lefts = [
ParsedExpression(left, allow_underscore=True, allow_duplicates=True)
for left in lefts
for left in lefts_str.split(',')
]

right = ParsedExpression(right, allow_underscore=True)
right = ParsedExpression(right_str, allow_underscore=True)

# Start from 'a' and go up to 'Z'
output_axis_names = string.ascii_letters
Expand Down Expand Up @@ -697,17 +695,20 @@ def _compactify_pattern_for_einsum(pattern: str) -> str:
return compact_pattern


# dunders in overloads turn arguments into positional-only.
# After python 3.7 EOL this should be replaced with '/' as the last argument.

@typing.overload
def einsum(tensor: Tensor, pattern: str) -> Tensor: ...
def einsum(__tensor: Tensor, __pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str) -> Tensor: ...
def einsum(__tensor1: Tensor, __tensor2: Tensor, __pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str) -> Tensor: ...
def einsum(__tensor1: Tensor, __tensor2: Tensor, __tensor3: Tensor, __pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str) -> Tensor: ...
def einsum(__tensor1: Tensor, __tensor2: Tensor, __tensor3: Tensor, __tensor4: Tensor, __pattern: str) -> Tensor: ...


def einsum(*tensors_and_pattern: List[Union[Tensor, str]]) -> Tensor:
def einsum(*tensors_and_pattern: Union[Tensor, str]) -> Tensor:
"""
einops.einsum calls einsum operations with einops-style named
axes indexing, computing tensor products with an arbitrary
Expand Down
4 changes: 2 additions & 2 deletions einops/experimental/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(
for x in composed_shape:
flat_shape.extend(x)

self.compose_transposition: Tuple[int] = tuple([decomposed_shape.index(x) for x in flat_shape])
self.decompose_transposition: Tuple[int] = tuple([flat_shape.index(x) for x in decomposed_shape])
self.compose_transposition: Tuple[int, ...] = tuple([decomposed_shape.index(x) for x in flat_shape])
self.decompose_transposition: Tuple[int, ...] = tuple([flat_shape.index(x) for x in decomposed_shape])
self.composed_shape = composed_shape
self.decomposed_shape = decomposed_shape

Expand Down
7 changes: 4 additions & 3 deletions einops/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__author__ = 'Alex Rogozhnikov'

import functools
from typing import Any

from einops.einops import _apply_recipe

Expand All @@ -18,13 +19,13 @@ class RearrangeMixin:
See einops.rearrange for source_examples.
"""

def __init__(self, pattern, **axes_lengths):
def __init__(self, pattern: str, **axes_lengths: Any) -> None:
super().__init__()
self.pattern = pattern
self.axes_lengths = axes_lengths
self._recipe = self.recipe() # checking parameters

def __repr__(self):
def __repr__(self) -> str:
params = repr(self.pattern)
for axis, length in self.axes_lengths.items():
params += ', {}={}'.format(axis, length)
Expand Down Expand Up @@ -53,7 +54,7 @@ class ReduceMixin:
See einops.reduce for source_examples.
"""

def __init__(self, pattern, reduction, **axes_lengths):
def __init__(self, pattern: str, reduction: str, **axes_lengths: Any):
super().__init__()
self.pattern = pattern
self.reduction = reduction
Expand Down
12 changes: 6 additions & 6 deletions einops/layers/_einmix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict
from typing import Any, List, Optional, Dict

from einops import EinopsError
from einops.parsing import ParsedExpression
Expand All @@ -13,7 +13,7 @@ def _report_axes(axes: set, report_message: str):


class _EinmixMixin:
def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths):
def __init__(self, pattern: str, weight_shape: str, bias_shape: Optional[str]=None, **axes_lengths: Any):
"""
EinMix - Einstein summation with automated tensor management and axis packing/unpacking.
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths):
self.axes_lengths = axes_lengths
self.initialize_einmix(pattern=pattern, weight_shape=weight_shape, bias_shape=bias_shape, axes_lengths=axes_lengths)

def initialize_einmix(self, pattern, weight_shape, bias_shape, axes_lengths):
def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optional[str], axes_lengths: dict):
left_pattern, right_pattern = pattern.split('->')
left = ParsedExpression(left_pattern)
right = ParsedExpression(right_pattern)
Expand All @@ -81,7 +81,7 @@ def initialize_einmix(self, pattern, weight_shape, bias_shape, axes_lengths):
pre_reshape_lengths = None
post_reshape_pattern = None
if any(len(group) != 1 for group in left.composition):
names = []
names: List[str] = []
for group in left.composition:
names += group
composition = ' '.join(names)
Expand Down Expand Up @@ -143,8 +143,8 @@ def initialize_einmix(self, pattern, weight_shape, bias_shape, axes_lengths):

# rewrite einsum expression with single-letter latin identifiers so that
# expression will be understood by any framework
mapping2letters = {*left.identifiers, *right.identifiers, *weight.identifiers}
mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapping2letters)}
mapped_identifiers = {*left.identifiers, *right.identifiers, *weight.identifiers}
mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapped_identifiers)}

def write_flat(axes: list):
return ''.join(mapping2letters[axis] for axis in axes)
Expand Down
6 changes: 3 additions & 3 deletions einops/layers/chainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict
from typing import Optional, Dict, cast

import chainer

Expand Down Expand Up @@ -36,11 +36,11 @@ def _create_rearrange_layers(self,
):
self.pre_rearrange = None
if pre_reshape_pattern is not None:
self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths)
self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths))

self.post_rearrange = None
if post_reshape_pattern is not None:
self.post_rearrange = Rearrange(post_reshape_pattern)
self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths))

def __call__(self, input):
if self.pre_rearrange is not None:
Expand Down
6 changes: 3 additions & 3 deletions einops/layers/flax.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import field
from typing import Optional, Dict
from typing import Optional, Dict, cast

import flax.linen as nn
import jax
Expand Down Expand Up @@ -63,11 +63,11 @@ def _create_rearrange_layers(self,
post_reshape_lengths: Optional[Dict]):
self.pre_rearrange = None
if pre_reshape_pattern is not None:
self.pre_rearrange = Rearrange(pre_reshape_pattern, sizes=pre_reshape_lengths)
self.pre_rearrange = Rearrange(pre_reshape_pattern, sizes=cast(dict, pre_reshape_lengths))

self.post_rearrange = None
if post_reshape_pattern is not None:
self.post_rearrange = Rearrange(post_reshape_pattern, sizes=post_reshape_lengths)
self.post_rearrange = Rearrange(post_reshape_pattern, sizes=cast(dict, post_reshape_lengths))

def __call__(self, input):
if self.pre_rearrange is not None:
Expand Down
6 changes: 3 additions & 3 deletions einops/layers/oneflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict
from typing import Optional, Dict, cast

import oneflow as flow

Expand Down Expand Up @@ -36,11 +36,11 @@ def _create_rearrange_layers(self,
):
self.pre_rearrange = None
if pre_reshape_pattern is not None:
self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths)
self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths))

self.post_rearrange = None
if post_reshape_pattern is not None:
self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths)
self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths))

def forward(self, input):
if self.pre_rearrange is not None:
Expand Down
6 changes: 3 additions & 3 deletions einops/layers/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Dict
from typing import List, Optional, Dict, cast

import tensorflow as tf
from tensorflow.keras.layers import Layer
Expand Down Expand Up @@ -59,11 +59,11 @@ def _create_rearrange_layers(self,
):
self.pre_rearrange = None
if pre_reshape_pattern is not None:
self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths)
self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths))

self.post_rearrange = None
if post_reshape_pattern is not None:
self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths)
self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths))

def build(self, input_shape):
pass
Expand Down
Loading

0 comments on commit e2c9bb2

Please sign in to comment.