forked from python/mypy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpartially_defined.py
680 lines (590 loc) · 25 KB
/
partially_defined.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
from __future__ import annotations
from enum import Enum
from mypy import checker, errorcodes
from mypy.messages import MessageBuilder
from mypy.nodes import (
AssertStmt,
AssignmentExpr,
AssignmentStmt,
BreakStmt,
ClassDef,
Context,
ContinueStmt,
DictionaryComprehension,
Expression,
ExpressionStmt,
ForStmt,
FuncDef,
FuncItem,
GeneratorExpr,
GlobalDecl,
IfStmt,
Import,
ImportFrom,
LambdaExpr,
ListExpr,
Lvalue,
MatchStmt,
MypyFile,
NameExpr,
NonlocalDecl,
RaiseStmt,
ReturnStmt,
StarExpr,
SymbolTable,
TryStmt,
TupleExpr,
TypeAliasStmt,
WhileStmt,
WithStmt,
implicit_module_attrs,
)
from mypy.options import Options
from mypy.patterns import AsPattern, StarredPattern
from mypy.reachability import ALWAYS_TRUE, infer_pattern_value
from mypy.traverser import ExtendedTraverserVisitor
from mypy.types import Type, UninhabitedType
class BranchState:
"""BranchState contains information about variable definition at the end of a branching statement.
`if` and `match` are examples of branching statements.
`may_be_defined` contains variables that were defined in only some branches.
`must_be_defined` contains variables that were defined in all branches.
"""
def __init__(
self,
must_be_defined: set[str] | None = None,
may_be_defined: set[str] | None = None,
skipped: bool = False,
) -> None:
if may_be_defined is None:
may_be_defined = set()
if must_be_defined is None:
must_be_defined = set()
self.may_be_defined = set(may_be_defined)
self.must_be_defined = set(must_be_defined)
self.skipped = skipped
def copy(self) -> BranchState:
return BranchState(
must_be_defined=set(self.must_be_defined),
may_be_defined=set(self.may_be_defined),
skipped=self.skipped,
)
class BranchStatement:
def __init__(self, initial_state: BranchState | None = None) -> None:
if initial_state is None:
initial_state = BranchState()
self.initial_state = initial_state
self.branches: list[BranchState] = [
BranchState(
must_be_defined=self.initial_state.must_be_defined,
may_be_defined=self.initial_state.may_be_defined,
)
]
def copy(self) -> BranchStatement:
result = BranchStatement(self.initial_state)
result.branches = [b.copy() for b in self.branches]
return result
def next_branch(self) -> None:
self.branches.append(
BranchState(
must_be_defined=self.initial_state.must_be_defined,
may_be_defined=self.initial_state.may_be_defined,
)
)
def record_definition(self, name: str) -> None:
assert len(self.branches) > 0
self.branches[-1].must_be_defined.add(name)
self.branches[-1].may_be_defined.discard(name)
def delete_var(self, name: str) -> None:
assert len(self.branches) > 0
self.branches[-1].must_be_defined.discard(name)
self.branches[-1].may_be_defined.discard(name)
def record_nested_branch(self, state: BranchState) -> None:
assert len(self.branches) > 0
current_branch = self.branches[-1]
if state.skipped:
current_branch.skipped = True
return
current_branch.must_be_defined.update(state.must_be_defined)
current_branch.may_be_defined.update(state.may_be_defined)
current_branch.may_be_defined.difference_update(current_branch.must_be_defined)
def skip_branch(self) -> None:
assert len(self.branches) > 0
self.branches[-1].skipped = True
def is_possibly_undefined(self, name: str) -> bool:
assert len(self.branches) > 0
return name in self.branches[-1].may_be_defined
def is_undefined(self, name: str) -> bool:
assert len(self.branches) > 0
branch = self.branches[-1]
return name not in branch.may_be_defined and name not in branch.must_be_defined
def is_defined_in_a_branch(self, name: str) -> bool:
assert len(self.branches) > 0
for b in self.branches:
if name in b.must_be_defined or name in b.may_be_defined:
return True
return False
def done(self) -> BranchState:
# First, compute all vars, including skipped branches. We include skipped branches
# because our goal is to capture all variables that semantic analyzer would
# consider defined.
all_vars = set()
for b in self.branches:
all_vars.update(b.may_be_defined)
all_vars.update(b.must_be_defined)
# For the rest of the things, we only care about branches that weren't skipped.
non_skipped_branches = [b for b in self.branches if not b.skipped]
if non_skipped_branches:
must_be_defined = non_skipped_branches[0].must_be_defined
for b in non_skipped_branches[1:]:
must_be_defined.intersection_update(b.must_be_defined)
else:
must_be_defined = set()
# Everything that wasn't defined in all branches but was defined
# in at least one branch should be in `may_be_defined`!
may_be_defined = all_vars.difference(must_be_defined)
return BranchState(
must_be_defined=must_be_defined,
may_be_defined=may_be_defined,
skipped=len(non_skipped_branches) == 0,
)
class ScopeType(Enum):
Global = 1
Class = 2
Func = 3
Generator = 4
class Scope:
def __init__(self, stmts: list[BranchStatement], scope_type: ScopeType) -> None:
self.branch_stmts: list[BranchStatement] = stmts
self.scope_type = scope_type
self.undefined_refs: dict[str, set[NameExpr]] = {}
def copy(self) -> Scope:
result = Scope([s.copy() for s in self.branch_stmts], self.scope_type)
result.undefined_refs = self.undefined_refs.copy()
return result
def record_undefined_ref(self, o: NameExpr) -> None:
if o.name not in self.undefined_refs:
self.undefined_refs[o.name] = set()
self.undefined_refs[o.name].add(o)
def pop_undefined_ref(self, name: str) -> set[NameExpr]:
return self.undefined_refs.pop(name, set())
class DefinedVariableTracker:
"""DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""
def __init__(self) -> None:
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
self.scopes: list[Scope] = [Scope([BranchStatement()], ScopeType.Global)]
# disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
# in things like try/except/finally statements.
self.disable_branch_skip = False
def copy(self) -> DefinedVariableTracker:
result = DefinedVariableTracker()
result.scopes = [s.copy() for s in self.scopes]
result.disable_branch_skip = self.disable_branch_skip
return result
def _scope(self) -> Scope:
assert len(self.scopes) > 0
return self.scopes[-1]
def enter_scope(self, scope_type: ScopeType) -> None:
assert len(self._scope().branch_stmts) > 0
initial_state = None
if scope_type == ScopeType.Generator:
# Generators are special because they inherit the outer scope.
initial_state = self._scope().branch_stmts[-1].branches[-1]
self.scopes.append(Scope([BranchStatement(initial_state)], scope_type))
def exit_scope(self) -> None:
self.scopes.pop()
def in_scope(self, scope_type: ScopeType) -> bool:
return self._scope().scope_type == scope_type
def start_branch_statement(self) -> None:
assert len(self._scope().branch_stmts) > 0
self._scope().branch_stmts.append(
BranchStatement(self._scope().branch_stmts[-1].branches[-1])
)
def next_branch(self) -> None:
assert len(self._scope().branch_stmts) > 1
self._scope().branch_stmts[-1].next_branch()
def end_branch_statement(self) -> None:
assert len(self._scope().branch_stmts) > 1
result = self._scope().branch_stmts.pop().done()
self._scope().branch_stmts[-1].record_nested_branch(result)
def skip_branch(self) -> None:
# Only skip branch if we're outside of "root" branch statement.
if len(self._scope().branch_stmts) > 1 and not self.disable_branch_skip:
self._scope().branch_stmts[-1].skip_branch()
def record_definition(self, name: str) -> None:
assert len(self.scopes) > 0
assert len(self.scopes[-1].branch_stmts) > 0
self._scope().branch_stmts[-1].record_definition(name)
def delete_var(self, name: str) -> None:
assert len(self.scopes) > 0
assert len(self.scopes[-1].branch_stmts) > 0
self._scope().branch_stmts[-1].delete_var(name)
def record_undefined_ref(self, o: NameExpr) -> None:
"""Records an undefined reference. These can later be retrieved via `pop_undefined_ref`."""
assert len(self.scopes) > 0
self._scope().record_undefined_ref(o)
def pop_undefined_ref(self, name: str) -> set[NameExpr]:
"""If name has previously been reported as undefined, the NameExpr that was called will be returned."""
assert len(self.scopes) > 0
return self._scope().pop_undefined_ref(name)
def is_possibly_undefined(self, name: str) -> bool:
assert len(self._scope().branch_stmts) > 0
# A variable is undefined if it's in a set of `may_be_defined` but not in `must_be_defined`.
return self._scope().branch_stmts[-1].is_possibly_undefined(name)
def is_defined_in_different_branch(self, name: str) -> bool:
"""This will return true if a variable is defined in a branch that's not the current branch."""
assert len(self._scope().branch_stmts) > 0
stmt = self._scope().branch_stmts[-1]
if not stmt.is_undefined(name):
return False
for stmt in self._scope().branch_stmts:
if stmt.is_defined_in_a_branch(name):
return True
return False
def is_undefined(self, name: str) -> bool:
assert len(self._scope().branch_stmts) > 0
return self._scope().branch_stmts[-1].is_undefined(name)
class Loop:
def __init__(self) -> None:
self.has_break = False
class PossiblyUndefinedVariableVisitor(ExtendedTraverserVisitor):
"""Detects the following cases:
- A variable that's defined only part of the time.
- If a variable is used before definition
An example of a partial definition:
if foo():
x = 1
print(x) # Error: "x" may be undefined.
Example of a used before definition:
x = y
y: int = 2
Note that this code does not detect variables not defined in any of the branches -- that is
handled by the semantic analyzer.
"""
def __init__(
self,
msg: MessageBuilder,
type_map: dict[Expression, Type],
options: Options,
names: SymbolTable,
) -> None:
self.msg = msg
self.type_map = type_map
self.options = options
self.builtins = SymbolTable()
builtins_mod = names.get("__builtins__", None)
if builtins_mod:
assert isinstance(builtins_mod.node, MypyFile)
self.builtins = builtins_mod.node.names
self.loops: list[Loop] = []
self.try_depth = 0
self.tracker = DefinedVariableTracker()
for name in implicit_module_attrs:
self.tracker.record_definition(name)
def var_used_before_def(self, name: str, context: Context) -> None:
if self.msg.errors.is_error_code_enabled(errorcodes.USED_BEFORE_DEF):
self.msg.var_used_before_def(name, context)
def variable_may_be_undefined(self, name: str, context: Context) -> None:
if self.msg.errors.is_error_code_enabled(errorcodes.POSSIBLY_UNDEFINED):
self.msg.variable_may_be_undefined(name, context)
def process_definition(self, name: str) -> None:
# Was this name previously used? If yes, it's a used-before-definition error.
if not self.tracker.in_scope(ScopeType.Class):
refs = self.tracker.pop_undefined_ref(name)
for ref in refs:
if self.loops:
self.variable_may_be_undefined(name, ref)
else:
self.var_used_before_def(name, ref)
else:
# Errors in class scopes are caught by the semantic analyzer.
pass
self.tracker.record_definition(name)
def visit_global_decl(self, o: GlobalDecl) -> None:
for name in o.names:
self.process_definition(name)
super().visit_global_decl(o)
def visit_nonlocal_decl(self, o: NonlocalDecl) -> None:
for name in o.names:
self.process_definition(name)
super().visit_nonlocal_decl(o)
def process_lvalue(self, lvalue: Lvalue | None) -> None:
if isinstance(lvalue, NameExpr):
self.process_definition(lvalue.name)
elif isinstance(lvalue, StarExpr):
self.process_lvalue(lvalue.expr)
elif isinstance(lvalue, (ListExpr, TupleExpr)):
for item in lvalue.items:
self.process_lvalue(item)
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
for lvalue in o.lvalues:
self.process_lvalue(lvalue)
super().visit_assignment_stmt(o)
def visit_assignment_expr(self, o: AssignmentExpr) -> None:
o.value.accept(self)
self.process_lvalue(o.target)
def visit_if_stmt(self, o: IfStmt) -> None:
for e in o.expr:
e.accept(self)
self.tracker.start_branch_statement()
for b in o.body:
if b.is_unreachable:
continue
b.accept(self)
self.tracker.next_branch()
if o.else_body:
if not o.else_body.is_unreachable:
o.else_body.accept(self)
else:
self.tracker.skip_branch()
self.tracker.end_branch_statement()
def visit_match_stmt(self, o: MatchStmt) -> None:
o.subject.accept(self)
self.tracker.start_branch_statement()
for i in range(len(o.patterns)):
pattern = o.patterns[i]
pattern.accept(self)
guard = o.guards[i]
if guard is not None:
guard.accept(self)
if not o.bodies[i].is_unreachable:
o.bodies[i].accept(self)
else:
self.tracker.skip_branch()
is_catchall = infer_pattern_value(pattern) == ALWAYS_TRUE
if not is_catchall:
self.tracker.next_branch()
self.tracker.end_branch_statement()
def visit_func_def(self, o: FuncDef) -> None:
self.process_definition(o.name)
super().visit_func_def(o)
def visit_func(self, o: FuncItem) -> None:
if o.is_dynamic() and not self.options.check_untyped_defs:
return
args = o.arguments or []
# Process initializers (defaults) outside the function scope.
for arg in args:
if arg.initializer is not None:
arg.initializer.accept(self)
self.tracker.enter_scope(ScopeType.Func)
for arg in args:
self.process_definition(arg.variable.name)
super().visit_var(arg.variable)
o.body.accept(self)
self.tracker.exit_scope()
def visit_generator_expr(self, o: GeneratorExpr) -> None:
self.tracker.enter_scope(ScopeType.Generator)
for idx in o.indices:
self.process_lvalue(idx)
super().visit_generator_expr(o)
self.tracker.exit_scope()
def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None:
self.tracker.enter_scope(ScopeType.Generator)
for idx in o.indices:
self.process_lvalue(idx)
super().visit_dictionary_comprehension(o)
self.tracker.exit_scope()
def visit_for_stmt(self, o: ForStmt) -> None:
o.expr.accept(self)
self.process_lvalue(o.index)
o.index.accept(self)
self.tracker.start_branch_statement()
loop = Loop()
self.loops.append(loop)
o.body.accept(self)
self.tracker.next_branch()
self.tracker.end_branch_statement()
if o.else_body is not None:
# If the loop has a `break` inside, `else` is executed conditionally.
# If the loop doesn't have a `break` either the function will return or
# execute the `else`.
has_break = loop.has_break
if has_break:
self.tracker.start_branch_statement()
self.tracker.next_branch()
o.else_body.accept(self)
if has_break:
self.tracker.end_branch_statement()
self.loops.pop()
def visit_return_stmt(self, o: ReturnStmt) -> None:
super().visit_return_stmt(o)
self.tracker.skip_branch()
def visit_lambda_expr(self, o: LambdaExpr) -> None:
self.tracker.enter_scope(ScopeType.Func)
super().visit_lambda_expr(o)
self.tracker.exit_scope()
def visit_assert_stmt(self, o: AssertStmt) -> None:
super().visit_assert_stmt(o)
if checker.is_false_literal(o.expr):
self.tracker.skip_branch()
def visit_raise_stmt(self, o: RaiseStmt) -> None:
super().visit_raise_stmt(o)
self.tracker.skip_branch()
def visit_continue_stmt(self, o: ContinueStmt) -> None:
super().visit_continue_stmt(o)
self.tracker.skip_branch()
def visit_break_stmt(self, o: BreakStmt) -> None:
super().visit_break_stmt(o)
if self.loops:
self.loops[-1].has_break = True
self.tracker.skip_branch()
def visit_expression_stmt(self, o: ExpressionStmt) -> None:
if isinstance(self.type_map.get(o.expr, None), (UninhabitedType, type(None))):
self.tracker.skip_branch()
super().visit_expression_stmt(o)
def visit_try_stmt(self, o: TryStmt) -> None:
"""
Note that finding undefined vars in `finally` requires different handling from
the rest of the code. In particular, we want to disallow skipping branches due to jump
statements in except/else clauses for finally but not for other cases. Imagine a case like:
def f() -> int:
try:
x = 1
except:
# This jump statement needs to be handled differently depending on whether or
# not we're trying to process `finally` or not.
return 0
finally:
# `x` may be undefined here.
pass
# `x` is always defined here.
return x
"""
self.try_depth += 1
if o.finally_body is not None:
# In order to find undefined vars in `finally`, we need to
# process try/except with branch skipping disabled. However, for the rest of the code
# after finally, we need to process try/except with branch skipping enabled.
# Therefore, we need to process try/finally twice.
# Because processing is not idempotent, we should make a copy of the tracker.
old_tracker = self.tracker.copy()
self.tracker.disable_branch_skip = True
self.process_try_stmt(o)
self.tracker = old_tracker
self.process_try_stmt(o)
self.try_depth -= 1
def process_try_stmt(self, o: TryStmt) -> None:
"""
Processes try statement decomposing it into the following:
if ...:
body
else_body
elif ...:
except 1
elif ...:
except 2
else:
except n
finally
"""
self.tracker.start_branch_statement()
o.body.accept(self)
if o.else_body is not None:
o.else_body.accept(self)
if len(o.handlers) > 0:
assert len(o.handlers) == len(o.vars) == len(o.types)
for i in range(len(o.handlers)):
self.tracker.next_branch()
exc_type = o.types[i]
if exc_type is not None:
exc_type.accept(self)
var = o.vars[i]
if var is not None:
self.process_definition(var.name)
var.accept(self)
o.handlers[i].accept(self)
if var is not None:
self.tracker.delete_var(var.name)
self.tracker.end_branch_statement()
if o.finally_body is not None:
o.finally_body.accept(self)
def visit_while_stmt(self, o: WhileStmt) -> None:
o.expr.accept(self)
self.tracker.start_branch_statement()
loop = Loop()
self.loops.append(loop)
o.body.accept(self)
has_break = loop.has_break
if not checker.is_true_literal(o.expr):
# If this is a loop like `while True`, we can consider the body to be
# a single branch statement (we're guaranteed that the body is executed at least once).
# If not, call next_branch() to make all variables defined there conditional.
self.tracker.next_branch()
self.tracker.end_branch_statement()
if o.else_body is not None:
# If the loop has a `break` inside, `else` is executed conditionally.
# If the loop doesn't have a `break` either the function will return or
# execute the `else`.
if has_break:
self.tracker.start_branch_statement()
self.tracker.next_branch()
if o.else_body:
o.else_body.accept(self)
if has_break:
self.tracker.end_branch_statement()
self.loops.pop()
def visit_as_pattern(self, o: AsPattern) -> None:
if o.name is not None:
self.process_lvalue(o.name)
super().visit_as_pattern(o)
def visit_starred_pattern(self, o: StarredPattern) -> None:
if o.capture is not None:
self.process_lvalue(o.capture)
super().visit_starred_pattern(o)
def visit_name_expr(self, o: NameExpr) -> None:
if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global):
return
if self.tracker.is_possibly_undefined(o.name):
# A variable is only defined in some branches.
self.variable_may_be_undefined(o.name, o)
# We don't want to report the error on the same variable multiple times.
self.tracker.record_definition(o.name)
elif self.tracker.is_defined_in_different_branch(o.name):
# A variable is defined in one branch but used in a different branch.
if self.loops or self.try_depth > 0:
# If we're in a loop or in a try, we can't be sure that this variable
# is undefined. Report it as "may be undefined".
self.variable_may_be_undefined(o.name, o)
else:
self.var_used_before_def(o.name, o)
elif self.tracker.is_undefined(o.name):
# A variable is undefined. It could be due to two things:
# 1. A variable is just totally undefined
# 2. The variable is defined later in the code.
# Case (1) will be caught by semantic analyzer. Case (2) is a forward ref that should
# be caught by this visitor. Save the ref for later, so that if we see a definition,
# we know it's a used-before-definition scenario.
self.tracker.record_undefined_ref(o)
super().visit_name_expr(o)
def visit_with_stmt(self, o: WithStmt) -> None:
for expr, idx in zip(o.expr, o.target):
expr.accept(self)
self.process_lvalue(idx)
o.body.accept(self)
def visit_class_def(self, o: ClassDef) -> None:
self.process_definition(o.name)
self.tracker.enter_scope(ScopeType.Class)
super().visit_class_def(o)
self.tracker.exit_scope()
def visit_import(self, o: Import) -> None:
for mod, alias in o.ids:
if alias is not None:
self.tracker.record_definition(alias)
else:
# When you do `import x.y`, only `x` becomes defined.
names = mod.split(".")
if names:
# `names` should always be nonempty, but we don't want mypy
# to crash on invalid code.
self.tracker.record_definition(names[0])
super().visit_import(o)
def visit_import_from(self, o: ImportFrom) -> None:
for mod, alias in o.names:
name = alias
if name is None:
name = mod
self.tracker.record_definition(name)
super().visit_import_from(o)
def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
# Type alias target may contain forward references
self.tracker.record_definition(o.name.name)