Skip to content

Commit

Permalink
Improve precision of abstract_iteration (JuliaLang#36524)
Browse files Browse the repository at this point in the history
Currently abstract iteration works fine for length-1 iterators, but
fails over for other small length iterators, such as Pair{Int, Int},
or common patterns where `iterate` iterates over the fields of a
small struct. Other examples include StaticVectors and constant
iterators, which should all now unroll properly up to the
MAX_TUPLE_SPLAT limit. That said, MAX_TUPLE_SPLAT is quite high
at the moment, but because abstract_iteration isn't very precise,
it's unlikely to be limiting. With this increase in precision,
we may find that MAX_TUPLE_SPLAT is too high and we should lower it.

Also note that while this is a nice improvement, performance is still
not great, since we currently can't inline these apply calls. However,
fixing that is in progress in a parallel PR and with both changes
put together, the performance improvement of splatting of small
iterators is quite sizeable.
  • Loading branch information
Keno authored Jul 9, 2020
1 parent 47ffc00 commit bee6546
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
44 changes: 28 additions & 16 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
end
elseif istopfunction(f, :iterate)
itrty = argtypes[2]
if itrty isa Type && !issingletontype(itrty)
return Any
elseif itrty Array
if itrty Array
return Any
end
end
Expand Down Expand Up @@ -569,37 +567,51 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
stateordonet === Bottom && return Any[Bottom]
valtype = statetype = Bottom
ret = Any[]
stateordonet = widenconst(stateordonet)
while !(Nothing <: stateordonet) && length(ret) < InferenceParams(interp).MAX_TUPLE_SPLAT
if !isa(stateordonet, DataType) || !(stateordonet <: Tuple) || isvatuple(stateordonet) || length(stateordonet.parameters) != 2
# Try to unroll the iteration up to MAX_TUPLE_SPLAT, which covers any finite
# length iterators, or interesting prefix
while true
stateordonet_widened = widenconst(stateordonet)
if stateordonet_widened === Nothing
return ret
end
if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).MAX_TUPLE_SPLAT
break
end
if stateordonet.parameters[2] <: statetype
# infinite (or failing) iterator
if !isa(stateordonet_widened, DataType) || !(stateordonet_widened <: Tuple) || isvatuple(stateordonet_widened) || length(stateordonet_widened.parameters) != 2
break
end
nstatetype = getfield_tfunc(stateordonet, Const(2))
# If there's no new information in this statetype, don't bother continuing,
# the iterator won't be finite.
if nstatetype statetype
return Any[Bottom]
end
valtype = stateordonet.parameters[1]
statetype = stateordonet.parameters[2]
valtype = getfield_tfunc(stateordonet, Const(1))
push!(ret, valtype)
statetype = nstatetype
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
stateordonet = widenconst(stateordonet)
end
if stateordonet === Nothing
return ret
end
# From here on, we start asking for results on the widened types, rather than
# the precise (potentially const) state type
statetype = widenconst(statetype)
valtype = widenconst(valtype)
while valtype !== Any
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
stateordonet = widenconst(stateordonet)
nounion = typesubtract(stateordonet, Nothing)
if !isa(nounion, DataType) || !(nounion <: Tuple) || isvatuple(nounion) || length(nounion.parameters) != 2
valtype = Any
break
end
if nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype
if typeintersect(stateordonet, Nothing) === Union{}
# Reached a fixpoint, but Nothing is not possible => iterator is infinite or failing
return Any[Bottom]
end
break
end
valtype = tmerge(valtype, nounion.parameters[1])
statetype = tmerge(statetype, nounion.parameters[2])
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
stateordonet = widenconst(stateordonet)
end
push!(ret, Vararg{valtype})
return ret
Expand Down
4 changes: 4 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2674,3 +2674,7 @@ partial_return_1(x) = (x, 1)
partial_return_2(x) = Val{partial_return_1(x)[2]}

@test Base.return_types(partial_return_2, (Int,)) == Any[Type{Val{1}}]

# Precision of abstract_iteration
f_splat(x) = (x...,)
@test Base.return_types(f_splat, (Pair{Int,Int},)) == Any[Tuple{Int, Int}]

0 comments on commit bee6546

Please sign in to comment.