Skip to content

Commit

Permalink
Adds TC009 for type checking declarations used at runtime
Browse files Browse the repository at this point in the history
Extends TC100/TC200 to deal with type checking declarations
Avoids a couple of false positives in TC100/TC101
Avoids contradicting TC004/TC009 errors vs. TC100/TC200 errors
  • Loading branch information
Daverball authored and sondrelg committed Nov 25, 2023
1 parent 7aa077e commit 807cc3e
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 30 deletions.
102 changes: 79 additions & 23 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TC006,
TC007,
TC008,
TC009,
TC100,
TC101,
TC200,
Expand All @@ -53,6 +54,7 @@ def ast_unparse(node: ast.AST) -> str:
from typing import Any, Optional, Union

from flake8_type_checking.types import (
Declaration,
Flake8Generator,
FunctionRangesDict,
FunctionScopeNamesDict,
Expand Down Expand Up @@ -425,10 +427,11 @@ def __init__(
# This lets us identify imports that *are* needed at runtime, for TC004 errors.
self.type_checking_block_imports: set[tuple[Import, str]] = set()

#: Set of variable names for all declarations defined within a type-checking block
#: Tuple of (node, variable name) for all global declarations within a type-checking block
# This lets us avoid false positives for annotations referring to e.g. a TypeAlias
# defined within a type checking block
self.type_checking_block_declarations: set[str] = set()
# defined within a type checking block. We currently ignore function definitions, since
# those should be exceedingly rare inside type checking blocks.
self.type_checking_block_declarations: set[tuple[Declaration, str]] = set()

#: Set of all the class names defined within the file
# This lets us avoid false positives for classes referring to themselves
Expand Down Expand Up @@ -492,6 +495,11 @@ def names(self) -> set[str]:
"""Return unique names."""
return set(self.uses.keys())

@property
def type_checking_names(self) -> set[str]:
"""Return unique names either imported or declared in type checking blocks."""
return {name for _, name in chain(self.type_checking_block_imports, self.type_checking_block_declarations)}

# -- Map type checking block ---------------

def in_type_checking_block(self, lineno: int, col_offset: int) -> bool:
Expand Down Expand Up @@ -775,7 +783,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
if isinstance(element, ast.AnnAssign):
self.visit(element.annotation)

self.class_names.add(node.name)
if getattr(node, GLOBAL_PROPERTY, False) and self.in_type_checking_block(node.lineno, node.col_offset):
self.type_checking_block_declarations.add((node, node.name))
else:
self.class_names.add(node.name)
self.generic_visit(node)
return node

Expand Down Expand Up @@ -868,7 +879,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
self.add_annotation(node.value, 'alias')

if getattr(node, GLOBAL_PROPERTY, False) and self.in_type_checking_block(node.lineno, node.col_offset):
self.type_checking_block_declarations.add(node.target.id)
self.type_checking_block_declarations.add((node, node.target.id))

# if it wasn't a TypeAlias we need to visit the value expression
else:
Expand All @@ -887,7 +898,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign:
and isinstance(node.targets[0], ast.Name)
and self.in_type_checking_block(node.lineno, node.col_offset)
):
self.type_checking_block_declarations.add(node.targets[0].id)
self.type_checking_block_declarations.add((node, node.targets[0].id))

super().visit_Assign(node)
return node
Expand All @@ -907,7 +918,7 @@ def visit_TypeAlias(self, node: ast.TypeAlias) -> None:
self.add_annotation(node.value, 'new-alias')

if getattr(node, GLOBAL_PROPERTY, False) and self.in_type_checking_block(node.lineno, node.col_offset):
self.type_checking_block_declarations.add(node.name.id)
self.type_checking_block_declarations.add((node, node.name.id))

def register_function_ranges(self, node: Union[FunctionDef, AsyncFunctionDef]) -> None:
"""
Expand Down Expand Up @@ -1105,6 +1116,8 @@ def __init__(self, node: ast.Module, options: Optional[Namespace]) -> None:
self.empty_type_checking_blocks,
# TC006
self.unquoted_type_in_cast,
# TC009
self.used_type_checking_declarations,
# TC100
self.missing_futures_import,
# TC101
Expand Down Expand Up @@ -1184,13 +1197,54 @@ def unquoted_type_in_cast(self) -> Flake8Generator:
for lineno, col_offset, annotation in self.visitor.unquoted_types_in_casts:
yield lineno, col_offset, TC006.format(annotation=annotation), None

def used_type_checking_declarations(self) -> Flake8Generator:
"""TC009."""
for decl, decl_name in self.visitor.type_checking_block_declarations:
if decl_name in self.visitor.uses:
# If we get to here, we're pretty sure that the declaration
# shouldn't actually live inside a type-checking block

use = self.visitor.uses[decl_name]

# .. or whether one of the argument names shadows a declaration
use_in_function = False
if use.lineno in self.visitor.function_ranges:
for i in range(
self.visitor.function_ranges[use.lineno]['start'],
self.visitor.function_ranges[use.lineno]['end'],
):
if (
i in self.visitor.function_scope_names
and decl_name in self.visitor.function_scope_names[i]['names']
):
use_in_function = True
break

if not use_in_function:
yield decl.lineno, decl.col_offset, TC009.format(name=decl_name), None

def missing_futures_import(self) -> Flake8Generator:
"""TC100."""
if (
not self.visitor.futures_annotation
and {name for _, name in self.visitor.type_checking_block_imports} - self.visitor.names
):
if self.visitor.futures_annotation:
return

# if all the symbols imported/declared in type checking blocks are used
# at runtime, then we're covered by TC004
unused_type_checking_names = self.visitor.type_checking_names - self.visitor.names
if not unused_type_checking_names:
return

# if any of the symbols imported/declared in type checking blocks are used
# in an annotation outside a type checking block, then we need to emit TC100
for item in self.visitor.unwrapped_annotations:
if item.annotation not in unused_type_checking_names:
continue

if self.visitor.in_type_checking_block(item.lineno, item.col_offset):
continue

yield 1, 0, TC100, None
return

def futures_excess_quotes(self) -> Flake8Generator:
"""TC101."""
Expand Down Expand Up @@ -1233,11 +1287,13 @@ def futures_excess_quotes(self) -> Flake8Generator:
So we don't try to unwrap the annotations as far as possible, we just check if the entire
annotation can be unwrapped or not.
"""
type_checking_names = self.visitor.type_checking_names

for item in self.visitor.wrapped_annotations:
if item.type != 'annotation': # TypeAlias value will not be affected by a futures import
continue

if any(import_name in item.names for _, import_name in self.visitor.type_checking_block_imports):
if not item.names.isdisjoint(type_checking_names):
continue

if any(class_name in item.names for class_name in self.visitor.class_names):
Expand All @@ -1247,6 +1303,8 @@ def futures_excess_quotes(self) -> Flake8Generator:

def missing_quotes(self) -> Flake8Generator:
"""TC200 and TC007."""
unused_type_checking_names = self.visitor.type_checking_names - self.visitor.names

for item in self.visitor.unwrapped_annotations:
# A new style alias does never need to be wrapped
if item.type == 'new-alias':
Expand All @@ -1258,16 +1316,17 @@ def missing_quotes(self) -> Flake8Generator:
if self.visitor.in_type_checking_block(item.lineno, item.col_offset):
continue

for _, name in self.visitor.type_checking_block_imports:
if item.annotation == name:
if item.type == 'alias':
error = TC007.format(alias=item.annotation)
else:
error = TC200.format(annotation=item.annotation)
yield item.lineno, item.col_offset, error, None
if item.annotation in unused_type_checking_names:
if item.type == 'alias':
error = TC007.format(alias=item.annotation)
else:
error = TC200.format(annotation=item.annotation)
yield item.lineno, item.col_offset, error, None

def excess_quotes(self) -> Flake8Generator:
"""TC201 and TC008."""
type_checking_names = self.visitor.type_checking_names

for item in self.visitor.wrapped_annotations:
# A new style type alias should never be wrapped
if item.type == 'new-alias':
Expand All @@ -1284,15 +1343,12 @@ def excess_quotes(self) -> Flake8Generator:
continue

# See comment in futures_excess_quotes
if any(import_name in item.names for _, import_name in self.visitor.type_checking_block_imports):
if not item.names.isdisjoint(type_checking_names):
continue

if any(class_name in item.names for class_name in self.visitor.class_names):
continue

if any(variable_name in item.names for variable_name in self.visitor.type_checking_block_declarations):
continue

if item.type == 'alias':
error = TC008.format(alias=item.annotation)
else:
Expand Down
1 change: 1 addition & 0 deletions flake8_type_checking/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TC006 = "TC006 Annotation '{annotation}' in typing.cast() should be a string literal"
TC007 = "TC007 Type alias '{alias}' needs to be made into a string literal"
TC008 = "TC008 Type alias '{alias}' does not need to be a string literal"
TC009 = "TC009 Move declaration '{name}' out of type-checking block. Variable is used for more than type hinting."
TC100 = "TC100 Add 'from __future__ import annotations' import"
TC101 = "TC101 Annotation '{annotation}' does not need to be a string literal"
TC200 = "TC200 Annotation '{annotation}' needs to be made into a string literal"
Expand Down
5 changes: 5 additions & 0 deletions flake8_type_checking/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

if TYPE_CHECKING:
import ast
import sys
from typing import Any, Generator, Optional, Protocol, Tuple, TypedDict, Union

class FunctionRangesDict(TypedDict):
Expand All @@ -13,6 +14,10 @@ class FunctionRangesDict(TypedDict):
class FunctionScopeNamesDict(TypedDict):
names: list[str]

if sys.version_info >= (3, 12):
Declaration = Union[ast.ClassDef, ast.AnnAssign, ast.Assign, ast.TypeAlias]
else:
Declaration = Union[ast.ClassDef, ast.AnnAssign, ast.Assign]
Import = Union[ast.Import, ast.ImportFrom]
Flake8Generator = Generator[Tuple[int, int, str, Any], None, None]

Expand Down
138 changes: 138 additions & 0 deletions tests/test_tc009.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
This file tests the TC009 error:
>> Move declaration out of type-checking block. Variable is used for more than type hinting.
"""
import sys
import textwrap

import pytest

from flake8_type_checking.constants import TC009
from tests.conftest import _get_error

examples = [
# No error
('', set()),
# Used in file
(
textwrap.dedent("""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
datetime = Any
x = datetime
"""),
{'5:4 ' + TC009.format(name='datetime')},
),
# Used in function
(
textwrap.dedent("""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
class date: ...
def example():
return date()
"""),
{'5:4 ' + TC009.format(name='date')},
),
# Used, but only used inside the type checking block
(
textwrap.dedent("""
if TYPE_CHECKING:
class date: ...
CustomType = date
"""),
set(),
),
# Used for typing only
(
textwrap.dedent("""
if TYPE_CHECKING:
class date: ...
def example(*args: date, **kwargs: date):
return
my_type: Type[date] | date
"""),
set(),
),
(
textwrap.dedent("""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
class AsyncIterator: ...
class Example:
async def example(self) -> AsyncIterator[list[str]]:
yield 0
"""),
set(),
),
(
textwrap.dedent("""
from typing import TYPE_CHECKING
from weakref import WeakKeyDictionary
if TYPE_CHECKING:
Any = str
d = WeakKeyDictionary["Any", "Any"]()
"""),
set(),
),
(
textwrap.dedent("""
if TYPE_CHECKING:
a = int
b: TypeAlias = str
class c(Protocol): ...
class d(TypedDict): ...
def test_function(a, /, b, *, c, **d):
print(a, b, c, d)
"""),
set(),
),
]

if sys.version_info >= (3, 12):
examples.append(
(
textwrap.dedent("""
if TYPE_CHECKING:
type Foo = int
x = Foo
"""),
{'3:4 ' + TC009.format(name='Foo')},
)
)
examples.append(
(
textwrap.dedent("""
if TYPE_CHECKING:
type Foo = int
x: Foo
"""),
set(),
)
)


@pytest.mark.parametrize(('example', 'expected'), examples)
def test_TC009_errors(example, expected):
assert _get_error(example, error_code_filter='TC009') == expected
Loading

0 comments on commit 807cc3e

Please sign in to comment.