forked from arogozhnikov/einops
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparsing.py
144 lines (129 loc) · 6.24 KB
/
parsing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from einops import EinopsError
import keyword
import warnings
from typing import List, Optional, Set, Tuple
_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
class AnonymousAxis(object):
"""Important thing: all instances of this class are not equal to each other """
def __init__(self, value: str):
self.value = int(value)
if self.value <= 1:
if self.value == 1:
raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue')
else:
raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value))
def __repr__(self):
return "{}-axis".format(str(self.value))
class ParsedExpression:
"""
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
and keeps some information important for downstream
"""
def __init__(self, expression):
self.has_ellipsis: bool = False
self.has_ellipsis_parenthesized: Optional[bool] = None
self.identifiers: Set[str] = set()
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
self.has_non_unitary_anonymous_axes: bool = False
# composition keeps structure of composite axes, see how different corner cases are handled in tests
self.composition = []
if '.' in expression:
if '...' not in expression:
raise EinopsError('Expression may contain dots only inside ellipsis (...)')
if str.count(expression, '...') != 1 or str.count(expression, '.') != 3:
raise EinopsError(
'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ')
expression = expression.replace('...', _ellipsis)
self.has_ellipsis = True
bracket_group = None
def add_axis_name(x):
if x is not None:
if x in self.identifiers:
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
if x == _ellipsis:
self.identifiers.add(_ellipsis)
if bracket_group is None:
self.composition.append(_ellipsis)
self.has_ellipsis_parenthesized = False
else:
bracket_group.append(_ellipsis)
self.has_ellipsis_parenthesized = True
else:
is_number = str.isdecimal(x)
if is_number and int(x) == 1:
# handling the case of anonymous axis of length 1
if bracket_group is None:
self.composition.append([])
else:
pass # no need to think about 1s inside parenthesis
return
is_axis_name, reason = self.check_axis_name_return_reason(x)
if not (is_number or is_axis_name):
raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason))
if is_number:
x = AnonymousAxis(x)
self.identifiers.add(x)
if is_number:
self.has_non_unitary_anonymous_axes = True
if bracket_group is None:
self.composition.append([x])
else:
bracket_group.append(x)
current_identifier = None
for char in expression:
if char in '() ':
add_axis_name(current_identifier)
current_identifier = None
if char == '(':
if bracket_group is not None:
raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)")
bracket_group = []
elif char == ')':
if bracket_group is None:
raise EinopsError('Brackets are not balanced')
self.composition.append(bracket_group)
bracket_group = None
elif str.isalnum(char) or char in ['_', _ellipsis]:
if current_identifier is None:
current_identifier = char
else:
current_identifier += char
else:
raise EinopsError("Unknown character '{}'".format(char))
if bracket_group is not None:
raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression))
add_axis_name(current_identifier)
def flat_axes_order(self) -> List:
result = []
for composed_axis in self.composition:
assert isinstance(composed_axis, list), 'does not work with ellipsis'
for axis in composed_axis:
result.append(axis)
return result
def has_composed_axes(self) -> bool:
# this will ignore 1 inside brackets
for axes in self.composition:
if isinstance(axes, list) and len(axes) > 1:
return True
return False
@staticmethod
def check_axis_name_return_reason(name: str) -> Tuple[bool, str]:
if not str.isidentifier(name):
return False, 'not a valid python identifier'
elif name[0] == '_' or name[-1] == '_':
return False, 'axis name should should not start or end with underscore'
else:
if keyword.iskeyword(name):
warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning)
if name in ['axis']:
warnings.warn("It is discouraged to use 'axis' as an axis name "
"and will raise an error in future", FutureWarning)
return True, ''
@staticmethod
def check_axis_name(name: str) -> bool:
"""
Valid axes names are python identifiers except keywords,
and additionally should not start or end with underscore
"""
is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name)
return is_valid