Skip to content

Commit

Permalink
Refactor OptimizationState lifetime
Browse files Browse the repository at this point in the history
In #36508 we decided after some consideration not to add the `stmtinfo`
to the `CodeInfo` object, since this info would never be used for
codegen. However, this also means that external AbstractInterpreters
that would like to cache pre-optimization results cannot simply cache
the unoptimized `CodeInfo` and then later feed it into the optimizer.
Instead they would need to cache the whole OptimizationState object,
or maybe convert to IRCode before caching. However, at the moment
we eagerly drop the `OptimizationState` wrapper as soon was we
decide not to run the optimizer. This refactors things to keep
the OptimizationState around for unoptimized methods, only dropping
it right before caching, in a way that can be overriden by
an external AbstractInterpreter.

We run into the inverse problem during costant propagation where
inference would like to peek at the results of optimization in
order to decide whether constant propagation is likely to be
profitable. Of course, if optimization hasn't actually run yet
for this AbstractInterpreter, this doesn't work. Factor out
this logic such that an external interpreter can override this
heuristic. E.g. for my AD interpreter, I'm thinking just looking
at the vanilla function and checking its complexity would be
a good heuristic (since the AD'd version is supposed to give
the same result as the vanilla function, modulo capturing
some additional state for the reverse pass).
  • Loading branch information
Keno committed Jul 21, 2020
1 parent b08bd7d commit f9a72a4
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 69 deletions.
42 changes: 27 additions & 15 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,31 @@ function const_prop_profitable(@nospecialize(arg))
return false
end

# This is a heuristic to avoid trying to const prop through complicated functions
# where we would spend a lot of time, but are probably unliekly to get an improved
# result anyway.
function const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
# Peek at the inferred result for the function to determine if the optimizer
# was able to cut it down to something simple (inlineable in particular).
# If so, there's a good chance we might be able to const prop all the way
# through and learn something new.
code = get(code_cache(interp), mi, nothing)
declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source)
cache_inlineable = declared_inline
if isdefined(code, :inferred) && !cache_inlineable
cache_inf = code.inferred
if !(cache_inf === nothing)
cache_src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), cache_inf)
cache_src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), cache_inf)
cache_inlineable = cache_src_inferred && cache_src_inlineable
end
end
if !cache_inlineable
return false
end
return true
end

function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, sv::InferenceState, edgecycle::Bool)
method = match.method
nargs::Int = method.nargs
Expand Down Expand Up @@ -265,21 +290,8 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
mi === nothing && return Any
mi = mi::MethodInstance
# decide if it's likely to be worthwhile
if !force_inference
code = get(code_cache(interp), mi, nothing)
declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source)
cache_inlineable = declared_inline
if isdefined(code, :inferred) && !cache_inlineable
cache_inf = code.inferred
if !(cache_inf === nothing)
cache_src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), cache_inf)
cache_src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), cache_inf)
cache_inlineable = cache_src_inferred && cache_src_inlineable
end
end
if !cache_inlineable
return Any
end
if !force_inference && !const_prop_heuristic(interp, method, mi)
return Any
end
inf_cache = get_inference_cache(interp)
inf_result = cache_lookup(mi, argtypes, inf_cache)
Expand Down
122 changes: 68 additions & 54 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,34 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
finish(caller, interp)
end
# collect results for the new expanded frame
results = InferenceResult[ frames[i].result for i in 1:length(frames) ]
results = Tuple{InferenceResult, Bool}[ ( frames[i].result,
frames[i].cached || frames[i].parent !== nothing ) for i in 1:length(frames) ]
# empty!(frames)
min_valid = frame.min_valid
max_valid = frame.max_valid
cached = frame.cached
if cached || frame.parent !== nothing
for caller in results
for (caller, doopt) in results
opt = caller.src
if opt isa OptimizationState
optimize(opt, OptimizationParams(interp), caller.result)
finish(opt.src, interp)
# finish updating the result struct
validate_code_in_debug_mode(opt.linfo, opt.src, "optimized")
if opt.const_api
if caller.result isa Const
caller.src = caller.result
run_optimizer = doopt && may_optimize(interp)
if run_optimizer
optimize(opt, OptimizationParams(interp), caller.result)
finish(opt.src, interp)
# finish updating the result struct
validate_code_in_debug_mode(opt.linfo, opt.src, "optimized")
if opt.const_api
if caller.result isa Const
caller.src = caller.result
else
@assert isconstType(caller.result)
caller.src = Const(caller.result.parameters[1])
end
elseif opt.src.inferred
caller.src = opt.src::CodeInfo # stash a copy of the code (for inlining)
else
@assert isconstType(caller.result)
caller.src = Const(caller.result.parameters[1])
caller.src = nothing
end
elseif opt.src.inferred
caller.src = opt.src::CodeInfo # stash a copy of the code (for inlining)
else
caller.src = nothing
end
if min_valid < opt.min_valid
min_valid = opt.min_valid
Expand Down Expand Up @@ -79,14 +83,14 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
return true
end

