Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: fixes and improvements for backedge computation #46741

Merged
merged 5 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
inference: setup separate functions for each backedge kind
Also changes the argument list so that they are ordered as
`(caller, [backedge information])`.
  • Loading branch information
aviatesk committed Sep 15, 2022
commit 3e070f463b1fef94c03e8e10ab2af50501598858
18 changes: 11 additions & 7 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,15 +483,15 @@ function add_call_backedges!(interp::AbstractInterpreter,
end
end
for edge in edges
add_backedge!(edge, sv)
add_backedge!(sv, edge)
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
thisfullmatch || add_mt_backedge!(mt, atype, sv)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
end
Expand Down Expand Up @@ -889,7 +889,11 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
end
res = concrete_eval_call(interp, f, result, arginfo, sv, invokecall)
if isa(res, ConstCallResults)
add_backedge!(res.const_result.mi, sv, invokecall === nothing ? nothing : invokecall.lookupsig)
if invokecall === nothing
add_backedge!(sv, res.const_result.mi)
else
add_invoke_backedge!(sv, invokecall.lookupsig, res.const_result.mi)
end
return res
end
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
Expand Down Expand Up @@ -936,7 +940,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
result = inf_result.result
# if constant inference hits a cycle, just bail out
isa(result, InferenceState) && return nothing
add_backedge!(mi, sv)
add_backedge!(sv, mi)
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects)
end

Expand Down Expand Up @@ -1692,7 +1696,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
ti = tienv[1]; env = tienv[2]::SimpleVector
result = abstract_call_method(interp, method, ti, env, false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge::MethodInstance, sv, lookupsig)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge::MethodInstance)
match = MethodMatch(ti, env, method, argtype <: method.sig)
res = nothing
sig = match.spec_types
Expand Down Expand Up @@ -1848,7 +1852,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
sig = argtypes_to_type(arginfo.argtypes)
result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge, sv)
edge !== nothing && add_backedge!(sv, edge)
tt = closure.typ
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
Expand Down
41 changes: 26 additions & 15 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,38 +478,49 @@ function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceSta
return nothing
end

function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int)
function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, currpc::Int)
update_valid_age!(frame, caller)
backedge = (caller, currpc)
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
add_backedge!(frame.linfo, caller)
add_backedge!(caller, frame.linfo)
return frame
end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(li::MethodInstance, caller::InferenceState, @nospecialize(invokesig=nothing))
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
function add_backedge!(caller::InferenceState, li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, li)
end
if invokesig !== nothing
push!(edges, invokesig)
return nothing
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, invokesig, li)
end
push!(edges, li)
return nothing
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::InferenceState)
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
end
return nothing
end

function get_stmt_edges!(caller::InferenceState)
if !isa(caller.linfo.def, Method)
return nothing # don't add backedges to toplevel exprs
end
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
end
push!(edges, mt)
push!(edges, typ)
return nothing
return edges
end

function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
Expand Down
8 changes: 6 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ intersect!(et::EdgeTracker, range::WorldRange) =
et.valid_worlds[] = intersect(et.valid_worlds[], range)

push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi)
function add_edge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
invokesig === nothing && return push!(et.edges, mi)
function add_backedge!(et::EdgeTracker, mi::MethodInstance)
push!(et.edges, mi)
return nothing
end
function add_invoke_backedge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
push!(et.edges, invokesig, mi)
return nothing
end
function push!(et::EdgeTracker, ci::CodeInstance)
intersect!(et, WorldRange(min_world(li), max_world(li)))
Expand Down
24 changes: 21 additions & 3 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
inferred_src = match.src
if isa(inferred_src, ConstAPI)
# use constant calling convention
et !== nothing && add_edge!(et, invokesig, mi)
if et !== nothing
if invokesig === nothing
add_backedge!(et, mi)
else
add_invoke_backedge!(et, invokesig, mi)
end
end
return ConstantCase(quoted(inferred_src.val))
else
src = inferred_src # ::Union{Nothing,CodeInfo} for NativeInterpreter
Expand All @@ -843,7 +849,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
if code isa CodeInstance
if use_const_api(code)
# in this case function can be inlined to a constant
et !== nothing && add_edge!(et, invokesig, mi)
if et !== nothing
if invokesig === nothing
add_backedge!(et, mi)
else
add_invoke_backedge!(et, invokesig, mi)
end
end
return ConstantCase(quoted(code.rettype_const))
else
src = @atomic :monotonic code.inferred
Expand All @@ -867,7 +879,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
src === nothing && return compileable_specialization(et, match, effects;
compilesig_invokes=state.params.compilesig_invokes)

et !== nothing && add_edge!(et, invokesig, mi)
if et !== nothing
if invokesig === nothing
add_backedge!(et, mi)
else
add_invoke_backedge!(et, invokesig, mi)
end
end
return InliningTodo(mi, retrieve_ir_for_inlining(mi, src), effects)
end

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ function merge_call_chain!(interp::AbstractInterpreter, parent::InferenceState,
# of recursion.
merge_effects!(interp, parent, Effects(EFFECTS_TOTAL; terminates=false))
while true
add_cycle_backedge!(child, parent, parent.currpc)
add_cycle_backedge!(parent, child, parent.currpc)
union_caller_cycle!(ancestor, child)
merge_effects!(interp, child, Effects(EFFECTS_TOTAL; terminates=false))
child = parent
Expand Down