Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polymorphic inference: support for parameter specifications and lambdas #15837

Merged
merged 15 commits into from
Aug 15, 2023
Next Next commit
Add basic support for polymorphic infernce with ParamSpec
  • Loading branch information
Ivan Levkivskyi committed Aug 9, 2023
commit 9b4eac81eed87577388474559a00bd72ed09bd80
43 changes: 34 additions & 9 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
false_only,
fixup_partial_type,
function_type,
get_all_type_vars,
get_type_vars,
is_literal_type_like,
make_simplified_union,
Expand All @@ -145,6 +146,7 @@
LiteralValue,
NoneType,
Overloaded,
Parameters,
ParamSpecFlavor,
ParamSpecType,
PartialType,
Expand All @@ -167,6 +169,7 @@
get_proper_types,
has_recursive_types,
is_named_instance,
remove_dups,
split_with_prefix_and_suffix,
)
from mypy.types_utils import (
Expand Down Expand Up @@ -5632,18 +5635,24 @@ def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None:
self.bound_tvars: set[TypeVarLikeType] = set()
self.seen_aliases: set[TypeInfo] = set()

def visit_callable_type(self, t: CallableType) -> Type:
found_vars = set()
def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
found_vars = []
for arg in t.arg_types:
found_vars |= set(get_type_vars(arg)) & self.poly_tvars
found_vars += [
tv
for tv in get_all_type_vars(arg)
if tv in self.poly_tvars and tv not in self.bound_tvars
]
return remove_dups(found_vars)

found_vars -= self.bound_tvars
self.bound_tvars |= found_vars
def visit_callable_type(self, t: CallableType) -> Type:
found_vars = self.collect_vars(t)
self.bound_tvars |= set(found_vars)
result = super().visit_callable_type(t)
self.bound_tvars -= found_vars
self.bound_tvars -= set(found_vars)

assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.variables = list(result.variables) + list(found_vars)
result.variables = list(result.variables) + found_vars
return result

def visit_type_var(self, t: TypeVarType) -> Type:
Expand All @@ -5652,8 +5661,9 @@ def visit_type_var(self, t: TypeVarType) -> Type:
return super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> Type:
# TODO: Support polymorphic apply for ParamSpec.
raise PolyTranslationError()
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_param_spec(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
# TODO: Support polymorphic apply for TypeVarTuple.
Expand All @@ -5669,6 +5679,21 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type:
raise PolyTranslationError()

def visit_instance(self, t: Instance) -> Type:
if t.type.has_param_spec_type:
param_spec_index = next(
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType)
)
p = get_proper_type(t.args[param_spec_index])
if isinstance(p, Parameters):
found_vars = self.collect_vars(p)
self.bound_tvars |= set(found_vars)
new_args = [a.accept(self) for a in t.args]
self.bound_tvars -= set(found_vars)

repl = new_args[param_spec_index]
assert isinstance(repl, ProperType) and isinstance(repl, Parameters)
repl.variables = list(repl.variables) + list(found_vars)
return t.copy_modified(args=new_args)
# There is the same problem with callback protocols as with aliases
# (callback protocols are essentially more flexible aliases to callables).
# Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T].
Expand Down
31 changes: 24 additions & 7 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
cactual = self.actual.with_unpacked_kwargs()
param_spec = template.param_spec()
if param_spec is None:
# FIX verify argument counts
# TODO: verify argument counts; more generally, use some "formal to actual" map
# TODO: Erase template variables if it is generic?
if (
type_state.infer_polymorphic
Expand Down Expand Up @@ -943,34 +943,52 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
cactual_args = cactual.arg_types
# The lengths should match, but don't crash (it will error elsewhere).
for t, a in zip(template_args, cactual_args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the TODO referring to something like #15320 (...which I still need to fix)? That modifies the subtypes.py file but that sounds like what you're referring to...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is about #702 (wow, three-digit issue). I actually wanted to work on this at some point soon. I am not sure I will be able to handle all case, but I guess it should be relatively straightforward to cover ~95% of cases.

if isinstance(a, ParamSpecType) and not isinstance(t, ParamSpecType):
# This avoids bogus constraints like T <: P.args
# TODO: figure out a more principled way to skip arg_kind mismatch
# (see also a similar to do item in corresponding branch below)
continue
# Negate direction due to function argument type contravariance.
res.extend(infer_constraints(t, a, neg_op(self.direction)))
else:
# sometimes, it appears we try to get constraints between two paramspec callables?

# TODO: Direction
# TODO: check the prefixes match
prefix = param_spec.prefix
prefix_len = len(prefix.arg_types)
cactual_ps = cactual.param_spec()

if type_state.infer_polymorphic and cactual.variables and not self.skip_neg_op:
# Similar logic to the branch above.
res.extend(
infer_constraints(
cactual, template, neg_op(self.direction), skip_neg_op=True
)
)
extra_tvars = True

if not cactual_ps:
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
prefix_len = min(prefix_len, max_prefix_len)
res.append(
Constraint(
param_spec,
SUBTYPE_OF,
cactual.copy_modified(
neg_op(self.direction),
Parameters(
arg_types=cactual.arg_types[prefix_len:],
arg_kinds=cactual.arg_kinds[prefix_len:],
arg_names=cactual.arg_names[prefix_len:],
ret_type=UninhabitedType(),
variables=cactual.variables
if not type_state.infer_polymorphic
else [],
),
)
)
else:
res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps))
if not param_spec.prefix.arg_types or cactual_ps.prefix.arg_types:
# TODO: figure out a more general logic to reject shorter prefix in actual.
# This may be actually fixed by a more general to do item above.
res.append(Constraint(param_spec, neg_op(self.direction), cactual_ps))

# compare prefixes
cactual_prefix = cactual.copy_modified(
Expand All @@ -979,7 +997,6 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
arg_names=cactual.arg_names[:prefix_len],
)

# TODO: see above "FIX" comments for param_spec is None case
# TODO: this assumes positional arguments
for t, a in zip(prefix.arg_types, cactual_prefix.arg_types):
res.extend(infer_constraints(t, a, neg_op(self.direction)))
Expand Down
4 changes: 2 additions & 2 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mypy.join import join_types
from mypy.meet import meet_type_list, meet_types
from mypy.subtypes import is_subtype
from mypy.typeops import get_type_vars
from mypy.typeops import get_all_type_vars
from mypy.types import (
AnyType,
Instance,
Expand Down Expand Up @@ -463,4 +463,4 @@ def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool:

def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]:
"""Find type variables for which we are solving in a target type."""
return {tv.id for tv in get_type_vars(target)} & set(vars)
return {tv.id for tv in get_all_type_vars(target)} & set(vars)
19 changes: 14 additions & 5 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,22 +952,31 @@ def coerce_to_literal(typ: Type) -> Type:


def get_type_vars(tp: Type) -> list[TypeVarType]:
return tp.accept(TypeVarExtractor())
return cast("list[TypeVarType]", tp.accept(TypeVarExtractor()))


class TypeVarExtractor(TypeQuery[List[TypeVarType]]):
def __init__(self) -> None:
def get_all_type_vars(tp: Type) -> list[TypeVarLikeType]:
# TODO: should we always use this function instead of get_type_vars() above?
return tp.accept(TypeVarExtractor(include_all=True))


class TypeVarExtractor(TypeQuery[List[TypeVarLikeType]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a visit function for TypeVarTuple too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be in a dedicated PR for TypeVarTuple (which will be last PR in this series, well to be precise in this season, there are still few things left, I will open follow up issues for them).

def __init__(self, include_all: bool = False) -> None:
super().__init__(self._merge)
self.include_all = include_all

def _merge(self, iter: Iterable[list[TypeVarType]]) -> list[TypeVarType]:
def _merge(self, iter: Iterable[list[TypeVarLikeType]]) -> list[TypeVarLikeType]:
out = []
for item in iter:
out.extend(item)
return out

def visit_type_var(self, t: TypeVarType) -> list[TypeVarType]:
def visit_type_var(self, t: TypeVarType) -> list[TypeVarLikeType]:
return [t]

def visit_param_spec(self, t: ParamSpecType) -> list[TypeVarLikeType]:
return [t] if self.include_all else []


def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool:
"""Does this type have a custom special method such as __format__() or __eq__()?
Expand Down
36 changes: 36 additions & 0 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -3035,3 +3035,39 @@ reveal_type(dec1(id2)) # N: Revealed type is "def [S in (builtins.int, builtins
reveal_type(dec2(id1)) # N: Revealed type is "def [UC <: __main__.C] (UC`5) -> builtins.list[UC`5]"
reveal_type(dec2(id2)) # N: Revealed type is "def (<nothing>) -> builtins.list[<nothing>]" \
# E: Argument 1 to "dec2" has incompatible type "Callable[[V], V]"; expected "Callable[[<nothing>], <nothing>]"

[case testInferenceAgainstGenericParamSpecBasicInList]
# flags: --new-type-inference
from typing import TypeVar, Callable, List, Tuple
from typing_extensions import ParamSpec

T = TypeVar('T')
P = ParamSpec('P')
U = TypeVar('U')
V = TypeVar('V')

def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ...
def id(x: U) -> U: ...
def either(x: U, y: U) -> U: ...
def pair(x: U, y: V) -> Tuple[U, V]: ...
reveal_type(dec(id)) # N: Revealed type is "def [T] (x: T`2) -> builtins.list[T`2]"
reveal_type(dec(either)) # N: Revealed type is "def [T] (x: T`4, y: T`4) -> builtins.list[T`4]"
reveal_type(dec(pair)) # N: Revealed type is "def [U, V] (x: U`-1, y: V`-2) -> builtins.list[Tuple[U`-1, V`-2]]"
Comment on lines +3086 to +3088
Copy link
Contributor

@A5rocks A5rocks Aug 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pointed this out before but I'll move that comment here cause it's significantly more visible here ^^:

I don't really like how the typevars switched from a negative to positive id. While I can't imagine this has any negative (heh) impact, it annoys the pedant in me. (Specifically, positive ids are supposedly for classes, which these are not).

But moreso, here specifically, I can see that we're treating directly held typevar (the return type) differently to a typevar stuffed inside a generic type (Tuple[U, V]) (...though that's not an Instance). This feels like special casing that might bite later. Maybe it's fine?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is (much) more complicated than that. Positive ids are also generated by new_unification_variable(). Also there is such things as meta-level (and it is 1 for the unification variables, while 0 for regular TypeVars, i.e. those that are bound in current scope). Also we use freshen/freeze cycles for reasons other than inference, etc.

This is all indeed unnecessary complicated, but this PR doesn't really adds much to it, and cleaning it up would be quite tricky for a very modest benefit. There is one known problem caused by using these semi-global numeric ids -- accidental id clashes (and half of the complications are to avoid them), but it seems that the consensus is that best solution to fix this is to introduce namespaces, like we already did for class type variables.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, alright. I was basing my assumption based on the docstring for TypeVarId xd

[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericParamSpecBasicDeList]
# flags: --new-type-inference
from typing import TypeVar, Callable, List, Tuple
from typing_extensions import ParamSpec

T = TypeVar('T')
P = ParamSpec('P')
U = TypeVar('U')
V = TypeVar('V')

def dec(f: Callable[P, List[T]]) -> Callable[P, T]: ...
def id(x: U) -> U: ...
def either(x: U, y: U) -> U: ...
reveal_type(dec(id)) # N: Revealed type is "def [T] (x: builtins.list[T`2]) -> T`2"
reveal_type(dec(either)) # N: Revealed type is "def [T] (x: builtins.list[T`4], y: builtins.list[T`4]) -> T`4"
[builtins fixtures/list.pyi]
4 changes: 2 additions & 2 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6456,7 +6456,7 @@ P = ParamSpec("P")
R = TypeVar("R")

@overload
def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to check why this error disappeared. I will take a look at it if you think it is important.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how important this is, though the error is correct. Maybe (I haven't looked through the code yet) you're saying the P here is the same as the P in the other overload?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the mypy_primer output, I see there is something suspicious going on with overloads, so I will dig a bit deeper into this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I fixed one of the new overload errors in mypy_primer, this one however looks quite tricky to bring back. But note that when I enable --new-type-inference (also during unification manually, that flag doesn't control it) the error re-appears. So I guess we can just wait until it will be on by default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this doesn't look like a serious regression.

def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ...
@overload
def func(x: Callable[P, R]) -> Callable[Concatenate[str, P], R]: ...
def func(x: Callable[..., R]) -> Callable[..., R]: ...
Expand All @@ -6474,7 +6474,7 @@ eggs = lambda: 'eggs'
reveal_type(func(eggs)) # N: Revealed type is "def (builtins.str) -> builtins.str"

spam: Callable[..., str] = lambda x, y: 'baz'
reveal_type(func(spam)) # N: Revealed type is "def (*Any, **Any) -> builtins.str"
reveal_type(func(spam)) # N: Revealed type is "def (*Any, **Any) -> Any"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised about this change (just purely looking at output): mypy seems to think spam is going to be using a single overload of func, and both return the return type unchanged. How is str turning into Any...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one I understand, there is a rule for overloads that if

  • arguments match more than one overload variant
  • the multiple match is caused by Any types in argument types
  • return types in matched overloads are not all equivalent

then the inferred return type becomes Any. IIUC all three conditions are satisfied in this case, so it is probably OK.


[builtins fixtures/paramspec.pyi]

Expand Down
21 changes: 14 additions & 7 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1048,10 +1048,10 @@ class Job(Generic[_P, _T]):
def generic_f(x: _T) -> _T: ...

j = Job(generic_f)
reveal_type(j) # N: Revealed type is "__main__.Job[[x: _T`-1], _T`-1]"
reveal_type(j) # N: Revealed type is "__main__.Job[[x: _T`2], _T`2]"

jf = j.into_callable()
reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`-1) -> _T`-1"
reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`2) -> _T`2"
reveal_type(jf(1)) # N: Revealed type is "builtins.int"
[builtins fixtures/paramspec.pyi]

Expand Down Expand Up @@ -1307,15 +1307,18 @@ reveal_type(bar(C(fn=foo, x=1))) # N: Revealed type is "__main__.C[[x: builtins
[builtins fixtures/paramspec.pyi]

[case testParamSpecClassConstructor]
from typing import ParamSpec, Callable
from typing import ParamSpec, Callable, TypeVar

P = ParamSpec("P")

class SomeClass:
def __init__(self, a: str) -> None:
pass

def func(t: Callable[P, SomeClass], val: Callable[P, SomeClass]) -> None:
def func(t: Callable[P, SomeClass], val: Callable[P, SomeClass]) -> Callable[P, SomeClass]:
pass

def func_regular(t: Callable[[T], SomeClass], val: Callable[[T], SomeClass]) -> Callable[[T], SomeClass]:
pass

def constructor(a: str) -> SomeClass:
Expand All @@ -1324,9 +1327,13 @@ def constructor(a: str) -> SomeClass:
def wrong_constructor(a: bool) -> SomeClass:
return SomeClass("a")

def wrong_name_constructor(b: bool) -> SomeClass:
return SomeClass("a")

func(SomeClass, constructor)
func(SomeClass, wrong_constructor) # E: Argument 1 to "func" has incompatible type "Type[SomeClass]"; expected "Callable[[VarArg(<nothing>), KwArg(<nothing>)], SomeClass]" \
# E: Argument 2 to "func" has incompatible type "Callable[[bool], SomeClass]"; expected "Callable[[VarArg(<nothing>), KwArg(<nothing>)], SomeClass]"
reveal_type(func(SomeClass, wrong_constructor)) # N: Revealed type is "def (a: <nothing>) -> __main__.SomeClass"
reveal_type(func_regular(SomeClass, wrong_constructor)) # N: Revealed type is "def (<nothing>) -> __main__.SomeClass"
func(SomeClass, wrong_name_constructor) # E: Argument 1 to "func" has incompatible type "Type[SomeClass]"; expected "Callable[[<nothing>], SomeClass]"
[builtins fixtures/paramspec.pyi]

[case testParamSpecInTypeAliasBasic]
Expand Down Expand Up @@ -1547,5 +1554,5 @@ U = TypeVar("U")
def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ...
def test(x: U) -> U: ...
reveal_type(dec) # N: Revealed type is "def [P, T] (f: def (*P.args, **P.kwargs) -> T`-2) -> def (*P.args, **P.kwargs) -> builtins.list[T`-2]"
reveal_type(dec(test)) # N: Revealed type is "def [U] (x: U`-1) -> builtins.list[U`-1]"
reveal_type(dec(test)) # N: Revealed type is "def [T] (x: T`2) -> builtins.list[T`2]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a bit kinda like this?:

class A: ...
TA = TypeVar("T", bound=A)

def test_with_bound(x: TA) -> TA: ...
reveal_type(dec(test_with_bound))
dec(test_with_bound)(0)  # should error
dec(test_with_bound)(A())  # should be AOK

? I'm a bit concerned about the replacing of type variables here, though I do remember seeing something in an earlier PR about updating bounds so that's probably already handled.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, upper bounds should be ~decently handled. I added the test.

[builtins fixtures/paramspec.pyi]