Skip to content

Commit

Permalink
stubgen: properly convert overloaded functions (#9613)
Browse files Browse the repository at this point in the history
  • Loading branch information
chadrik authored Jan 15, 2021
1 parent e9edcb9 commit 92923b2
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 15 deletions.
87 changes: 72 additions & 15 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,10 +559,33 @@ def visit_mypy_file(self, o: MypyFile) -> None:
for name in sorted(undefined_names):
self.add('# %s\n' % name)

def visit_func_def(self, o: FuncDef, is_abstract: bool = False) -> None:
def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
"""@property with setters and getters, or @overload chain"""
overload_chain = False
for item in o.items:
if not isinstance(item, Decorator):
continue

if self.is_private_name(item.func.name, item.func.fullname):
continue

is_abstract, is_overload = self.process_decorator(item)

if not overload_chain:
self.visit_func_def(item.func, is_abstract=is_abstract, is_overload=is_overload)
if is_overload:
overload_chain = True
elif overload_chain and is_overload:
self.visit_func_def(item.func, is_abstract=is_abstract, is_overload=is_overload)
else:
# skip the overload implementation and clear the decorator we just processed
self.clear_decorators()

def visit_func_def(self, o: FuncDef, is_abstract: bool = False,
is_overload: bool = False) -> None:
if (self.is_private_name(o.name, o.fullname)
or self.is_not_in_all(o.name)
or self.is_recorded_name(o.name)):
or (self.is_recorded_name(o.name) and not is_overload)):
self.clear_decorators()
return
if not self._indent and self._state not in (EMPTY, FUNC) and not o.is_awaitable_coroutine:
Expand Down Expand Up @@ -599,7 +622,7 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False) -> None:
and not is_cls_arg):
self.add_typing_import("Any")
annotation = ": {}".format(self.typing_name("Any"))
elif annotated_type and not is_self_arg:
elif annotated_type and not is_self_arg and not is_cls_arg:
annotation = ": {}".format(self.print_annotation(annotated_type))
else:
annotation = ""
Expand Down Expand Up @@ -642,24 +665,43 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False) -> None:
def visit_decorator(self, o: Decorator) -> None:
if self.is_private_name(o.func.name, o.func.fullname):
return

is_abstract, _ = self.process_decorator(o)
self.visit_func_def(o.func, is_abstract=is_abstract)

def process_decorator(self, o: Decorator) -> Tuple[bool, bool]:
"""Process a series of decorataors.
Only preserve certain special decorators such as @abstractmethod.
Return a pair of booleans:
- True if any of the decorators makes a method abstract.
- True if any of the decorators is typing.overload.
"""
is_abstract = False
is_overload = False
for decorator in o.original_decorators:
if isinstance(decorator, NameExpr):
if self.process_name_expr_decorator(decorator, o):
is_abstract = True
i_is_abstract, i_is_overload = self.process_name_expr_decorator(decorator, o)
is_abstract = is_abstract or i_is_abstract
is_overload = is_overload or i_is_overload
elif isinstance(decorator, MemberExpr):
if self.process_member_expr_decorator(decorator, o):
is_abstract = True
self.visit_func_def(o.func, is_abstract=is_abstract)
i_is_abstract, i_is_overload = self.process_member_expr_decorator(decorator, o)
is_abstract = is_abstract or i_is_abstract
is_overload = is_overload or i_is_overload
return is_abstract, is_overload

def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> bool:
def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> Tuple[bool, bool]:
"""Process a function decorator of form @foo.
Only preserve certain special decorators such as @abstractmethod.
Return True if the decorator makes a method abstract.
Return a pair of booleans:
- True if the decorator makes a method abstract.
- True if the decorator is typing.overload.
"""
is_abstract = False
is_overload = False
name = expr.name
if name in ('property', 'staticmethod', 'classmethod'):
self.add_decorator(name)
Expand All @@ -675,27 +717,35 @@ def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> boo
self.add_decorator('property')
self.add_decorator('abc.abstractmethod')
is_abstract = True
return is_abstract
elif self.refers_to_fullname(name, 'typing.overload'):
self.add_decorator(name)
self.add_typing_import('overload')
is_overload = True
return is_abstract, is_overload

def refers_to_fullname(self, name: str, fullname: str) -> bool:
module, short = fullname.rsplit('.', 1)
return (self.import_tracker.module_for.get(name) == module and
(name == short or
self.import_tracker.reverse_alias.get(name) == short))

