From 206b31addfb224b5b9692b2efe94f5db90c92f63 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 04:10:16 +0000 Subject: [PATCH 01/12] minimal typehints in tests --- tests/test_einsum.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 71003603..8d15ecd2 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1,3 +1,4 @@ +from typing import Any, Callable from venv import create from . import collect_test_backends from einops.einops import _compactify_pattern_for_einsum, einsum, EinopsError @@ -7,11 +8,11 @@ class Arguments: - def __init__(self, *args, **kargs): + def __init__(self, *args: Any, **kargs: Any): self.args = args self.kwargs = kargs - def __call__(self, function): + def __call__(self, function: Callable): return function(*self.args, **self.kwargs) From 76082138f39a24795aaff5c2dd88f1057d4ad451 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 04:13:45 +0000 Subject: [PATCH 02/12] type hinting in backends --- einops/_backends.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/einops/_backends.py b/einops/_backends.py index 6275b475..4458a94d 100644 --- a/einops/_backends.py +++ b/einops/_backends.py @@ -15,7 +15,7 @@ __author__ = 'Alex Rogozhnikov' -_backends = {} +_backends: dict = {} _debug_importing = False @@ -54,7 +54,7 @@ def get_backend(tensor) -> 'AbstractBackend': class AbstractBackend: """ Base backend class, major part of methods are only for debugging purposes. """ - framework_name = None + framework_name: str def is_appropriate_type(self, tensor): """ helper method should recognize tensors it can handle """ @@ -137,7 +137,7 @@ def __rmul__(self, other): return self def __hash__(self): - return None.__hash__() + return hash(None) class NumpyBackend(AbstractBackend): From 36c7bb16e57d6e57f8f3050f9e07abdf3f00469f Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 06:00:36 +0000 Subject: [PATCH 03/12] type hints + minor changes --- einops/parsing.py | 73 ++++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/einops/parsing.py b/einops/parsing.py index 6b286a4f..df0f4c53 100644 --- a/einops/parsing.py +++ b/einops/parsing.py @@ -1,7 +1,7 @@ from einops import EinopsError import keyword import warnings -from typing import List, Optional, Set, Tuple +from typing import List, Optional, Set, Tuple, Union _ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated @@ -26,7 +26,7 @@ class ParsedExpression: non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') and keeps some information important for downstream """ - def __init__(self, expression, *, allow_underscore: bool = False, + def __init__(self, expression: str, *, allow_underscore: bool = False, allow_duplicates: bool = False): self.has_ellipsis: bool = False self.has_ellipsis_parenthesized: Optional[bool] = None @@ -34,7 +34,7 @@ def __init__(self, expression, *, allow_underscore: bool = False, # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition self.has_non_unitary_anonymous_axes: bool = False # composition keeps structure of composite axes, see how different corner cases are handled in tests - self.composition = [] + self.composition: List[Union[List[str], str]] = [] if '.' in expression: if '...' not in expression: raise EinopsError('Expression may contain dots only inside ellipsis (...)') @@ -44,47 +44,47 @@ def __init__(self, expression, *, allow_underscore: bool = False, expression = expression.replace('...', _ellipsis) self.has_ellipsis = True - bracket_group = None + bracket_group: Optional[List[str]] = None def add_axis_name(x): - if x is not None: - if x in self.identifiers: - if not (allow_underscore and x == "_") and not allow_duplicates: - raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) - if x == _ellipsis: - self.identifiers.add(_ellipsis) - if bracket_group is None: - self.composition.append(_ellipsis) - self.has_ellipsis_parenthesized = False - else: - bracket_group.append(_ellipsis) - self.has_ellipsis_parenthesized = True + if x in self.identifiers: + if not (allow_underscore and x == "_") and not allow_duplicates: + raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) + if x == _ellipsis: + self.identifiers.add(_ellipsis) + if bracket_group is None: + self.composition.append(_ellipsis) + self.has_ellipsis_parenthesized = False else: - is_number = str.isdecimal(x) - if is_number and int(x) == 1: - # handling the case of anonymous axis of length 1 - if bracket_group is None: - self.composition.append([]) - else: - pass # no need to think about 1s inside parenthesis - return - is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) - if not (is_number or is_axis_name): - raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) - if is_number: - x = AnonymousAxis(x) - self.identifiers.add(x) - if is_number: - self.has_non_unitary_anonymous_axes = True + bracket_group.append(_ellipsis) + self.has_ellipsis_parenthesized = True + else: + is_number = str.isdecimal(x) + if is_number and int(x) == 1: + # handling the case of anonymous axis of length 1 if bracket_group is None: - self.composition.append([x]) + self.composition.append([]) else: - bracket_group.append(x) + pass # no need to think about 1s inside parenthesis + return + is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) + if not (is_number or is_axis_name): + raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) + if is_number: + x = AnonymousAxis(x) + self.identifiers.add(x) + if is_number: + self.has_non_unitary_anonymous_axes = True + if bracket_group is None: + self.composition.append([x]) + else: + bracket_group.append(x) current_identifier = None for char in expression: if char in '() ': - add_axis_name(current_identifier) + if current_identifier is not None: + add_axis_name(current_identifier) current_identifier = None if char == '(': if bracket_group is not None: @@ -105,7 +105,8 @@ def add_axis_name(x): if bracket_group is not None: raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression)) - add_axis_name(current_identifier) + if current_identifier is not None: + add_axis_name(current_identifier) def flat_axes_order(self) -> List: result = [] From c054ad0daebf4c83fbbc9f265d72d06436ad0f1a Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 06:14:08 +0000 Subject: [PATCH 04/12] optimize reductions in torch where possible --- einops/_backends.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/einops/_backends.py b/einops/_backends.py index 4458a94d..8a4fde9b 100644 --- a/einops/_backends.py +++ b/einops/_backends.py @@ -322,16 +322,20 @@ def arange(self, start, stop): return self.torch.arange(start, stop, dtype=self.torch.int64) def reduce(self, x, operation, reduced_axes): - for axis in sorted(reduced_axes, reverse=True): - if operation == 'min': - x, _ = x.min(dim=axis) - elif operation == 'max': - x, _ = x.max(dim=axis) - elif operation in ['sum', 'mean', 'prod']: - x = getattr(x, operation)(dim=axis) - else: - raise NotImplementedError('Unknown reduction ', operation) - return x + if operation == 'min': + return x.amin(dim=reduced_axes) + elif operation == 'max': + return x.amax(dim=reduced_axes) + elif operation == 'sum': + return x.sum(dim=reduced_axes) + elif operation == 'mean': + return x.mean(dim=reduced_axes) + elif operation == 'prod': + for i in list(sorted(reduced_axes))[::-1]: + x = x.prod(dim=i) + return x + else: + raise NotImplementedError('Unknown reduction ', operation) def transpose(self, x, axes): return x.permute(axes) From 3ddb01e19a9bce3ba98f29562101bb48796a5242 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 06:16:08 +0000 Subject: [PATCH 05/12] type hinting for einmix --- einops/layers/_einmix.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/einops/layers/_einmix.py b/einops/layers/_einmix.py index c07d00fb..138234e5 100644 --- a/einops/layers/_einmix.py +++ b/einops/layers/_einmix.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Any, Optional, Dict from einops import EinopsError from einops.parsing import ParsedExpression @@ -13,7 +13,7 @@ def _report_axes(axes: set, report_message: str): class _EinmixMixin: - def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths): + def __init__(self, pattern: str, weight_shape: str, bias_shape: Optional[str]=None, **axes_lengths: Any): """ EinMix - Einstein summation with automated tensor management and axis packing/unpacking. @@ -60,7 +60,7 @@ def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths): self.axes_lengths = axes_lengths self.initialize_einmix(pattern=pattern, weight_shape=weight_shape, bias_shape=bias_shape, axes_lengths=axes_lengths) - def initialize_einmix(self, pattern, weight_shape, bias_shape, axes_lengths): + def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optional[str], axes_lengths: dict): left_pattern, right_pattern = pattern.split('->') left = ParsedExpression(left_pattern) right = ParsedExpression(right_pattern) From 28ab6020e0a1834e1f875ffda106f2a306a97ab1 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 06:20:05 +0000 Subject: [PATCH 06/12] add type casts to layers' implementations to suppress mypy warnings --- einops/layers/chainer.py | 6 +++--- einops/layers/flax.py | 6 +++--- einops/layers/oneflow.py | 6 +++--- einops/layers/tensorflow.py | 6 +++--- einops/layers/torch.py | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/einops/layers/chainer.py b/einops/layers/chainer.py index 05c3837b..0214bc32 100644 --- a/einops/layers/chainer.py +++ b/einops/layers/chainer.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Optional, Dict, cast import chainer @@ -36,11 +36,11 @@ def _create_rearrange_layers(self, ): self.pre_rearrange = None if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths) + self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths)) self.post_rearrange = None if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern) + self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths)) def __call__(self, input): if self.pre_rearrange is not None: diff --git a/einops/layers/flax.py b/einops/layers/flax.py index 804f9aa3..abd4ec5b 100644 --- a/einops/layers/flax.py +++ b/einops/layers/flax.py @@ -1,5 +1,5 @@ from dataclasses import field -from typing import Optional, Dict +from typing import Optional, Dict, cast import flax.linen as nn import jax @@ -63,11 +63,11 @@ def _create_rearrange_layers(self, post_reshape_lengths: Optional[Dict]): self.pre_rearrange = None if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, sizes=pre_reshape_lengths) + self.pre_rearrange = Rearrange(pre_reshape_pattern, sizes=cast(dict, pre_reshape_lengths)) self.post_rearrange = None if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern, sizes=post_reshape_lengths) + self.post_rearrange = Rearrange(post_reshape_pattern, sizes=cast(dict, post_reshape_lengths)) def __call__(self, input): if self.pre_rearrange is not None: diff --git a/einops/layers/oneflow.py b/einops/layers/oneflow.py index 8a0ef877..2885404d 100644 --- a/einops/layers/oneflow.py +++ b/einops/layers/oneflow.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Optional, Dict, cast import oneflow as flow @@ -36,11 +36,11 @@ def _create_rearrange_layers(self, ): self.pre_rearrange = None if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths) + self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths)) self.post_rearrange = None if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths) + self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths)) def forward(self, input): if self.pre_rearrange is not None: diff --git a/einops/layers/tensorflow.py b/einops/layers/tensorflow.py index 0485ca6d..c89a71ad 100644 --- a/einops/layers/tensorflow.py +++ b/einops/layers/tensorflow.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Dict +from typing import List, Optional, Dict, cast import tensorflow as tf from tensorflow.keras.layers import Layer @@ -59,11 +59,11 @@ def _create_rearrange_layers(self, ): self.pre_rearrange = None if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths) + self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths)) self.post_rearrange = None if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths) + self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths)) def build(self, input_shape): pass diff --git a/einops/layers/torch.py b/einops/layers/torch.py index 90737f79..31992415 100644 --- a/einops/layers/torch.py +++ b/einops/layers/torch.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Optional, Dict, cast import torch @@ -45,11 +45,11 @@ def _create_rearrange_layers(self, ): self.pre_rearrange = None if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths) + self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths)) self.post_rearrange = None if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths) + self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths)) def forward(self, input): if self.pre_rearrange is not None: From 3732e2bc40c342e00aa797512535ed381aabed68 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 06:20:52 +0000 Subject: [PATCH 07/12] type hints for layers --- einops/layers/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/einops/layers/__init__.py b/einops/layers/__init__.py index b16dc837..955aad82 100644 --- a/einops/layers/__init__.py +++ b/einops/layers/__init__.py @@ -1,6 +1,7 @@ __author__ = 'Alex Rogozhnikov' import functools +from typing import Any from einops.einops import _apply_recipe @@ -18,13 +19,13 @@ class RearrangeMixin: See einops.rearrange for source_examples. """ - def __init__(self, pattern, **axes_lengths): + def __init__(self, pattern: str, **axes_lengths: Any) -> None: super().__init__() self.pattern = pattern self.axes_lengths = axes_lengths self._recipe = self.recipe() # checking parameters - def __repr__(self): + def __repr__(self) -> str: params = repr(self.pattern) for axis, length in self.axes_lengths.items(): params += ', {}={}'.format(axis, length) @@ -53,7 +54,7 @@ class ReduceMixin: See einops.reduce for source_examples. """ - def __init__(self, pattern, reduction, **axes_lengths): + def __init__(self, pattern: str, reduction: str, **axes_lengths: Any): super().__init__() self.pattern = pattern self.reduction = reduction From 240c7d87e20abbc7e7392416d34c8fd36a7a65f1 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 06:22:28 +0000 Subject: [PATCH 08/12] einsum: fix type hints, fix mypy warnigns, more hinting --- einops/einops.py | 50 +++++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/einops/einops.py b/einops/einops.py index 9f4896bb..41a9b1eb 100644 --- a/einops/einops.py +++ b/einops/einops.py @@ -3,7 +3,7 @@ import string import typing from collections import OrderedDict -from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar +from typing import Set, Tuple, List, Dict, Union, Callable, Optional, TypeVar, cast if typing.TYPE_CHECKING: import numpy as np @@ -13,7 +13,7 @@ from .parsing import ParsedExpression, _ellipsis, AnonymousAxis Tensor = TypeVar('Tensor') -ReductionCallable = Callable[[Tensor, List[int]], Tensor] +ReductionCallable = Callable[[Tensor, Tuple[int]], Tensor] Reduction = Union[str, ReductionCallable] _reductions = ('min', 'max', 'sum', 'mean', 'prod') @@ -38,10 +38,9 @@ def _product(sequence: List[int]) -> int: def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend): - reduced_axes = tuple(reduced_axes) if callable(reduction_type): # custom callable - return reduction_type(tensor, reduced_axes) + return reduction_type(tensor, tuple(reduced_axes)) else: # one of built-in operations if len(reduced_axes) == 0: @@ -50,7 +49,7 @@ def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], bac if reduction_type == 'mean': if not backend.is_float_type(tensor): raise NotImplementedError('reduce_mean is not available for non-floating tensors') - return backend.reduce(tensor, reduction_type, reduced_axes) + return backend.reduce(tensor, reduction_type, tuple(reduced_axes)) def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes): @@ -171,7 +170,7 @@ def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) -> after_ellipsis = input_axis + len(shape) - len(self.input_composite_axes) if input_axis == self.ellipsis_position_in_lhs: assert len(known_axes) == 0 and len(unknown_axes) == 1 - unknown_axis, = unknown_axes + unknown_axis: int = unknown_axes[0] ellipsis_shape = shape[before_ellipsis:after_ellipsis + 1] for d in ellipsis_shape: if d is None: @@ -201,7 +200,7 @@ def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) -> raise EinopsError("Shape mismatch, can't divide axis of length {} in chunks of {}".format( length, known_product)) - unknown_axis: int = unknown_axes[0] + unknown_axis = unknown_axes[0] inferred_length: int = length // known_product axes_lengths[unknown_axis] = inferred_length @@ -249,9 +248,9 @@ def _prepare_transformation_recipe(pattern: str, """ Perform initial parsing of pattern and provided supplementary info axes_lengths is a tuple of tuples (axis_name, axis_length) """ - left, rght = pattern.split('->') - left = ParsedExpression(left) - rght = ParsedExpression(rght) + left_str, rght_str = pattern.split('->') + left = ParsedExpression(left_str) + rght = ParsedExpression(rght_str) # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction if not left.has_ellipsis and rght.has_ellipsis: @@ -280,7 +279,7 @@ def _prepare_transformation_recipe(pattern: str, raise EinopsError('Unknown reduction {}. Expect one of {}.'.format(operation, _reductions)) # parsing all dimensions to find out lengths - axis_name2known_length = OrderedDict() + axis_name2known_length: Dict[Union[str, AnonymousAxis], int]= OrderedDict() for composite_axis in left.composition: for axis_name in composite_axis: if isinstance(axis_name, AnonymousAxis): @@ -302,7 +301,7 @@ def _prepare_transformation_recipe(pattern: str, axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)} reduced_axes: List[int] = [position for axis, position in axis_name2position.items() if axis not in rght.identifiers] - reduced_axes: List[int] = list(sorted(reduced_axes)) + reduced_axes = list(sorted(reduced_axes)) for elementary_axis, axis_length in axes_lengths: if not ParsedExpression.check_axis_name(elementary_axis): @@ -314,8 +313,8 @@ def _prepare_transformation_recipe(pattern: str, input_axes_known_unknown = [] # some of shapes will be inferred later - all information is prepared for faster inference for composite_axis in left.composition: - known = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length} - unknown = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length} + known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length} + unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length} if len(unknown) > 1: raise EinopsError('Could not infer sizes for {}'.format(unknown)) assert len(unknown) + len(known) == len(composite_axis) @@ -324,7 +323,7 @@ def _prepare_transformation_recipe(pattern: str, [axis_name2position[axis] for axis in unknown]) ) - axis_position_after_reduction = {} + axis_position_after_reduction: Dict[str, int] = {} for axis_name in itertools.chain(*left.composition): if axis_name in rght.identifiers: axis_position_after_reduction[axis_name] = len(axis_position_after_reduction) @@ -481,7 +480,7 @@ def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) if len(tensor) == 0: raise TypeError("Rearrange can't be applied to an empty list") tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor) - return reduce(tensor, pattern, reduction='rearrange', **axes_lengths) + return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths) def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor: @@ -581,7 +580,7 @@ def parse_shape(x, pattern: str) -> dict: else: composition = exp.composition result = {} - for (axis_name,), axis_length in zip(composition, shape): + for (axis_name,), axis_length in zip(composition, shape): # type: ignore if axis_name != '_': result[axis_name] = axis_length return result @@ -644,15 +643,14 @@ def _compactify_pattern_for_einsum(pattern: str) -> str: # numpy allows this, so make sure users # don't accidentally do something like this. raise ValueError("Einsum pattern must contain '->'.") - lefts, right = pattern.split('->') - lefts = lefts.split(',') + lefts_str, right_str = pattern.split('->') lefts = [ ParsedExpression(left, allow_underscore=True, allow_duplicates=True) - for left in lefts + for left in lefts_str.split(',') ] - right = ParsedExpression(right, allow_underscore=True) + right = ParsedExpression(right_str, allow_underscore=True) # Start from 'a' and go up to 'Z' output_axis_names = string.ascii_letters @@ -698,16 +696,16 @@ def _compactify_pattern_for_einsum(pattern: str) -> str: @typing.overload -def einsum(tensor: Tensor, pattern: str) -> Tensor: ... +def einsum(tensor: Tensor, pattern: str, /) -> Tensor: ... @typing.overload -def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str) -> Tensor: ... +def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str, /) -> Tensor: ... @typing.overload -def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str) -> Tensor: ... +def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str, /) -> Tensor: ... @typing.overload -def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str) -> Tensor: ... +def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str, /) -> Tensor: ... -def einsum(*tensors_and_pattern: List[Union[Tensor, str]]) -> Tensor: +def einsum(*tensors_and_pattern: Union[Tensor, str]) -> Tensor: """ einops.einsum calls einsum operations with einops-style named axes indexing, computing tensor products with an arbitrary From ceb1c4f291d65058863e13e53d90d6be2236cb1f Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 06:23:32 +0000 Subject: [PATCH 09/12] fox type hints --- einops/experimental/indexing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/einops/experimental/indexing.py b/einops/experimental/indexing.py index a72558c4..35cef4b3 100644 --- a/einops/experimental/indexing.py +++ b/einops/experimental/indexing.py @@ -72,8 +72,8 @@ def __init__( for x in composed_shape: flat_shape.extend(x) - self.compose_transposition: Tuple[int] = tuple([decomposed_shape.index(x) for x in flat_shape]) - self.decompose_transposition: Tuple[int] = tuple([flat_shape.index(x) for x in decomposed_shape]) + self.compose_transposition: Tuple[int, ...] = tuple([decomposed_shape.index(x) for x in flat_shape]) + self.decompose_transposition: Tuple[int, ...] = tuple([flat_shape.index(x) for x in decomposed_shape]) self.composed_shape = composed_shape self.decomposed_shape = decomposed_shape From 3823586e638cad45aed1903984058087a3f02f27 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 06:24:13 +0000 Subject: [PATCH 10/12] add testing as one of hatch scripts --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 3cedc622..32b2fc1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,3 +66,9 @@ deploy = "mkdocs build --clean --strict && mkdocs gh-deploy" # hatch run pypi:deploy_test deploy_test = "hatch build --clean && hatch publish -r test" deploy = "hatch build --clean && hatch publish" + + +[tool.hatch.envs.testing.scripts] +# hatch run testing:test +test = "python test.py" + From 5af7e5b61645bffa8ebc1dc4a86dd08b3eac5383 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 07:15:09 +0000 Subject: [PATCH 11/12] addressing mypy comments --- einops/layers/_einmix.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/einops/layers/_einmix.py b/einops/layers/_einmix.py index 138234e5..dc9fd88d 100644 --- a/einops/layers/_einmix.py +++ b/einops/layers/_einmix.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Dict +from typing import Any, List, Optional, Dict from einops import EinopsError from einops.parsing import ParsedExpression @@ -81,7 +81,7 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona pre_reshape_lengths = None post_reshape_pattern = None if any(len(group) != 1 for group in left.composition): - names = [] + names: List[str] = [] for group in left.composition: names += group composition = ' '.join(names) @@ -143,8 +143,8 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona # rewrite einsum expression with single-letter latin identifiers so that # expression will be understood by any framework - mapping2letters = {*left.identifiers, *right.identifiers, *weight.identifiers} - mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapping2letters)} + mapped_identifiers = {*left.identifiers, *right.identifiers, *weight.identifiers} + mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapped_identifiers)} def write_flat(axes: list): return ''.join(mapping2letters[axis] for axis in axes) From 4bd117050f353db8acd25a98be6f8ff5b2cf5834 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Mon, 10 Oct 2022 07:16:35 +0000 Subject: [PATCH 12/12] use dunders in overloads as positional-only for python 3.7 and earlier --- einops/einops.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/einops/einops.py b/einops/einops.py index 41a9b1eb..82355746 100644 --- a/einops/einops.py +++ b/einops/einops.py @@ -695,14 +695,17 @@ def _compactify_pattern_for_einsum(pattern: str) -> str: return compact_pattern +# dunders in overloads turn arguments into positional-only. +# After python 3.7 EOL this should be replaced with '/' as the last argument. + @typing.overload -def einsum(tensor: Tensor, pattern: str, /) -> Tensor: ... +def einsum(__tensor: Tensor, __pattern: str) -> Tensor: ... @typing.overload -def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str, /) -> Tensor: ... +def einsum(__tensor1: Tensor, __tensor2: Tensor, __pattern: str) -> Tensor: ... @typing.overload -def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str, /) -> Tensor: ... +def einsum(__tensor1: Tensor, __tensor2: Tensor, __tensor3: Tensor, __pattern: str) -> Tensor: ... @typing.overload -def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str, /) -> Tensor: ... +def einsum(__tensor1: Tensor, __tensor2: Tensor, __tensor3: Tensor, __tensor4: Tensor, __pattern: str) -> Tensor: ... def einsum(*tensors_and_pattern: Union[Tensor, str]) -> Tensor: