Skip to content

Commit

Permalink
Implementing background infrastructure for recursive types: Part 1 (p…
Browse files Browse the repository at this point in the history
…ython#7330)

During planning discussions one of the main concerns about recursive types was the fact that we have hundreds of places where certain types are special-cased using `isinstance()`, and fixing all of them will take weeks.

So I did a little experiment this weekend, to understand how bad it _actually_ is. I wrote a simple mypy plugin for mypy self-check, and it discovered 800+ such call sites. This looks pretty bad, but it turns out that fixing half of them (roughly 400 plugin errors) took me less than 2 days. This is kind of a triumph of our tooling :-) (i.e. mypy plugin + PyCharm plugin).

Taking into account results of this experiment I propose to actually go ahead and implement recursive types. Here are some comments:

* There will be four subsequent PRs: second part of `isinstance()` cleanup, implementing visitors and related methods everywhere, actual core implementation, adding extra tests for tricky recursion patterns.
* The core idea of implementation stays the same as we discussed with @JukkaL: `TypeAliasType` and `TypeAlias` node will essentially match logic between `Instance` and `TypeInfo` (but structurally, as for protocols)
* I wanted to make `PlaceholderType` a non-`ProperType`, but it didn't work immediately because we call `make_union()` during semantic analysis. If this seems important, this can be done with a bit more effort.
* I make `TypeType.item` a proper type (following PEP 484, only very limited things can be passed to `Type[...]`). I also make `UnionType.items` proper types, mostly because of `make_simplified_union()`. Finally, I make `FuncBase.type` a proper type, I think a type alias can never appear there.
* It is sometimes hard to decide where exactly is to call `get_proper_type()`, I tried to balance calling them not too soon and not too late, depending of every individual case. Please review, I am open to modifying logic in some places.
  • Loading branch information
ilevkivskyi authored Aug 16, 2019
1 parent 7fb7e26 commit e04bf78
Show file tree
Hide file tree
Showing 43 changed files with 659 additions and 320 deletions.
70 changes: 70 additions & 0 deletions misc/proper_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from mypy.plugin import Plugin, FunctionContext
from mypy.types import Type, Instance, CallableType, UnionType, get_proper_type

import os.path
from typing_extensions import Type as typing_Type
from typing import Optional, Callable

FILE_WHITELIST = [
'checker.py',
'checkexpr.py',
'checkmember.py',
'messages.py',
'semanal.py',
'typeanal.py'
]


class ProperTypePlugin(Plugin):
"""
A plugin to ensure that every type is expanded before doing any special-casing.
This solves the problem that we have hundreds of call sites like:
if isinstance(typ, UnionType):
... # special-case union
But after introducing a new type TypeAliasType (and removing immediate expansion)
all these became dangerous because typ may be e.g. an alias to union.
"""
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'builtins.isinstance':
return isinstance_proper_hook
return None


def isinstance_proper_hook(ctx: FunctionContext) -> Type:
if os.path.split(ctx.api.path)[-1] in FILE_WHITELIST:
return ctx.default_return_type
for arg in ctx.arg_types[0]:
if is_improper_type(arg):
right = get_proper_type(ctx.arg_types[1][0])
if isinstance(right, CallableType) and right.is_type_obj():
if right.type_object().fullname() in ('mypy.types.Type',
'mypy.types.ProperType',
'mypy.types.TypeAliasType'):
# Special case: things like assert isinstance(typ, ProperType) are always OK.
return ctx.default_return_type
if right.type_object().fullname() in ('mypy.types.UnboundType',
'mypy.types.TypeVarType'):
# Special case: these are not valid targets for a type alias and thus safe.
return ctx.default_return_type
ctx.api.fail('Never apply isinstance() to unexpanded types;'
' use mypy.types.get_proper_type() first', ctx.context)
return ctx.default_return_type


def is_improper_type(typ: Type) -> bool:
"""Is this a type that is not a subtype of ProperType?"""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
info = typ.type
return info.has_base('mypy.types.Type') and not info.has_base('mypy.types.ProperType')
if isinstance(typ, UnionType):
return any(is_improper_type(t) for t in typ.items)
return False


