Skip to content

Commit

Permalink
refactor: Use lookup_full_name for detecting validate_arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Daverball committed Dec 13, 2024
1 parent c81ff55 commit 85e5586
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,17 @@ class PydanticMixin:

if TYPE_CHECKING:
pydantic_enabled: bool
pydantic_validate_arguments_import_name: str | None

def visit(self, node: ast.AST) -> ast.AST: # noqa: D102
...

def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102
...

def _function_is_wrapped_by_validate_arguments(self, node: FunctionDef | AsyncFunctionDef) -> bool:
if self.pydantic_enabled and node.decorator_list:
for decorator_node in node.decorator_list:
if getattr(decorator_node, 'id', '') == self.pydantic_validate_arguments_import_name:
if self.lookup_full_name(decorator_node) == 'pydantic.validate_arguments':
return True
return False

Expand Down Expand Up @@ -1043,7 +1045,6 @@ def __init__(
)
self.injector_enabled = injector_enabled
self.pydantic_enabled_baseclass_passlist = pydantic_enabled_baseclass_passlist
self.pydantic_validate_arguments_import_name = None
self.cwd = cwd # we need to know the current directory to guess at which imports are remote and which are not

#: A list of modules that re-export symbols from the typing module
Expand Down Expand Up @@ -1365,14 +1366,6 @@ def add_import(self, node: Import) -> None: # noqa: C901
# Skip checking the import if the module is passlisted
exempt = all_exempt or (isinstance(node, ast.Import) and self.is_exempt_module(name_node.name))

# Look for pydantic.validate_arguments import
# TODO: Switch to using lookup_full_name instead
if name_node.name == 'validate_arguments':
if name_node.asname is not None:
self.pydantic_validate_arguments_import_name = name_node.asname
else:
self.pydantic_validate_arguments_import_name = name_node.name

if name_node.name == '*':
# don't record * imports
continue
Expand Down

0 comments on commit 85e5586

Please sign in to comment.