def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> bool:
def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> Tuple[bool,
bool]:
"""Process a function decorator of form @foo.bar.
Only preserve certain special decorators such as @abstractmethod.
Return True if the decorator makes a method abstract.
Return a pair of booleans:
- True if the decorator makes a method abstract.
- True if the decorator is typing.overload.
"""
is_abstract = False
is_overload = False
if expr.name == 'setter' and isinstance(expr.expr, NameExpr):
self.add_decorator('%s.setter' % expr.expr.name)
elif (isinstance(expr.expr, NameExpr) and
(expr.expr.name == 'abc' or
self.import_tracker.reverse_alias.get('abc')) and
self.import_tracker.reverse_alias.get(expr.expr.name) == 'abc') and
expr.name in ('abstractmethod', 'abstractproperty')):
if expr.name == 'abstractproperty':
self.import_tracker.require_name(expr.expr.name)
Expand Down Expand Up @@ -723,7 +773,14 @@ def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) ->
self.add_coroutine_decorator(context.func,
expr.expr.name + '.coroutine',
expr.expr.name)
return is_abstract
elif (isinstance(expr.expr, NameExpr) and
(expr.expr.name == 'typing' or
self.import_tracker.reverse_alias.get(expr.expr.name) == 'typing') and
expr.name == 'overload'):
self.import_tracker.require_name(expr.expr.name)
self.add_decorator('%s.%s' % (expr.expr.name, 'overload'))
is_overload = True
return is_abstract, is_overload

def visit_class_def(self, o: ClassDef) -> None:
self.method_names = find_method_names(o.defs.body)
Expand Down
181 changes: 181 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,20 @@ class A(metaclass=abc.ABCMeta):
@abc.abstractmethod
def meth(self): ...

[case testAbstractMethodMemberExpr2]
import abc as _abc

class A(metaclass=abc.ABCMeta):
@_abc.abstractmethod
def meth(self):
pass
[out]
import abc as _abc

class A(metaclass=abc.ABCMeta):
@_abc.abstractmethod
def meth(self): ...

[case testABCMeta_semanal]
from base import Base
from abc import abstractmethod
Expand Down Expand Up @@ -2288,3 +2302,170 @@ import p.a

x: a.X
y: p.a.Y

[case testOverload_fromTypingImport]
from typing import Tuple, Union, overload

class A:
@overload
def f(self, x: int, y: int) -> int:
...

@overload
def f(self, x: Tuple[int, int]) -> int:
...

def f(self, *args: Union[int, Tuple[int, int]]) -> int:
pass

@overload
def f(x: int, y: int) -> int:
...

@overload
def f(x: Tuple[int, int]) -> int:
...

def f(*args: Union[int, Tuple[int, int]]) -> int:
pass


[out]
from typing import Tuple, overload

class A:
@overload
def f(self, x: int, y: int) -> int: ...
@overload
def f(self, x: Tuple[int, int]) -> int: ...


@overload
def f(x: int, y: int) -> int: ...
@overload
def f(x: Tuple[int, int]) -> int: ...

[case testOverload_importTyping]
import typing

class A:
@typing.overload
def f(self, x: int, y: int) -> int:
...

@typing.overload
def f(self, x: typing.Tuple[int, int]) -> int:
...

def f(self, *args: typing.Union[int, typing.Tuple[int, int]]) -> int:
pass

@typing.overload
@classmethod
def g(cls, x: int, y: int) -> int:
...

@typing.overload
@classmethod
def g(cls, x: typing.Tuple[int, int]) -> int:
...

@classmethod
def g(self, *args: typing.Union[int, typing.Tuple[int, int]]) -> int:
pass

@typing.overload
def f(x: int, y: int) -> int:
...

@typing.overload
def f(x: typing.Tuple[int, int]) -> int:
...

def f(*args: typing.Union[int, typing.Tuple[int, int]]) -> int:
pass


[out]
import typing

class A:
@typing.overload
def f(self, x: int, y: int) -> int: ...
@typing.overload
def f(self, x: typing.Tuple[int, int]) -> int: ...
@typing.overload
@classmethod
def g(cls, x: int, y: int) -> int: ...
@typing.overload
@classmethod
def g(cls, x: typing.Tuple[int, int]) -> int: ...


@typing.overload
def f(x: int, y: int) -> int: ...
@typing.overload
def f(x: typing.Tuple[int, int]) -> int: ...


[case testOverload_importTypingAs]
import typing as t

class A:
@t.overload
def f(self, x: int, y: int) -> int:
...

@t.overload
def f(self, x: t.Tuple[int, int]) -> int:
...

def f(self, *args: typing.Union[int, t.Tuple[int, int]]) -> int:
pass

@t.overload
@classmethod
def g(cls, x: int, y: int) -> int:
...

@t.overload
@classmethod
def g(cls, x: t.Tuple[int, int]) -> int:
...

@classmethod
def g(self, *args: t.Union[int, t.Tuple[int, int]]) -> int:
pass

@t.overload
def f(x: int, y: int) -> int:
...

@t.overload
def f(x: t.Tuple[int, int]) -> int:
...

def f(*args: t.Union[int, t.Tuple[int, int]]) -> int:
pass


[out]
import typing as t

class A:
@t.overload
def f(self, x: int, y: int) -> int: ...
@t.overload
def f(self, x: t.Tuple[int, int]) -> int: ...
@t.overload
@classmethod
def g(cls, x: int, y: int) -> int: ...
@t.overload
@classmethod
def g(cls, x: t.Tuple[int, int]) -> int: ...


@t.overload
def f(x: int, y: int) -> int: ...
@t.overload
def f(x: t.Tuple[int, int]) -> int: ...

0 comments on commit 92923b2

Please sign in to comment.