Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attrs.evolve: support generics and unions #15050

Merged
merged 4 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
squashed
  • Loading branch information
ikonst committed Apr 14, 2023
commit 8c90ab37893fff8503e53431ab2af0dfdf15465d
107 changes: 84 additions & 23 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

from __future__ import annotations

from typing import Iterable, List, cast
from collections import defaultdict
from functools import reduce
from typing import Iterable, List, Mapping, cast
from typing_extensions import Final, Literal

import mypy.plugin # To avoid circular imports.
from mypy.applytype import apply_generic_arguments
from mypy.checker import TypeChecker
from mypy.errorcodes import LITERAL_REQ
from mypy.expandtype import expand_type
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
from mypy.meet import meet_types
from mypy.messages import format_type_bare
from mypy.nodes import (
ARG_NAMED,
Expand Down Expand Up @@ -67,6 +70,7 @@
Type,
TypeOfAny,
TypeVarType,
UninhabitedType,
UnionType,
get_proper_type,
)
Expand Down Expand Up @@ -943,12 +947,78 @@ def _get_attrs_init_type(typ: Instance) -> CallableType | None:
return init_method.type


def _get_attrs_cls_and_init(typ: ProperType) -> tuple[Instance | None, CallableType | None]:
def _format_not_attrs_class_failure(t: Type, parent_t: Type) -> str:
t_name = format_type_bare(t)
if parent_t is t:
return (
f'Argument 1 to "evolve" has a variable type "{t_name}" not bound to an attrs class'
if isinstance(t, TypeVarType)
else f'Argument 1 to "evolve" has incompatible type "{t_name}"; expected an attrs class'
)
else:
pt_name = format_type_bare(parent_t)
return (
f'Argument 1 to "evolve" has type "{pt_name}" whose item "{t_name}" is not bound to an attrs class'
if isinstance(t, TypeVarType)
else f'Argument 1 to "evolve" has incompatible type "{pt_name}" whose item "{t_name}" is not an attrs class'
)


def _get_expanded_attr_types(
ctx: mypy.plugin.FunctionSigContext,
typ: ProperType,
display_typ: ProperType,
parent_typ: ProperType,
) -> list[Mapping[str, Type]] | None:
"""
For a given type, determine what attrs classes it can be: for each class, return the field types.
For generic classes, the field types are expanded.
If the type contains Any or a non-attrs type, returns None; in the latter case, also reports an error.
"""
if isinstance(typ, AnyType):
return None
if isinstance(typ, UnionType):
ret: list[Mapping[str, Type]] | None = []
for item in typ.relevant_items():
item = get_proper_type(item)
item_types = _get_expanded_attr_types(ctx, item, item, parent_typ)
if ret is not None and item_types is not None:
ret += item_types
else:
ret = None # but keep iterating to emit all errors
return ret
if isinstance(typ, TypeVarType):
typ = get_proper_type(typ.upper_bound)
return _get_expanded_attr_types(
ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ
)
if not isinstance(typ, Instance):
return None, None
return typ, _get_attrs_init_type(typ)
ctx.api.fail(_format_not_attrs_class_failure(display_typ, parent_typ), ctx.context)
return None
init_func = _get_attrs_init_type(typ)
if init_func is None:
ctx.api.fail(_format_not_attrs_class_failure(display_typ, parent_typ), ctx.context)
return None
init_func = expand_type_by_instance(init_func, typ)
field_names = cast(List[str], init_func.arg_names[1:])
field_types = init_func.arg_types[1:]
return [dict(zip(field_names, field_types))]


def _meet_fields(types: list[Mapping[str, Type]]) -> Mapping[str, Type]:
"""
"Meets" the fields of a list of attrs classes, i.e. for each field, its new type will be the lower bound.
"""
field_to_types = defaultdict(list)
for fields in types:
for name, typ in fields.items():
field_to_types[name].append(typ)

return {
name: get_proper_type(reduce(meet_types, f_types))
if len(f_types) == len(types)
hauntsaninja marked this conversation as resolved.
Show resolved Hide resolved
else UninhabitedType()
for name, f_types in field_to_types.items()
}


def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
Expand All @@ -972,27 +1042,18 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
# </hack>

inst_type = get_proper_type(inst_type)
if isinstance(inst_type, AnyType):
return ctx.default_signature # evolve(Any, ....) -> Any
inst_type_str = format_type_bare(inst_type)

attrs_type, attrs_init_type = _get_attrs_cls_and_init(inst_type)
if attrs_type is None or attrs_init_type is None:
ctx.api.fail(
f'Argument 1 to "evolve" has a variable type "{inst_type_str}" not bound to an attrs class'
if isinstance(inst_type, TypeVarType)
else f'Argument 1 to "evolve" has incompatible type "{inst_type_str}"; expected an attrs class',
ctx.context,
)
attr_types = _get_expanded_attr_types(ctx, inst_type, inst_type, inst_type)
if attr_types is None:
return ctx.default_signature
fields = _meet_fields(attr_types)

