-
-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Polymorphic inference: support for parameter specifications and lambdas #15837
Changes from all commits
9b4eac8
1959136
e796c6a
177b312
420f60d
d7cfbe9
d4c9146
0af630f
c5c1b76
582a4de
4f8afce
2d21032
59963c4
72da8f5
7a87692
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,12 @@ | |
from mypy.checkstrformat import StringFormatterChecker | ||
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars | ||
from mypy.errors import ErrorWatcher, report_internal_error | ||
from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars | ||
from mypy.expandtype import ( | ||
expand_type, | ||
expand_type_by_instance, | ||
freshen_all_functions_type_vars, | ||
freshen_function_type_vars, | ||
) | ||
from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments | ||
from mypy.literals import literal | ||
from mypy.maptype import map_instance_to_supertype | ||
|
@@ -122,6 +127,7 @@ | |
false_only, | ||
fixup_partial_type, | ||
function_type, | ||
get_all_type_vars, | ||
get_type_vars, | ||
is_literal_type_like, | ||
make_simplified_union, | ||
|
@@ -145,6 +151,7 @@ | |
LiteralValue, | ||
NoneType, | ||
Overloaded, | ||
Parameters, | ||
ParamSpecFlavor, | ||
ParamSpecType, | ||
PartialType, | ||
|
@@ -167,6 +174,7 @@ | |
get_proper_types, | ||
has_recursive_types, | ||
is_named_instance, | ||
remove_dups, | ||
split_with_prefix_and_suffix, | ||
) | ||
from mypy.types_utils import ( | ||
|
@@ -1570,6 +1578,16 @@ def check_callable_call( | |
lambda i: self.accept(args[i]), | ||
) | ||
|
||
# This is tricky: return type may contain its own type variables, like in | ||
# def [S] (S) -> def [T] (T) -> tuple[S, T], so we need to update their ids | ||
# to avoid possible id clashes if this call itself appears in a generic | ||
# function body. | ||
ret_type = get_proper_type(callee.ret_type) | ||
if isinstance(ret_type, CallableType) and ret_type.variables: | ||
fresh_ret_type = freshen_all_functions_type_vars(callee.ret_type) | ||
freeze_all_type_vars(fresh_ret_type) | ||
callee = callee.copy_modified(ret_type=fresh_ret_type) | ||
|
||
if callee.is_generic(): | ||
need_refresh = any( | ||
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables | ||
|
@@ -1588,7 +1606,7 @@ def check_callable_call( | |
lambda i: self.accept(args[i]), | ||
) | ||
callee = self.infer_function_type_arguments( | ||
callee, args, arg_kinds, formal_to_actual, context | ||
callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context | ||
) | ||
if need_refresh: | ||
formal_to_actual = map_actuals_to_formals( | ||
|
@@ -1855,6 +1873,8 @@ def infer_function_type_arguments_using_context( | |
# def identity(x: T) -> T: return x | ||
# | ||
# expects_literal(identity(3)) # Should type-check | ||
# TODO: we may want to add similar exception if all arguments are lambdas, since | ||
# in this case external context is almost everything we have. | ||
if not is_generic_instance(ctx) and not is_literal_type_like(ctx): | ||
return callable.copy_modified() | ||
args = infer_type_arguments(callable.variables, ret_type, erased_ctx) | ||
|
@@ -1876,7 +1896,9 @@ def infer_function_type_arguments( | |
callee_type: CallableType, | ||
args: list[Expression], | ||
arg_kinds: list[ArgKind], | ||
arg_names: Sequence[str | None] | None, | ||
formal_to_actual: list[list[int]], | ||
need_refresh: bool, | ||
context: Context, | ||
) -> CallableType: | ||
"""Infer the type arguments for a generic callee type. | ||
|
@@ -1918,7 +1940,14 @@ def infer_function_type_arguments( | |
if 2 in arg_pass_nums: | ||
# Second pass of type inference. | ||
(callee_type, inferred_args) = self.infer_function_type_arguments_pass2( | ||
callee_type, args, arg_kinds, formal_to_actual, inferred_args, context | ||
callee_type, | ||
args, | ||
arg_kinds, | ||
arg_names, | ||
formal_to_actual, | ||
inferred_args, | ||
need_refresh, | ||
context, | ||
) | ||
|
||
if ( | ||
|
@@ -1944,6 +1973,17 @@ def infer_function_type_arguments( | |
or set(get_type_vars(a)) & set(callee_type.variables) | ||
for a in inferred_args | ||
): | ||
if need_refresh: | ||
# Technically we need to refresh formal_to_actual after *each* inference pass, | ||
# since each pass can expand ParamSpec or TypeVarTuple. Although such situations | ||
# are very rare, not doing this can cause crashes. | ||
formal_to_actual = map_actuals_to_formals( | ||
arg_kinds, | ||
arg_names, | ||
callee_type.arg_kinds, | ||
callee_type.arg_names, | ||
lambda a: self.accept(args[a]), | ||
) | ||
# If the regular two-phase inference didn't work, try inferring type | ||
# variables while allowing for polymorphic solutions, i.e. for solutions | ||
# potentially involving free variables. | ||
|
@@ -1991,8 +2031,10 @@ def infer_function_type_arguments_pass2( | |
callee_type: CallableType, | ||
args: list[Expression], | ||
arg_kinds: list[ArgKind], | ||
arg_names: Sequence[str | None] | None, | ||
formal_to_actual: list[list[int]], | ||
old_inferred_args: Sequence[Type | None], | ||
need_refresh: bool, | ||
context: Context, | ||
) -> tuple[CallableType, list[Type | None]]: | ||
"""Perform second pass of generic function type argument inference. | ||
|
@@ -2014,6 +2056,14 @@ def infer_function_type_arguments_pass2( | |
if isinstance(arg, (NoneType, UninhabitedType)) or has_erased_component(arg): | ||
inferred_args[i] = None | ||
callee_type = self.apply_generic_arguments(callee_type, inferred_args, context) | ||
if need_refresh: | ||
formal_to_actual = map_actuals_to_formals( | ||
arg_kinds, | ||
arg_names, | ||
callee_type.arg_kinds, | ||
callee_type.arg_names, | ||
lambda a: self.accept(args[a]), | ||
) | ||
|
||
arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual) | ||
|
||
|
@@ -4674,8 +4724,22 @@ def infer_lambda_type_using_context( | |
# they must be considered as indeterminate. We use ErasedType since it | ||
# does not affect type inference results (it is for purposes like this | ||
# only). | ||
callable_ctx = get_proper_type(replace_meta_vars(ctx, ErasedType())) | ||
assert isinstance(callable_ctx, CallableType) | ||
if self.chk.options.new_type_inference: | ||
# With new type inference we can preserve argument types even if they | ||
# are generic, since new inference algorithm can handle constraints | ||
# like S <: T (we still erase return type since it's ultimately unknown). | ||
extra_vars = [] | ||
for arg in ctx.arg_types: | ||
meta_vars = [tv for tv in get_all_type_vars(arg) if tv.id.is_meta_var()] | ||
extra_vars.extend([tv for tv in meta_vars if tv not in extra_vars]) | ||
callable_ctx = ctx.copy_modified( | ||
ret_type=replace_meta_vars(ctx.ret_type, ErasedType()), | ||
variables=list(ctx.variables) + extra_vars, | ||
) | ||
else: | ||
erased_ctx = replace_meta_vars(ctx, ErasedType()) | ||
assert isinstance(erased_ctx, ProperType) and isinstance(erased_ctx, CallableType) | ||
callable_ctx = erased_ctx | ||
|
||
# The callable_ctx may have a fallback of builtins.type if the context | ||
# is a constructor -- but this fallback doesn't make sense for lambdas. | ||
|
@@ -5632,18 +5696,28 @@ def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: | |
self.bound_tvars: set[TypeVarLikeType] = set() | ||
self.seen_aliases: set[TypeInfo] = set() | ||
|
||
def visit_callable_type(self, t: CallableType) -> Type: | ||
found_vars = set() | ||
def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]: | ||
found_vars = [] | ||
for arg in t.arg_types: | ||
found_vars |= set(get_type_vars(arg)) & self.poly_tvars | ||
for tv in get_all_type_vars(arg): | ||
if isinstance(tv, ParamSpecType): | ||
normalized: TypeVarLikeType = tv.copy_modified( | ||
flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], []) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why drop the prefix? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have something like |
||
) | ||
else: | ||
normalized = tv | ||
if normalized in self.poly_tvars and normalized not in self.bound_tvars: | ||
found_vars.append(normalized) | ||
return remove_dups(found_vars) | ||
|
||
found_vars -= self.bound_tvars | ||
self.bound_tvars |= found_vars | ||
def visit_callable_type(self, t: CallableType) -> Type: | ||
found_vars = self.collect_vars(t) | ||
self.bound_tvars |= set(found_vars) | ||
result = super().visit_callable_type(t) | ||
self.bound_tvars -= found_vars | ||
self.bound_tvars -= set(found_vars) | ||
|
||
assert isinstance(result, ProperType) and isinstance(result, CallableType) | ||
result.variables = list(result.variables) + list(found_vars) | ||
result.variables = list(result.variables) + found_vars | ||
return result | ||
|
||
def visit_type_var(self, t: TypeVarType) -> Type: | ||
|
@@ -5652,8 +5726,9 @@ def visit_type_var(self, t: TypeVarType) -> Type: | |
return super().visit_type_var(t) | ||
|
||
def visit_param_spec(self, t: ParamSpecType) -> Type: | ||
# TODO: Support polymorphic apply for ParamSpec. | ||
raise PolyTranslationError() | ||
if t in self.poly_tvars and t not in self.bound_tvars: | ||
raise PolyTranslationError() | ||
return super().visit_param_spec(t) | ||
|
||
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: | ||
# TODO: Support polymorphic apply for TypeVarTuple. | ||
|
@@ -5669,6 +5744,26 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type: | |
raise PolyTranslationError() | ||
|
||
def visit_instance(self, t: Instance) -> Type: | ||
if t.type.has_param_spec_type: | ||
# We need this special-casing to preserve the possibility to store a | ||
# generic function in an instance type. Things like | ||
# forall T . Foo[[x: T], T] | ||
# are not really expressible in current type system, but this looks like | ||
# a useful feature, so let's keep it. | ||
param_spec_index = next( | ||
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType) | ||
) | ||
p = get_proper_type(t.args[param_spec_index]) | ||
if isinstance(p, Parameters): | ||
found_vars = self.collect_vars(p) | ||
self.bound_tvars |= set(found_vars) | ||
new_args = [a.accept(self) for a in t.args] | ||
self.bound_tvars -= set(found_vars) | ||
|
||
repl = new_args[param_spec_index] | ||
assert isinstance(repl, ProperType) and isinstance(repl, Parameters) | ||
repl.variables = list(repl.variables) + list(found_vars) | ||
return t.copy_modified(args=new_args) | ||
# There is the same problem with callback protocols as with aliases | ||
# (callback protocols are essentially more flexible aliases to callables). | ||
# Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T]. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
extra_vars
could be a set maybe? That would mean an IMO simpler comprehension.I'm also not sure why
ctx.variables
is guaranteed to not include these new variables.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC I use this logic with lists for variables here (and in few other places) to have stable order. Otherwise tests will randomly fail on
reveal_type()
(and it is generally good to have predictable stable order for comparison purposes).