Skip to content

Commit

Permalink
Make refcounting handle error values loaded into registers (python#10819
Browse files Browse the repository at this point in the history
)

Currently the finally path has a hacky pointless branch
on a register that is potentially NULL in order to convince
refcounting to do the right thing.

I want to add another case that would require this, so just fix
refcounting to properly account for error values in registers.
  • Loading branch information
msullivan authored Jul 14, 2021
1 parent 7d69ce2 commit d0bd1c8
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 32 deletions.
19 changes: 16 additions & 3 deletions mypyc/analysis/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,19 @@ class DefinedVisitor(BaseAnalysisVisitor):
Note that this only deals with registers and not temporaries, on
the assumption that we never access temporaries when they might be
undefined.
If strict_errors is True, then we regard any use of LoadErrorValue
as making a register undefined. Otherwise we only do if
`undefines` is set on the error value.
This lets us only consider the things we care about during
uninitialized variable checking while capturing all possibly
undefined things for refcounting.
"""

def __init__(self, strict_errors: bool = False) -> None:
self.strict_errors = strict_errors

def visit_branch(self, op: Branch) -> GenAndKill:
return set(), set()

Expand All @@ -252,7 +263,8 @@ def visit_register_op(self, op: RegisterOp) -> GenAndKill:

def visit_assign(self, op: Assign) -> GenAndKill:
# Loading an error value may undefine the register.
if isinstance(op.src, LoadErrorValue) and op.src.undefines:
if (isinstance(op.src, LoadErrorValue)
and (op.src.undefines or self.strict_errors)):
return set(), {op.dest}
else:
return {op.dest}, set()
Expand Down Expand Up @@ -284,7 +296,8 @@ def analyze_must_defined_regs(
blocks: List[BasicBlock],
cfg: CFG,
initial_defined: Set[Value],
regs: Iterable[Value]) -> AnalysisResult[Value]:
regs: Iterable[Value],
strict_errors: bool = False) -> AnalysisResult[Value]:
"""Calculate always defined registers at each CFG location.
This analysis can work before exception insertion, since it is a
Expand All @@ -296,7 +309,7 @@ def analyze_must_defined_regs(
"""
return run_analysis(blocks=blocks,
cfg=cfg,
gen_and_kill=DefinedVisitor(),
gen_and_kill=DefinedVisitor(strict_errors=strict_errors),
initial=initial_defined,
backward=False,
kind=MUST_ANALYSIS,
Expand Down
8 changes: 0 additions & 8 deletions mypyc/irbuild/nonlocalcontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,6 @@ def __init__(self, outer: NonlocalControl, ret_reg: Optional[Value], saved: Valu
self.saved = saved

def gen_cleanup(self, builder: 'IRBuilder', line: int) -> None:
# Do an error branch on the return value register, which
# may be undefined. This will allow it to be properly
# decrefed if it is not null. This is kind of a hack.
if self.ret_reg:
target = BasicBlock()
builder.add(Branch(self.ret_reg, target, target, Branch.IS_ERROR))
builder.activate_block(target)

# Restore the old exc_info
target, cleanup = BasicBlock(), BasicBlock()
builder.add(Branch(self.saved, target, cleanup, Branch.IS_ERROR))
Expand Down
40 changes: 20 additions & 20 deletions mypyc/test-data/exceptions.test
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ L1:
L2:
r3 = PyObject_CallFunctionObjArgs(r2, 0)
dec_ref r2
if is_error(r3) goto L5 (error at a:3) else goto L20
if is_error(r3) goto L5 (error at a:3) else goto L19
L3:
r4 = 'hi'
inc_ref r4
Expand All @@ -278,52 +278,52 @@ L6:
r11 = builtins :: module
r12 = 'print'
r13 = CPyObject_GetAttr(r11, r12)
if is_error(r13) goto L13 (error at a:6) else goto L7
if is_error(r13) goto L20 (error at a:6) else goto L7
L7:
r14 = PyObject_CallFunctionObjArgs(r13, r10, 0)
dec_ref r13
if is_error(r14) goto L13 (error at a:6) else goto L21
if is_error(r14) goto L20 (error at a:6) else goto L21
L8:
if is_error(r7) goto L11 else goto L9
if is_error(r7) goto L11 else goto L22
L9:
CPy_Reraise()
if not 0 goto L13 else goto L22 :: bool
if not 0 goto L13 else goto L23 :: bool
L10:
unreachable
L11:
if is_error(r5) goto L18 else goto L12
if is_error(r5) goto L17 else goto L12
L12:
return r5
L13:
if is_error(r5) goto L14 else goto L23
if is_error(r7) goto L15 else goto L14
L14:
if is_error(r7) goto L16 else goto L15
L15:
CPy_RestoreExcInfo(r7)
dec_ref r7
L16:
xdec_ref r7
L15:
r15 = CPy_KeepPropagating()
if not r15 goto L19 else goto L17 :: bool
if not r15 goto L18 else goto L16 :: bool
L16:
unreachable
L17:
unreachable
L18:
unreachable
L19:
r16 = <error> :: str
return r16
L20:
L19:
dec_ref r3
goto L3
L20:
xdec_ref r5
goto L13
L21:
dec_ref r14
goto L8
L22:
dec_ref r5
dec_ref r7
goto L10
xdec_ref r5
goto L9
L23:
dec_ref r5
goto L14
xdec_ref r7
goto L10

[case testDocstring1]
def lol() -> None:
Expand Down
2 changes: 1 addition & 1 deletion mypyc/transform/refcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def insert_ref_count_opcodes(ir: FuncIR) -> None:
args: Set[Value] = set(ir.arg_regs)
live = analyze_live_regs(ir.blocks, cfg)
borrow = analyze_borrowed_arguments(ir.blocks, cfg, borrowed)
defined = analyze_must_defined_regs(ir.blocks, cfg, args, values)
defined = analyze_must_defined_regs(ir.blocks, cfg, args, values, strict_errors=True)
ordering = make_value_ordering(ir)
cache: BlockCache = {}
for block in ir.blocks[:]:
Expand Down

0 comments on commit d0bd1c8

Please sign in to comment.