Skip to content

Commit

Permalink
Cleans up redundant head scope for ClassDef.
Browse files Browse the repository at this point in the history
Fixes runtime symbol lookup into outer scopes.
Treats NamedExpr like definitions.
Improves scope handling of comprehensions and if expressions.
  • Loading branch information
Daverball authored and sondrelg committed Nov 25, 2023
1 parent 38810c8 commit 48fbeda
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 33 deletions.
198 changes: 165 additions & 33 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ def ast_unparse(node: ast.AST) -> str:
from collections.abc import Iterator
from typing import Any, Optional, Union

from flake8_type_checking.types import Flake8Generator, Function, HasPosition, Import, ImportTypeValue, Name
from flake8_type_checking.types import (
Comprehension,
Flake8Generator,
Function,
HasPosition,
Import,
ImportTypeValue,
Name,
)


class AttrsMixin:
Expand Down Expand Up @@ -420,8 +428,9 @@ class Scope:
when comprehension inlining becomes a thing and it no longer generates a
new stack frame.
For ClassDef/FunctionDef/AsyncFunctionDef we create a tiny virtual scope
for the head to properly handle PEP695 parameter scopes.
For FunctionDef/AsyncFunctionDef we create a tiny virtual scope for the
head containing only the function signature to properly handle PEP695
type parameter scopes.
"""

def __init__(self, node: ast.Module | ast.ClassDef | Function, parent: Scope | None = None, is_head: bool = False):
Expand All @@ -438,7 +447,11 @@ def __init__(self, node: ast.Module | ast.ClassDef | Function, parent: Scope | N
# symbols
self.parent = parent

#: For function scopes/class scopes whether it is just for the head or also the body
#: For function scopes whether it is just for the head or also the body
# This is to deal with the fact that defaults and annotations are part
# of the outer scope, but type params are private to the function without
# leaking outside, so there is a thin faux-scope around the head which
# contains just the type params
self.is_head = is_head

#: The name of the class if this is a scope created by a class definition
Expand Down Expand Up @@ -489,6 +502,19 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only:
while parent is not None and parent.class_name is not None:
parent = parent.parent

# we only propagate the use to the outer scope if we're a head-scope or
# a class scope, this is to deal with the fact that even if a symbol is
# defined after a function definition, it will still be available inside
# the function. If the function is called before the symbol actually
# exists, then an UnboundLocalError is raised, we can't easily detect
# this case, so there's no point in trying to handle it.
# Inversely, in case of a type checking lookup if we're not using a
# futures import the location does matter even in outer scope, since
# annotations are evaluated immediately, so that's why we're doing
# this inside the `if runtime_only` block.
if not self.is_head and not self.class_name:
use = None

# we're done looking up and didn't find anything
if parent is None:
return None
Expand Down Expand Up @@ -571,6 +597,9 @@ def __init__(
self.typing_cast_aliases: set[str] = set()
self.unquoted_types_in_casts: list[tuple[int, int, str]] = []

#: For tracking which comprehension/IfExp we're currently inside of
self.active_context: Optional[Comprehension | ast.IfExp] = None

@contextmanager
def create_scope(self, node: ast.ClassDef | Function, is_head: bool = True) -> Iterator[Scope]:
"""Create a new scope."""
Expand Down Expand Up @@ -883,10 +912,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
)
)

with self.create_scope(node, is_head=True) as head_scope:
# add PEP695 type parameters to class head scope
with self.create_scope(node) as scope:
# add PEP695 type parameters to class scope
for type_param in getattr(node, 'type_params', ()):
head_scope.symbols[type_param.name].append(
scope.symbols[type_param.name].append(
Symbol(
type_param.name,
type_param.lineno,
Expand All @@ -898,30 +927,27 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
for head_expr in chain(node.bases, node.keywords):
self.visit(head_expr)

with self.create_scope(node, is_head=False):
has_base_classes = node.bases
all_base_classes_ignored = all(
isinstance(base, ast.Name) and base.id in self.pydantic_enabled_baseclass_passlist
for base in node.bases
)
affected_by_pydantic_support = (
self.pydantic_enabled and has_base_classes and not all_base_classes_ignored
)
affected_by_cattrs_support = self.cattrs_enabled and self.is_attrs_class(node)

if affected_by_pydantic_support or affected_by_cattrs_support:
# When pydantic or cattrs support is enabled, treat any class variable
# annotation as being required at runtime. We need to do this, or
# users run the risk of guarding imports to resources that actually are
# required at runtime. This can be pretty scary, since it will crashes
# the application at runtime.
for element in node.body:
if isinstance(element, ast.AnnAssign):
self.visit(element.annotation)

for stmt in node.body:
self.visit(stmt)
return node
has_base_classes = node.bases
all_base_classes_ignored = all(
isinstance(base, ast.Name) and base.id in self.pydantic_enabled_baseclass_passlist
for base in node.bases
)
affected_by_pydantic_support = self.pydantic_enabled and has_base_classes and not all_base_classes_ignored
affected_by_cattrs_support = self.cattrs_enabled and self.is_attrs_class(node)

if affected_by_pydantic_support or affected_by_cattrs_support:
# When pydantic or cattrs support is enabled, treat any class variable
# annotation as being required at runtime. We need to do this, or
# users run the risk of guarding imports to resources that actually are
# required at runtime. This can be pretty scary, since it will crashes
# the application at runtime.
for element in node.body:
if isinstance(element, ast.AnnAssign):
self.visit(element.annotation)

for stmt in node.body:
self.visit(stmt)
return node

def visit_Name(self, node: ast.Name) -> ast.Name:
"""Map names."""
Expand Down Expand Up @@ -1040,6 +1066,9 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign:
in_type_checking_block = self.in_type_checking_block(node.lineno, node.col_offset)

for target in node.targets:
# each target can either be an ast.Name or an ast.Tuple/ast.List containing
# ast.Names, but there's also assignments to ast.Subscript/ast.Attribute, we
# only need to record new symbols for ast.Name
for name in getattr(target, 'elts', [target]):
if not hasattr(name, 'id'):
continue
Expand Down Expand Up @@ -1232,13 +1261,116 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
super().visit_AsyncFunctionDef(node)
self.register_function_annotations(node)

self.visit(node.args)

def visit_Lambda(self, node: ast.Lambda) -> None:
"""Remove and map argument symbols."""
with self.create_scope(node, is_head=True), self.create_scope(node, is_head=False):
self.register_function_annotations(node)

@contextmanager
def set_context(self, node: Comprehension | ast.IfExp) -> Iterator[None]:
"""
Set the active context for ast.NamedExpr/ast.comprehension.
This is to deal with the fact that comprehensions and ast.IfExp are
evaluated out of order, so in order for our symbol lookups to be a
little bit more accurate we need to attach declarations to the active
context, rather than the node itself.
"""
old_context = self.active_context
self.active_context = node
yield
self.active_context = old_context

def visit_ListComp(self, node: ast.ListComp) -> None:
"""Map symbols in list comprehension."""
with self.set_context(node):
self.generic_visit(node)

def visit_SetComp(self, node: ast.SetComp) -> None:
"""Map symbols in set comprehension."""
with self.set_context(node):
self.generic_visit(node)

def visit_DictComp(self, node: ast.DictComp) -> None:
"""Map symbols in dict comprehension."""
with self.set_context(node):
self.generic_visit(node)

def visit_GeneratorExp(self, node: ast.GeneratorExp) -> None:
"""Map symbols in generator expressions."""
with self.set_context(node):
self.generic_visit(node)

def visit_comprehension(self, node: ast.comprehension) -> None:
"""
Map all the symbols in a comprehension.
Comprehensions are a bit of a special case, since the expressions
are evaluated out of order, which complicates the symbol lookup.
We get around that by attaching all targets and all NamedExpr to
the comprehesion rather than themselves. So everyone inside the
comprehension can see the symbols.
This is technically not quite correct, since inside an individual
if expression the order of symbols still matters. But we don't try
to catch every single case here, we just use this to figure out
if type checking symbols are used at runtime, so it's fine if we're
a little lax here, since there are no annotations inside comprehensions
anyways.
"""
in_type_checking_block = self.in_type_checking_block(node.lineno, node.col_offset)

assert self.active_context is not None
for name in getattr(node.target, 'elts', [node.target]):
if not hasattr(name, 'id'):
continue

self.current_scope.symbols[name.id].append(
Symbol(
name.id,
# these symbols can be used in elt/key/value even though
# those appear before the comprehension, so we use the
# start of the expression as the location of the definition
self.active_context.lineno,
self.active_context.col_offset,
'definition',
in_type_checking_block=in_type_checking_block,
)
)

self.visit(node.iter)
for if_expr in node.ifs:
self.visit(if_expr)

def visit_IfExp(self, node: ast.IfExp) -> ast.IfExp:
"""Set the context for named expressions."""
with self.set_context(node):
self.generic_visit(node)
return node

def visit_NamedExpr(self, node: ast.NamedExpr) -> ast.NamedExpr:
"""
Keep track of variable definitions.
If we're inside a comprehension/IfExp then we treat definitions as if
they occured at the start of the expression to deal with the out of
order evaluation of comprehensions and if expressions.
"""
location_node = self.active_context or node
self.current_scope.symbols[node.target.id].append(
Symbol(
node.target.id,
location_node.lineno,
location_node.col_offset,
'definition',
in_type_checking_block=self.in_type_checking_block(node.lineno, node.col_offset),
)
)
self.visit(node.value)

return node

def register_unquoted_type_in_typing_cast(self, node: ast.Call) -> None:
"""Find typing.cast() calls with the type argument unquoted."""
func = node.func
Expand Down
1 change: 1 addition & 0 deletions flake8_type_checking/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Generator, Optional, Protocol, Tuple, Union

Function = Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda]
Comprehension = Union[ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp]
Import = Union[ast.Import, ast.ImportFrom]
Flake8Generator = Generator[Tuple[int, int, str, Any], None, None]

Expand Down

0 comments on commit 48fbeda

Please sign in to comment.