Skip to content

Commit

Permalink
[mypyc] Improve support for compiling singledispatch (python#10795)
Browse files Browse the repository at this point in the history
This makes several improvements to the support for compiling singledispatch that was introduced in python#10753 by:

* Making sure registered implementations defined later in a file take precedence when multiple overlap
 * Using non-native calls to registered implementations to allow for adding other decorators to registered functions (099b047)
 * Creating a separate function that dispatches to the correct implementation instead of adding code to dispatch to one of the registered implementations directly into the main singledispatch function, allowing the main singledispatch function to be a generator (59555e4)
 * Avoiding a compilation error when trying to dispatch on an ABC (2d40421)
  • Loading branch information
pranavrajpal authored Jul 12, 2021
1 parent 66cae4b commit 7d69ce2
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 64 deletions.
4 changes: 1 addition & 3 deletions mypyc/irbuild/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def __init__(self,
is_nested: bool = False,
contains_nested: bool = False,
is_decorated: bool = False,
in_non_ext: bool = False,
is_singledispatch: bool = False) -> None:
in_non_ext: bool = False) -> None:
self.fitem = fitem
self.name = name if not is_decorated else decorator_helper_name(name)
self.class_name = class_name
Expand All @@ -48,7 +47,6 @@ def __init__(self,
self.contains_nested = contains_nested
self.is_decorated = is_decorated
self.in_non_ext = in_non_ext
self.is_singledispatch = is_singledispatch

# TODO: add field for ret_type: RType = none_rprimitive

Expand Down
77 changes: 61 additions & 16 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
setup_func_for_recursive_call
)

from mypyc.primitives.registry import builtin_names


# Top-level transform functions

Expand Down Expand Up @@ -219,9 +221,12 @@ def c() -> None:
in_non_ext = not ir.is_ext_class
class_name = cdef.name

builder.enter(FuncInfo(fitem, name, class_name, gen_func_ns(builder),
is_nested, contains_nested, is_decorated, in_non_ext,
is_singledispatch))
if is_singledispatch:
func_name = '__mypyc_singledispatch_main_function_{}__'.format(name)
else:
func_name = name
builder.enter(FuncInfo(fitem, func_name, class_name, gen_func_ns(builder),
is_nested, contains_nested, is_decorated, in_non_ext))

# Functions that contain nested functions need an environment class to store variables that
# are free in their nested functions. Generator functions need an environment class to
Expand Down Expand Up @@ -254,9 +259,6 @@ def c() -> None:
if builder.fn_info.contains_nested and not builder.fn_info.is_generator:
finalize_env_class(builder)

if builder.fn_info.is_singledispatch:
add_singledispatch_registered_impls(builder)

builder.ret_types[-1] = sig.ret_type

# Add all variables and functions that are declared/defined within this
Expand Down Expand Up @@ -313,6 +315,15 @@ def c() -> None:
# calculate them *once* when the function definition is evaluated.
calculate_arg_defaults(builder, fn_info, func_reg, symtable)

if is_singledispatch:
# add the generated main singledispatch function
builder.functions.append(func_ir)
# create the dispatch function
assert isinstance(fitem, FuncDef)
dispatch_name = decorator_helper_name(name) if is_decorated else name
dispatch_func_ir = gen_dispatch_func_ir(builder, fitem, fn_info.name, dispatch_name, sig)
return dispatch_func_ir, None

return (func_ir, func_reg)


