forked from python/mypy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreachability.py
343 lines (292 loc) · 12.4 KB
/
reachability.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
"""Utilities related to determining the reachability of code (in semantic analysis)."""
from typing import Tuple, TypeVar, Union, Optional
from typing_extensions import Final
from mypy.nodes import (
Expression, IfStmt, Block, AssertStmt, MatchStmt, NameExpr, UnaryExpr, MemberExpr, OpExpr,
ComparisonExpr, StrExpr, UnicodeExpr, CallExpr, IntExpr, TupleExpr, IndexExpr, SliceExpr,
Import, ImportFrom, ImportAll, LITERAL_YES
)
from mypy.options import Options
from mypy.patterns import Pattern, AsPattern, OrPattern
from mypy.traverser import TraverserVisitor
from mypy.literals import literal
# Inferred truth value of an expression.
ALWAYS_TRUE: Final = 1
MYPY_TRUE: Final = 2 # True in mypy, False at runtime
ALWAYS_FALSE: Final = 3
MYPY_FALSE: Final = 4 # False in mypy, True at runtime
TRUTH_VALUE_UNKNOWN: Final = 5
inverted_truth_mapping: Final = {
ALWAYS_TRUE: ALWAYS_FALSE,
ALWAYS_FALSE: ALWAYS_TRUE,
TRUTH_VALUE_UNKNOWN: TRUTH_VALUE_UNKNOWN,
MYPY_TRUE: MYPY_FALSE,
MYPY_FALSE: MYPY_TRUE,
}
reverse_op: Final = {
"==": "==",
"!=": "!=",
"<": ">",
">": "<",
"<=": ">=",
">=": "<=",
}
def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None:
for i in range(len(s.expr)):
result = infer_condition_value(s.expr[i], options)
if result in (ALWAYS_FALSE, MYPY_FALSE):
# The condition is considered always false, so we skip the if/elif body.
mark_block_unreachable(s.body[i])
elif result in (ALWAYS_TRUE, MYPY_TRUE):
# This condition is considered always true, so all of the remaining
# elif/else bodies should not be checked.
if result == MYPY_TRUE:
# This condition is false at runtime; this will affect
# import priorities.
mark_block_mypy_only(s.body[i])
for body in s.body[i + 1:]:
mark_block_unreachable(body)
# Make sure else body always exists and is marked as
# unreachable so the type checker always knows that
# all control flow paths will flow through the if
# statement body.
if not s.else_body:
s.else_body = Block([])
mark_block_unreachable(s.else_body)
break
def infer_reachability_of_match_statement(s: MatchStmt, options: Options) -> None:
for i, guard in enumerate(s.guards):
pattern_value = infer_pattern_value(s.patterns[i])
if guard is not None:
guard_value = infer_condition_value(guard, options)
else:
guard_value = ALWAYS_TRUE
if pattern_value in (ALWAYS_FALSE, MYPY_FALSE) \
or guard_value in (ALWAYS_FALSE, MYPY_FALSE):
# The case is considered always false, so we skip the case body.
mark_block_unreachable(s.bodies[i])
elif pattern_value in (ALWAYS_FALSE, MYPY_TRUE) \
and guard_value in (ALWAYS_TRUE, MYPY_TRUE):
for body in s.bodies[i + 1:]:
mark_block_unreachable(body)
if guard_value == MYPY_TRUE:
# This condition is false at runtime; this will affect
# import priorities.
mark_block_mypy_only(s.bodies[i])
def assert_will_always_fail(s: AssertStmt, options: Options) -> bool:
return infer_condition_value(s.expr, options) in (ALWAYS_FALSE, MYPY_FALSE)
def infer_condition_value(expr: Expression, options: Options) -> int:
"""Infer whether the given condition is always true/false.
Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false,
MYPY_TRUE if true under mypy and false at runtime, MYPY_FALSE if
false under mypy and true at runtime, else TRUTH_VALUE_UNKNOWN.
"""
pyversion = options.python_version
name = ''
negated = False
alias = expr
if isinstance(alias, UnaryExpr):
if alias.op == 'not':
expr = alias.expr
negated = True
result = TRUTH_VALUE_UNKNOWN
if isinstance(expr, NameExpr):
name = expr.name
elif isinstance(expr, MemberExpr):
name = expr.name
elif isinstance(expr, OpExpr) and expr.op in ('and', 'or'):
left = infer_condition_value(expr.left, options)
if ((left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'and') or
(left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'or')):
# Either `True and <other>` or `False or <other>`: the result will
# always be the right-hand-side.
return infer_condition_value(expr.right, options)
else:
# The result will always be the left-hand-side (e.g. ALWAYS_* or
# TRUTH_VALUE_UNKNOWN).
return left
else:
result = consider_sys_version_info(expr, pyversion)
if result == TRUTH_VALUE_UNKNOWN:
result = consider_sys_platform(expr, options.platform)
if result == TRUTH_VALUE_UNKNOWN:
if name == 'PY2':
result = ALWAYS_TRUE if pyversion[0] == 2 else ALWAYS_FALSE
elif name == 'PY3':
result = ALWAYS_TRUE if pyversion[0] == 3 else ALWAYS_FALSE
elif name == 'MYPY' or name == 'TYPE_CHECKING':
result = MYPY_TRUE
elif name in options.always_true:
result = ALWAYS_TRUE
elif name in options.always_false:
result = ALWAYS_FALSE
if negated:
result = inverted_truth_mapping[result]
return result
def infer_pattern_value(pattern: Pattern) -> int:
if isinstance(pattern, AsPattern) and pattern.pattern is None:
return ALWAYS_TRUE
elif isinstance(pattern, OrPattern) and \
any(infer_pattern_value(p) == ALWAYS_TRUE for p in pattern.patterns):
return ALWAYS_TRUE
else:
return TRUTH_VALUE_UNKNOWN
def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> int:
"""Consider whether expr is a comparison involving sys.version_info.
Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
"""
# Cases supported:
# - sys.version_info[<int>] <compare_op> <int>
# - sys.version_info[:<int>] <compare_op> <tuple_of_n_ints>
# - sys.version_info <compare_op> <tuple_of_1_or_2_ints>
# (in this case <compare_op> must be >, >=, <, <=, but cannot be ==, !=)
if not isinstance(expr, ComparisonExpr):
return TRUTH_VALUE_UNKNOWN
# Let's not yet support chained comparisons.
if len(expr.operators) > 1:
return TRUTH_VALUE_UNKNOWN
op = expr.operators[0]
if op not in ('==', '!=', '<=', '>=', '<', '>'):
return TRUTH_VALUE_UNKNOWN
index = contains_sys_version_info(expr.operands[0])
thing = contains_int_or_tuple_of_ints(expr.operands[1])
if index is None or thing is None:
index = contains_sys_version_info(expr.operands[1])
thing = contains_int_or_tuple_of_ints(expr.operands[0])
op = reverse_op[op]
if isinstance(index, int) and isinstance(thing, int):
# sys.version_info[i] <compare_op> k
if 0 <= index <= 1:
return fixed_comparison(pyversion[index], op, thing)
else:
return TRUTH_VALUE_UNKNOWN
elif isinstance(index, tuple) and isinstance(thing, tuple):
lo, hi = index
if lo is None:
lo = 0
if hi is None:
hi = 2
if 0 <= lo < hi <= 2:
val = pyversion[lo:hi]
if len(val) == len(thing) or len(val) > len(thing) and op not in ('==', '!='):
return fixed_comparison(val, op, thing)
return TRUTH_VALUE_UNKNOWN
def consider_sys_platform(expr: Expression, platform: str) -> int:
"""Consider whether expr is a comparison involving sys.platform.
Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
"""
# Cases supported:
# - sys.platform == 'posix'
# - sys.platform != 'win32'
# - sys.platform.startswith('win')
if isinstance(expr, ComparisonExpr):
# Let's not yet support chained comparisons.
if len(expr.operators) > 1:
return TRUTH_VALUE_UNKNOWN
op = expr.operators[0]
if op not in ('==', '!='):
return TRUTH_VALUE_UNKNOWN
if not is_sys_attr(expr.operands[0], 'platform'):
return TRUTH_VALUE_UNKNOWN
right = expr.operands[1]
if not isinstance(right, (StrExpr, UnicodeExpr)):
return TRUTH_VALUE_UNKNOWN
return fixed_comparison(platform, op, right.value)
elif isinstance(expr, CallExpr):
if not isinstance(expr.callee, MemberExpr):
return TRUTH_VALUE_UNKNOWN
if len(expr.args) != 1 or not isinstance(expr.args[0], (StrExpr, UnicodeExpr)):
return TRUTH_VALUE_UNKNOWN
if not is_sys_attr(expr.callee.expr, 'platform'):
return TRUTH_VALUE_UNKNOWN
if expr.callee.name != 'startswith':
return TRUTH_VALUE_UNKNOWN
if platform.startswith(expr.args[0].value):
return ALWAYS_TRUE
else:
return ALWAYS_FALSE
else:
return TRUTH_VALUE_UNKNOWN
Targ = TypeVar('Targ', int, str, Tuple[int, ...])
def fixed_comparison(left: Targ, op: str, right: Targ) -> int:
rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE}
if op == '==':
return rmap[left == right]
if op == '!=':
return rmap[left != right]
if op == '<=':
return rmap[left <= right]
if op == '>=':
return rmap[left >= right]
if op == '<':
return rmap[left < right]
if op == '>':
return rmap[left > right]
return TRUTH_VALUE_UNKNOWN
def contains_int_or_tuple_of_ints(expr: Expression
) -> Union[None, int, Tuple[int], Tuple[int, ...]]:
if isinstance(expr, IntExpr):
return expr.value
if isinstance(expr, TupleExpr):
if literal(expr) == LITERAL_YES:
thing = []
for x in expr.items:
if not isinstance(x, IntExpr):
return None
thing.append(x.value)
return tuple(thing)
return None
def contains_sys_version_info(expr: Expression
) -> Union[None, int, Tuple[Optional[int], Optional[int]]]:
if is_sys_attr(expr, 'version_info'):
return (None, None) # Same as sys.version_info[:]
if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, 'version_info'):
index = expr.index
if isinstance(index, IntExpr):
return index.value
if isinstance(index, SliceExpr):
if index.stride is not None:
if not isinstance(index.stride, IntExpr) or index.stride.value != 1:
return None
begin = end = None
if index.begin_index is not None:
if not isinstance(index.begin_index, IntExpr):
return None
begin = index.begin_index.value
if index.end_index is not None:
if not isinstance(index.end_index, IntExpr):
return None
end = index.end_index.value
return (begin, end)
return None
def is_sys_attr(expr: Expression, name: str) -> bool:
# TODO: This currently doesn't work with code like this:
# - import sys as _sys
# - from sys import version_info
if isinstance(expr, MemberExpr) and expr.name == name:
if isinstance(expr.expr, NameExpr) and expr.expr.name == 'sys':
# TODO: Guard against a local named sys, etc.
# (Though later passes will still do most checking.)
return True
return False
def mark_block_unreachable(block: Block) -> None:
block.is_unreachable = True
block.accept(MarkImportsUnreachableVisitor())
class MarkImportsUnreachableVisitor(TraverserVisitor):
"""Visitor that flags all imports nested within a node as unreachable."""
def visit_import(self, node: Import) -> None:
node.is_unreachable = True
def visit_import_from(self, node: ImportFrom) -> None:
node.is_unreachable = True
def visit_import_all(self, node: ImportAll) -> None:
node.is_unreachable = True
def mark_block_mypy_only(block: Block) -> None:
block.accept(MarkImportsMypyOnlyVisitor())
class MarkImportsMypyOnlyVisitor(TraverserVisitor):
"""Visitor that sets is_mypy_only (which affects priority)."""
def visit_import(self, node: Import) -> None:
node.is_mypy_only = True
def visit_import_from(self, node: ImportFrom) -> None:
node.is_mypy_only = True
def visit_import_all(self, node: ImportAll) -> None:
node.is_mypy_only = True