# AttrClass.__init__ has the following signature (or similar, if having kw-only & defaults):
# def __init__(self, attr1: Type1, attr2: Type2) -> None:
# We want to generate a signature for evolve that looks like this:
# def evolve(inst: AttrClass, *, attr1: Type1 = ..., attr2: Type2 = ...) -> AttrClass:
return attrs_init_type.copy_modified(
arg_names=["inst"] + attrs_init_type.arg_names[1:],
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT for _ in attrs_init_type.arg_kinds[1:]],
return CallableType(
arg_names=["inst", *fields.keys()],
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT] * len(fields),
arg_types=[inst_type, *fields.values()],
ret_type=inst_type,
fallback=ctx.default_signature.fallback,
name=f"{ctx.default_signature.name} of {inst_type_str}",
)
81 changes: 80 additions & 1 deletion test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,81 @@ reveal_type(ret) # N: Revealed type is "Any"

[typing fixtures/typing-medium.pyi]

[case testEvolveGeneric]
import attrs
from typing import Generic, TypeVar

T = TypeVar('T')

@attrs.define
class A(Generic[T]):
x: T


a = A(x=42)
reveal_type(a) # N: Revealed type is "__main__.A[builtins.int]"
a2 = attrs.evolve(a, x=42)
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
a2 = attrs.evolve(a, x='42') # E: Argument "x" to "evolve" of "A[int]" has incompatible type "str"; expected "int"
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"

[builtins fixtures/attr.pyi]

[case testEvolveUnion]
# flags: --python-version 3.10
from typing import Generic, TypeVar
import attrs

T = TypeVar('T')


@attrs.define
class A(Generic[T]):
x: T # exercises meet(T=int, int) = int
y: bool # exercises meet(bool, int) = bool
z: str # exercises meet(str, bytes) = <nothing>
w: dict # exercises meet(dict, <nothing>) = <nothing>


@attrs.define
class B:
x: int
y: bool
z: bytes


a_or_b: A[int] | B
a2 = attrs.evolve(a_or_b, x=42, y=True)
a2 = attrs.evolve(a_or_b, x=42, y=True, z='42') # E: Argument "z" to "evolve" of "Union[A[int], B]" has incompatible type "str"; expected <nothing>
a2 = attrs.evolve(a_or_b, x=42, y=True, w={}) # E: Argument "w" to "evolve" of "Union[A[int], B]" has incompatible type "Dict[<nothing>, <nothing>]"; expected <nothing>

[builtins fixtures/attr.pyi]

[case testEvolveUnionOfTypeVar]
# flags: --python-version 3.10
import attrs
from typing import TypeVar

@attrs.define
class A:
x: int
y: int
z: str
w: dict


class B:
pass

TA = TypeVar('TA', bound=A)
TB = TypeVar('TB', bound=B)

def f(b_or_t: TA | TB | int) -> None:
a2 = attrs.evolve(b_or_t) # E: Argument 1 to "evolve" has type "Union[TA, TB, int]" whose item "TB" is not bound to an attrs class # E: Argument 1 to "evolve" has incompatible type "Union[TA, TB, int]" whose item "int" is not an attrs class


[builtins fixtures/attr.pyi]

[case testEvolveTypeVarBound]
import attrs
from typing import TypeVar
Expand Down Expand Up @@ -1997,11 +2072,12 @@ f(B(x=42))

[case testEvolveTypeVarBoundNonAttrs]
import attrs
from typing import TypeVar
from typing import Union, TypeVar

TInt = TypeVar('TInt', bound=int)
TAny = TypeVar('TAny')
TNone = TypeVar('TNone', bound=None)
TUnion = TypeVar('TUnion', bound=Union[str, int])

def f(t: TInt) -> None:
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TInt" not bound to an attrs class
Expand All @@ -2012,6 +2088,9 @@ def g(t: TAny) -> None:
def h(t: TNone) -> None:
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TNone" not bound to an attrs class

def x(t: TUnion) -> None:
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has incompatible type "TUnion" whose item "str" is not an attrs class # E: Argument 1 to "evolve" has incompatible type "TUnion" whose item "int" is not an attrs class

[builtins fixtures/attr.pyi]

[case testEvolveTypeVarConstrained]
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/fixtures/attr.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ class object:
class type: pass
class bytes: pass
class function: pass
class bool: pass
class float: pass
class int:
@overload
def __init__(self, x: Union[str, bytes, int] = ...) -> None: ...
@overload
def __init__(self, x: Union[str, bytes], base: int) -> None: ...
class bool(int): pass
class complex:
@overload
def __init__(self, real: float = ..., im: float = ...) -> None: ...
Expand Down