Expand Down Expand Up @@ -768,28 +779,62 @@ def check_if_isinstance(builder: IRBuilder, obj: Value, typ: TypeInfo, line: int
class_ir = builder.mapper.type_to_ir[typ]
return builder.builder.isinstance_native(obj, class_ir, line)
else:
class_obj = builder.load_module_attr_by_fullname(typ.fullname, line)
if typ.fullname in builtin_names:
builtin_addr_type, src = builtin_names[typ.fullname]
class_obj = builder.add(LoadAddress(builtin_addr_type, src, line))
else:
class_obj = builder.load_global_str(typ.name, line)
return builder.call_c(slow_isinstance_op, [obj, class_obj], line)


def add_singledispatch_registered_impls(builder: IRBuilder) -> None:
fitem = builder.fn_info.fitem
assert isinstance(fitem, FuncDef)
def generate_singledispatch_dispatch_function(
builder: IRBuilder,
main_singledispatch_function_name: str,
fitem: FuncDef,
) -> None:
impls = builder.singledispatch_impls[fitem]
line = fitem.line
current_func_decl = builder.mapper.func_to_decl[fitem]
arg_info = get_args(builder, current_func_decl.sig.args, line)
for dispatch_type, impl in impls:
func_decl = builder.mapper.func_to_decl[impl]

def gen_func_call_and_return(func_name: str) -> None:
func = builder.load_global_str(func_name, line)
# TODO: don't pass optional arguments if they weren't passed to this function
ret_val = builder.builder.py_call(
func, arg_info.args, line, arg_info.arg_kinds, arg_info.arg_names
)
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
builder.nonlocal_control[-1].gen_return(builder, coerced, line)

# Reverse the list of registered implementations so we use the implementations defined later
# if there are multiple overlapping implementations
for dispatch_type, impl in reversed(impls):
call_impl, next_impl = BasicBlock(), BasicBlock()
should_call_impl = check_if_isinstance(builder, arg_info.args[0], dispatch_type, line)
builder.add_bool_branch(should_call_impl, call_impl, next_impl)

# Call the registered implementation
builder.activate_block(call_impl)

ret_val = builder.builder.call(
func_decl, arg_info.args, arg_info.arg_kinds, arg_info.arg_names, line
)
builder.nonlocal_control[-1].gen_return(builder, ret_val, line)
gen_func_call_and_return(impl.name)
builder.activate_block(next_impl)

gen_func_call_and_return(main_singledispatch_function_name)


def gen_dispatch_func_ir(
builder: IRBuilder,
fitem: FuncDef,
main_func_name: str,
dispatch_name: str,
sig: FuncSignature,
) -> FuncIR:
"""Create a dispatch function (a function that checks the first argument type and dispatches
to the correct implementation)
"""
builder.enter()
generate_singledispatch_dispatch_function(builder, main_func_name, fitem)
args, _, blocks, _, fn_info = builder.leave()
func_decl = FuncDecl(dispatch_name, None, builder.module_name, sig)
dispatch_func_ir = FuncIR(func_decl, args, blocks)
return dispatch_func_ir
2 changes: 1 addition & 1 deletion mypyc/irbuild/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def build_ir(modules: List[MypyFile],

for module in modules:
# First pass to determine free symbols.
pbv = PreBuildVisitor()
pbv = PreBuildVisitor(errors, module)
module.accept(pbv)

# Construct and configure builder objects (cyclic runtime dependency).
Expand Down
24 changes: 22 additions & 2 deletions mypyc/irbuild/prebuildvisitor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from mypyc.errors import Errors
from mypy.types import Instance, get_proper_type
from typing import DefaultDict, Dict, List, NamedTuple, Set, Optional, Tuple
from collections import defaultdict

from mypy.nodes import (
Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr,
CallExpr, RefExpr, TypeInfo
CallExpr, RefExpr, TypeInfo, MypyFile
)
from mypy.traverser import TraverserVisitor

Expand All @@ -23,7 +24,7 @@ class PreBuildVisitor(TraverserVisitor):
The main IR build pass uses this information.
"""

def __init__(self) -> None:
def __init__(self, errors: Errors, current_file: MypyFile) -> None:
super().__init__()
# Dict from a function to symbols defined directly in the
# function that are used as non-local (free) variables within a
Expand Down Expand Up @@ -57,6 +58,10 @@ def __init__(self) -> None:
self.singledispatch_impls: DefaultDict[
FuncDef, List[Tuple[TypeInfo, FuncDef]]] = defaultdict(list)

self.errors: Errors = errors

self.current_file: MypyFile = current_file

def visit_decorator(self, dec: Decorator) -> None:
if dec.decorators:
# Only add the function being decorated if there exist
Expand All @@ -72,12 +77,27 @@ def visit_decorator(self, dec: Decorator) -> None:
else:
decorators_to_store = dec.decorators.copy()
removed: List[int] = []
# the index of the last non-register decorator before finding a register decorator
# when going through decorators from top to bottom
last_non_register: Optional[int] = None
for i, d in enumerate(decorators_to_store):
impl = get_singledispatch_register_call_info(d, dec.func)
if impl is not None:
self.singledispatch_impls[impl.singledispatch_func].append(
(impl.dispatch_type, dec.func))
removed.append(i)
if last_non_register is not None:
# found a register decorator after a non-register decorator, which we
# don't support because we'd have to make a copy of the function before
# calling the decorator so that we can call it later, which complicates
# the implementation for something that is probably not commonly used
self.errors.error(
"Calling decorator after registering function not supported",
self.current_file.path,
decorators_to_store[last_non_register].line,
)
else:
last_non_register = i
# calling register on a function that tries to dispatch based on type annotations
# raises a TypeError because compiled functions don't have an __annotations__
# attribute
Expand Down
21 changes: 21 additions & 0 deletions mypyc/test-data/commandline.test
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def f(x: int) -> int:
from typing import List, Any, AsyncIterable
from typing_extensions import Final
from mypy_extensions import trait, mypyc_attr
from functools import singledispatch

def busted(b: bool) -> None:
for i in range(1, 10, 0): # E: range() step can't be zero
Expand Down Expand Up @@ -219,3 +220,23 @@ async def async_with() -> None:

async def async_generators() -> AsyncIterable[int]:
yield 1 # E: async generators are unimplemented

@singledispatch
def a(arg) -> None:
pass

@decorator # E: Calling decorator after registering function not supported
@a.register
def g(arg: int) -> None:
pass

@a.register
@decorator
def h(arg: str) -> None:
pass

@decorator
@decorator # E: Calling decorator after registering function not supported
@a.register
def i(arg: Foo) -> None:
pass
Loading

0 comments on commit 7d69ce2

Please sign in to comment.