def plugin(version: str) -> typing_Type[ProperTypePlugin]:
return ProperTypePlugin
8 changes: 5 additions & 3 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import mypy.subtypes
import mypy.sametypes
from mypy.expandtype import expand_type
from mypy.types import Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType
from mypy.types import (
Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types
)
from mypy.messages import MessageBuilder
from mypy.nodes import Context

Expand All @@ -25,10 +27,10 @@ def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optiona
assert len(tvars) == len(orig_types)
# Check that inferred type variable values are compatible with allowed
# values and bounds. Also, promote subtype values to allowed values.
types = list(orig_types)
types = get_proper_types(orig_types)
for i, type in enumerate(types):
assert not isinstance(type, PartialType), "Internal error: must never apply partial type"
values = callable.variables[i].values
values = get_proper_types(callable.variables[i].values)
if type is None:
continue
if values:
Expand Down
9 changes: 6 additions & 3 deletions mypy/argmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from typing import List, Optional, Sequence, Callable, Set

from mypy.types import Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType
from mypy.types import (
Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, get_proper_type
)
from mypy import nodes


Expand Down Expand Up @@ -34,7 +36,7 @@ def map_actuals_to_formals(actual_kinds: List[int],
formal_to_actual[fi].append(ai)
elif actual_kind == nodes.ARG_STAR:
# We need to know the actual type to map varargs.
actualt = actual_arg_type(ai)
actualt = get_proper_type(actual_arg_type(ai))
if isinstance(actualt, TupleType):
# A tuple actual maps to a fixed number of formals.
for _ in range(len(actualt.items)):
Expand Down Expand Up @@ -65,7 +67,7 @@ def map_actuals_to_formals(actual_kinds: List[int],
formal_to_actual[formal_kinds.index(nodes.ARG_STAR2)].append(ai)
else:
assert actual_kind == nodes.ARG_STAR2
actualt = actual_arg_type(ai)
actualt = get_proper_type(actual_arg_type(ai))
if isinstance(actualt, TypedDictType):
for name, value in actualt.items.items():
if name in formal_names:
Expand Down Expand Up @@ -153,6 +155,7 @@ def expand_actual_type(self,
This is supposed to be called for each formal, in order. Call multiple times per
formal if multiple actuals map to a formal.
"""
actual_type = get_proper_type(actual_type)
if actual_kind == nodes.ARG_STAR:
if isinstance(actual_type, Instance):
if actual_type.type.fullname() == 'builtins.list':
Expand Down
17 changes: 12 additions & 5 deletions mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Dict, List, Set, Iterator, Union, Optional, Tuple, cast
from typing_extensions import DefaultDict

from mypy.types import Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType
from mypy.types import (
Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType, get_proper_type
)
from mypy.subtypes import is_subtype
from mypy.join import join_simple
from mypy.sametypes import is_same_type
Expand Down Expand Up @@ -191,7 +193,7 @@ def update_from_options(self, frames: List[Frame]) -> bool:

type = resulting_values[0]
assert type is not None
declaration_type = self.declarations.get(key)
declaration_type = get_proper_type(self.declarations.get(key))
if isinstance(declaration_type, AnyType):
# At this point resulting values can't contain None, see continue above
if not all(is_same_type(type, cast(Type, t)) for t in resulting_values[1:]):
Expand Down Expand Up @@ -246,6 +248,9 @@ def assign_type(self, expr: Expression,
type: Type,
declared_type: Optional[Type],
restrict_any: bool = False) -> None:
type = get_proper_type(type)
declared_type = get_proper_type(declared_type)

if self.type_assignments is not None:
# We are in a multiassign from union, defer the actual binding,
# just collect the types.
Expand All @@ -270,7 +275,7 @@ def assign_type(self, expr: Expression,
# times?
return

enclosing_type = self.most_recent_enclosing_type(expr, type)
enclosing_type = get_proper_type(self.most_recent_enclosing_type(expr, type))
if isinstance(enclosing_type, AnyType) and not restrict_any:
# If x is Any and y is int, after x = y we do not infer that x is int.
# This could be changed.
Expand All @@ -287,7 +292,8 @@ def assign_type(self, expr: Expression,
elif (isinstance(type, AnyType)
and isinstance(declared_type, UnionType)
and any(isinstance(item, NoneType) for item in declared_type.items)
and isinstance(self.most_recent_enclosing_type(expr, NoneType()), NoneType)):
and isinstance(get_proper_type(self.most_recent_enclosing_type(expr, NoneType())),
NoneType)):
# Replace any Nones in the union type with Any
new_items = [type if isinstance(item, NoneType) else item
for item in declared_type.items]
Expand Down Expand Up @@ -320,6 +326,7 @@ def invalidate_dependencies(self, expr: BindableExpression) -> None:
self._cleanse_key(dep)

def most_recent_enclosing_type(self, expr: BindableExpression, type: Type) -> Optional[Type]:
type = get_proper_type(type)
if isinstance(type, AnyType):
return get_declaration(expr)
key = literal_hash(expr)
Expand Down Expand Up @@ -412,7 +419,7 @@ def top_frame_context(self) -> Iterator[Frame]:

def get_declaration(expr: BindableExpression) -> Optional[Type]:
if isinstance(expr, RefExpr) and isinstance(expr.node, Var):
type = expr.node.type
type = get_proper_type(expr.node.type)
if not isinstance(type, PartialType):
return type
return None
8 changes: 4 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
impl_type = None # type: Optional[CallableType]
if defn.impl:
if isinstance(defn.impl, FuncDef):
inner_type = defn.impl.type
inner_type = defn.impl.type # type: Optional[Type]
elif isinstance(defn.impl, Decorator):
inner_type = defn.impl.var.type
else:
Expand Down Expand Up @@ -3650,8 +3650,8 @@ def find_isinstance_check(self, node: Expression
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
# respectively
vartype = type_map[node]
if_type = true_only(vartype)
else_type = false_only(vartype)
if_type = true_only(vartype) # type: Type
else_type = false_only(vartype) # type: Type
ref = node # type: Expression
if_map = {ref: if_type} if not isinstance(if_type, UninhabitedType) else None
else_map = {ref: else_type} if not isinstance(else_type, UninhabitedType) else None
Expand Down Expand Up @@ -4139,7 +4139,7 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap:
# expressions whose type is refined by both conditions. (We do not
# learn anything about expressions whose type is refined by only
# one condition.)
result = {}
result = {} # type: Dict[Expression, Type]
for n1 in m1:
for n2 in m2:
if literal_hash(n1) == literal_hash(n2):
Expand Down
11 changes: 6 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
StarType, is_optional, remove_optional, is_generic_instance
StarType, is_optional, remove_optional, is_generic_instance, get_proper_type
)
from mypy.nodes import (
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
Expand Down Expand Up @@ -587,6 +587,7 @@ def apply_function_plugin(self,
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
object_type = get_proper_type(object_type)
return method_callback(
MethodContext(object_type, formal_arg_types, formal_arg_kinds,
callee.arg_names, formal_arg_names,
Expand All @@ -608,6 +609,7 @@ def apply_method_signature_hook(
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_exprs[formal].append(args[actual])
object_type = get_proper_type(object_type)
return signature_hook(
MethodSigContext(object_type, formal_arg_exprs, callee, context, self.chk))
else:
Expand Down Expand Up @@ -2710,7 +2712,7 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression)
else:
typ = self.accept(index)
if isinstance(typ, UnionType):
key_types = typ.items
key_types = list(typ.items) # type: List[Type]
else:
key_types = [typ]

Expand Down Expand Up @@ -3549,7 +3551,7 @@ def has_member(self, typ: Type, member: str) -> bool:
elif isinstance(typ, TypeType):
# Type[Union[X, ...]] is always normalized to Union[Type[X], ...],
# so we don't need to care about unions here.
item = typ.item
item = typ.item # type: Type
if isinstance(item, TypeVarType):
item = item.upper_bound
if isinstance(item, TupleType):
Expand Down Expand Up @@ -3743,8 +3745,7 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type, # noqa
not is_overlapping_types(known_type, restriction,
prohibit_none_typevar_overlap=True)):
return None
ans = narrow_declared_type(known_type, restriction)
return ans
return narrow_declared_type(known_type, restriction)
return known_type


Expand Down
11 changes: 6 additions & 5 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mypy.types import (
Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef,
Overloaded, TypeVarType, UnionType, PartialType, UninhabitedType, TypeOfAny, LiteralType,
DeletedType, NoneType, TypeType, function_type, get_type_vars,
DeletedType, NoneType, TypeType, function_type, get_type_vars, get_proper_type
)
from mypy.nodes import (
TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr,
Expand Down Expand Up @@ -371,8 +371,8 @@ def analyze_member_var_access(name: str,
fullname = '{}.{}'.format(method.info.fullname(), name)
hook = mx.chk.plugin.get_attribute_hook(fullname)
if hook:
result = hook(AttributeContext(mx.original_type, result,
mx.context, mx.chk))
result = hook(AttributeContext(get_proper_type(mx.original_type),
result, mx.context, mx.chk))
return result
else:
setattr_meth = info.get_method('__setattr__')
Expand Down Expand Up @@ -511,7 +511,7 @@ def analyze_var(name: str,
mx.msg.read_only_property(name, itype.type, mx.context)
if mx.is_lvalue and var.is_classvar:
mx.msg.cant_assign_to_classvar(name, mx.context)
result = t
result = t # type: Type
if var.is_initialized_in_class and isinstance(t, FunctionLike) and not t.is_type_obj():
if mx.is_lvalue:
if var.is_property:
Expand Down Expand Up @@ -552,7 +552,8 @@ def analyze_var(name: str,
result = analyze_descriptor_access(mx.original_type, result, mx.builtin_type,
mx.msg, mx.context, chk=mx.chk)
if hook:
result = hook(AttributeContext(mx.original_type, result, mx.context, mx.chk))
result = hook(AttributeContext(get_proper_type(mx.original_type),
result, mx.context, mx.chk))
return result


Expand Down
4 changes: 2 additions & 2 deletions mypy/checkstrformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Final, TYPE_CHECKING

from mypy.types import (
Type, AnyType, TupleType, Instance, UnionType, TypeOfAny
Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type
)
from mypy.nodes import (
StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr
Expand Down Expand Up @@ -137,7 +137,7 @@ def check_simple_str_interpolation(self, specifiers: List[ConversionSpecifier],
if checkers is None:
return

rhs_type = self.accept(replacements)
rhs_type = get_proper_type(self.accept(replacements))
rep_types = [] # type: List[Type]
if isinstance(rhs_type, TupleType):
rep_types = rhs_type.items
Expand Down
13 changes: 8 additions & 5 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance,
TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType,
ProperType, get_proper_type
)
from mypy.maptype import map_instance_to_supertype
import mypy.subtypes
Expand Down Expand Up @@ -88,6 +89,8 @@ def infer_constraints(template: Type, actual: Type,
The constraints are represented as Constraint objects.
"""
template = get_proper_type(template)
actual = get_proper_type(actual)

# If the template is simply a type variable, emit a Constraint directly.
# We need to handle this case before handling Unions for two reasons:
Expand Down Expand Up @@ -199,12 +202,12 @@ def is_same_constraint(c1: Constraint, c2: Constraint) -> bool:
and mypy.sametypes.is_same_type(c1.target, c2.target))


def simplify_away_incomplete_types(types: List[Type]) -> List[Type]:
def simplify_away_incomplete_types(types: Iterable[Type]) -> List[Type]:
complete = [typ for typ in types if is_complete_type(typ)]
if complete:
return complete
else:
return types
return list(types)


def is_complete_type(typ: Type) -> bool:
Expand All @@ -229,9 +232,9 @@ class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]):

# The type that is compared against a template
# TODO: The value may be None. Is that actually correct?
actual = None # type: Type
actual = None # type: ProperType

def __init__(self, actual: Type, direction: int) -> None:
def __init__(self, actual: ProperType, direction: int) -> None:
# Direction must be SUBTYPE_OF or SUPERTYPE_OF.
self.actual = actual
self.direction = direction
Expand Down Expand Up @@ -298,7 +301,7 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
if isinstance(actual, Instance):
instance = actual
erased = erase_typevars(template)
assert isinstance(erased, Instance)
assert isinstance(erased, Instance) # type: ignore
# We always try nominal inference if possible,
# it is much faster than the structural one.
if (self.direction == SUBTYPE_OF and
Expand Down
Loading

0 comments on commit e04bf78

Please sign in to comment.