function CodeInstance(result::InferenceResult, min_valid::UInt, max_valid::UInt,
may_compress=true, allow_discard_tree=true)
inferred_result = result.src
function CodeInstance(result::InferenceResult, @nospecialize(inferred_result::Any),
min_valid::UInt, max_valid::UInt)
local const_flags::Int32
if inferred_result isa Const
# use constant calling convention
rettype_const = (result.src::Const).val
const_flags = 0x3
inferred_result = nothing
else
if isa(result.result, Const)
rettype_const = (result.result::Const).val
Expand All @@ -101,28 +105,6 @@ function CodeInstance(result::InferenceResult, min_valid::UInt, max_valid::UInt,
rettype_const = nothing
const_flags = 0x00
end
if inferred_result isa CodeInfo
def = result.linfo.def
toplevel = !isa(def, Method)
if !toplevel
cache_the_tree = !allow_discard_tree || (result.src.inferred &&
(result.src.inlineable ||
ccall(:jl_isa_compileable_sig, Int32, (Any, Any), result.linfo.specTypes, def) != 0))
if cache_the_tree
if may_compress
nslots = length(inferred_result.slotflags)
resize!(inferred_result.slottypes, nslots)
resize!(inferred_result.slotnames, nslots)
inferred_result = ccall(:jl_compress_ir, Any, (Any, Any), def, inferred_result)
end
else
inferred_result = nothing
end
end
end
end
if !isa(inferred_result, Union{CodeInfo, Vector{UInt8}})
inferred_result = nothing
end
return CodeInstance(result.linfo,
widenconst(result.result), rettype_const, inferred_result,
Expand All @@ -138,8 +120,47 @@ already_inferred_quick_test(interp::NativeInterpreter, mi::MethodInstance) =
already_inferred_quick_test(interp::AbstractInterpreter, mi::MethodInstance) =
false

# inference completed on `me`
# update the MethodInstance
function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInstance, ci::CodeInfo)
def = linfo.def
toplevel = !isa(def, Method)
if toplevel
return ci
end
cache_the_tree = !may_discard_trees(interp) || (ci.inferred &&
(ci.inlineable ||
ccall(:jl_isa_compileable_sig, Int32, (Any, Any), linfo.specTypes, def) != 0))
if cache_the_tree
if may_compress(interp)
nslots = length(ci.slotflags)
resize!(ci.slottypes, nslots)
resize!(ci.slotnames, nslots)
return ccall(:jl_compress_ir, Any, (Any, Any), def, ci)
else
return ci
end
else
return nothing
end
end

function transform_result_for_cache(interp::AbstractInterpreter, linfo::MethodInstance,
@nospecialize(inferred_result))
local const_flags::Int32
# If we decided not to optimize, drop the OptimizationState now.
# External interpreters can override as necessary to cache additional information
if inferred_result isa OptimizationState
inferred_result = inferred_result.src
end
if inferred_result isa CodeInfo
inferred_result = maybe_compress_codeinfo(interp, linfo, inferred_result)
end
# The global cache can only handle objects that codegen understands
if !isa(inferred_result, Union{CodeInfo, Vector{UInt8}, Const})
inferred_result = nothing
end
return inferred_result
end

function cache_result!(interp::AbstractInterpreter, result::InferenceResult, min_valid::UInt, max_valid::UInt)
# check if the existing linfo metadata is also sufficient to describe the current inference result
# to decide if it is worth caching this
Expand All @@ -150,13 +171,15 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult, min

# TODO: also don't store inferred code if we've previously decided to interpret this function
if !already_inferred
code_cache(interp)[result.linfo] = CodeInstance(result, min_valid, max_valid,
may_compress(interp), may_discard_trees(interp))
inferred_result = transform_result_for_cache(interp, result.linfo, result.src)
code_cache(interp)[result.linfo] = CodeInstance(result, inferred_result, min_valid, max_valid)
end
unlock_mi_inference(interp, result.linfo)
nothing
end

# inference completed on `me`
# update the MethodInstance
function finish(me::InferenceState, interp::AbstractInterpreter)
# prepare to run optimization passes on fulltree
if me.limited && me.cached && me.parent !== nothing
Expand All @@ -168,16 +191,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
else
# annotate fulltree with type information
type_annotate!(me)
can_optimize = may_optimize(interp)
run_optimizer = (me.cached || me.parent !== nothing) && can_optimize
if run_optimizer
# construct the optimizer for later use, if we're building this IR to cache it
# (otherwise, we'll run the optimization passes later, outside of inference)
opt = OptimizationState(me, OptimizationParams(interp), interp)
me.result.src = opt
elseif !can_optimize
me.result.src = me.src
end
me.result.src = OptimizationState(me, OptimizationParams(interp), interp)
end
me.result.result = me.bestguess
nothing
Expand Down

0 comments on commit f9a72a4

Please sign in to comment.