Skip to content

Commit

Permalink
[stubgen] Improve self annotations (python#18420)
Browse files Browse the repository at this point in the history
Print annotations for self variables if given. Aside from the most
common ones for `str`, `int`, `bool` etc. those were previously inferred
as `Incomplete`.
  • Loading branch information
cdce8p authored Jan 7, 2025
1 parent 8951a33 commit 32b860e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
10 changes: 5 additions & 5 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,11 +648,11 @@ def visit_func_def(self, o: FuncDef) -> None:
self.add("\n")
if not self.is_top_level():
self_inits = find_self_initializers(o)
for init, value in self_inits:
for init, value, annotation in self_inits:
if init in self.method_names:
# Can't have both an attribute and a method/property with the same name.
continue
init_code = self.get_init(init, value)
init_code = self.get_init(init, value, annotation)
if init_code:
self.add(init_code)

Expand Down Expand Up @@ -1414,7 +1414,7 @@ def find_method_names(defs: list[Statement]) -> set[str]:

class SelfTraverser(mypy.traverser.TraverserVisitor):
def __init__(self) -> None:
self.results: list[tuple[str, Expression]] = []
self.results: list[tuple[str, Expression, Type | None]] = []

def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
lvalue = o.lvalues[0]
Expand All @@ -1423,10 +1423,10 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
and isinstance(lvalue.expr, NameExpr)
and lvalue.expr.name == "self"
):
self.results.append((lvalue.name, o.rvalue))
self.results.append((lvalue.name, o.rvalue, o.unanalyzed_type))


def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression]]:
def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression, Type | None]]:
"""Find attribute initializers in a method.
Return a list of pairs (attribute name, r.h.s. expression).
Expand Down
11 changes: 11 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,24 @@ class C:
def __init__(self, x: str) -> None: ...

[case testSelfAssignment]
from mod import A
from typing import Any, Dict, Union
class C:
def __init__(self):
self.a: A = A()
self.x = 1
x.y = 2
self.y: Dict[str, Any] = {}
self.z: Union[int, str, bool, None] = None
[out]
from mod import A
from typing import Any

class C:
a: A
x: int
y: dict[str, Any]
z: int | str | bool | None
def __init__(self) -> None: ...

[case testSelfAndClassBodyAssignment]
Expand Down

0 comments on commit 32b860e

Please sign in to comment.