Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Add inference->optimize analysis forwarding mechanism #36508

Merged
merged 1 commit into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add inference->optimize analysis forwarding mechanism
This change attempts to be a solution to the generalized problem
encountered in #36169. In short, we do a whole bunch of analysis
during inference to figure out the final type of an expression,
but sometimes, we may need intermediate results that were
computed along the way. So far, we don't really have a great
place to put those results, so we end up having to re-compute
them during the optimization phase. That's what #36169 did,
but is clearly not a scalable solution.

I encountered the exact same issue while working on a new AD
compiler plugin, that needs to do a whole bunch of work during
inference to determine what to do (e.g. call a primitive, recurse,
or increase the derivative level), and optimizations need to have
access to this information.

This PR adds an additional `info` field to CodeInfo and IRCode
that can be used to forward this kind of information. As a proof
of concept, it forwards method match info from inference to
inlining (we do already cache these, so there's little performance
gain from this per se - it's more to exercise the infrastructure).

The plan is to do an alternative fix to #36169 on top of this
as the next step, but I figured I'd open it up for discussion first.
  • Loading branch information
Keno committed Jul 15, 2020
commit 372e87824e99c4dbec6a87de9f4dfae611e81b2d
350 changes: 186 additions & 164 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ include("compiler/typeutils.jl")
include("compiler/typelimits.jl")
include("compiler/typelattice.jl")
include("compiler/tfuncs.jl")
include("compiler/stmtinfo.jl")

include("compiler/abstractinterpretation.jl")
include("compiler/typeinfer.jl")
Expand Down
4 changes: 3 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mutable struct InferenceState
nargs::Int
stmt_types::Vector{Any}
stmt_edges::Vector{Any}
stmt_info::Vector{Any}
# return type
bestguess #::Type
# current active instruction pointers
Expand Down Expand Up @@ -62,6 +63,7 @@ mutable struct InferenceState

nssavalues = src.ssavaluetypes::Int
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
stmt_info = Any[ nothing for i = 1:length(code) ]

n = length(code)
s_edges = Any[ nothing for i = 1:n ]
Expand Down Expand Up @@ -105,7 +107,7 @@ mutable struct InferenceState
InferenceParams(interp), result, linfo,
sp, slottypes, inmodule, 0,
src, get_world_counter(interp), min_valid, max_valid,
nargs, s_types, s_edges,
nargs, s_types, s_edges, stmt_info,
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
ssavalue_uses, throw_blocks,
Expand Down
6 changes: 4 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mutable struct OptimizationState
linfo::MethodInstance
calledges::Vector{Any}
src::CodeInfo
stmt_info::Vector{Any}
mod::Module
nargs::Int
world::UInt
Expand All @@ -31,7 +32,7 @@ mutable struct OptimizationState
src = frame.src
return new(params, frame.linfo,
s_edges::Vector{Any},
src, frame.mod, frame.nargs,
src, frame.stmt_info, frame.mod, frame.nargs,
frame.world, frame.min_valid, frame.max_valid,
frame.sptypes, frame.slottypes, false,
frame.matching_methods_cache, interp)
Expand All @@ -49,6 +50,7 @@ mutable struct OptimizationState
slottypes = Any[ Any for i = 1:nslots ]
end
s_edges = []
stmt_info = Any[nothing for i = 1:nssavalues]
# cache some useful state computations
toplevel = !isa(linfo.def, Method)
if !toplevel
Expand All @@ -61,7 +63,7 @@ mutable struct OptimizationState
end
return new(params, linfo,
s_edges::Vector{Any},
src, inmodule, nargs,
src, stmt_info, inmodule, nargs,
get_world_counter(), UInt(1), get_world_counter(),
sptypes_from_meth_instance(linfo), slottypes, false,
IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(), interp)
Expand Down
7 changes: 5 additions & 2 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
changemap = fill(0, length(code))
labelmap = coverage ? fill(0, length(code)) : changemap
prevloc = zero(eltype(ci.codelocs))
stmtinfo = sv.stmt_info
while idx <= length(code)
codeloc = ci.codelocs[idx]
if coverage && codeloc != prevloc && codeloc != 0
# insert a side-effect instruction before the current instruction in the same basic block
insert!(code, idx, Expr(:code_coverage_effect))
insert!(ci.codelocs, idx, codeloc)
insert!(ci.ssavaluetypes, idx, Nothing)
insert!(stmtinfo, idx, nothing)
changemap[oldidx] += 1
if oldidx < length(labelmap)
labelmap[oldidx + 1] += 1
Expand All @@ -61,6 +63,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
insert!(code, idx + 1, ReturnNode())
insert!(ci.codelocs, idx + 1, ci.codelocs[idx])
insert!(ci.ssavaluetypes, idx + 1, Union{})
insert!(stmtinfo, idx + 1, nothing)
if oldidx < length(changemap)
changemap[oldidx + 1] += 1
coverage && (labelmap[oldidx + 1] += 1)
Expand Down Expand Up @@ -98,10 +101,10 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
end
end
end
strip_trailing_junk!(ci, code, flags)
strip_trailing_junk!(ci, code, stmtinfo, flags)
cfg = compute_basic_blocks(code)
types = Any[]
stmts = InstructionStream(code, types, ci.codelocs, flags)
stmts = InstructionStream(code, types, stmtinfo, ci.codelocs, flags)
ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), sv.slottypes, meta, sv.sptypes)
return ir
end
Expand Down
59 changes: 44 additions & 15 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,22 @@ function process_simple!(ir::IRCode, idx::Int, params::OptimizationParams, world
return (sig, invoke_data)
end

# This is not currently called in the regular course, but may be needed
# if we ever want to re-run inlining again later in the pass pipeline after
# additional type information was discovered.
function recompute_method_matches(atype, sv)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function recompute_method_matches(atype, sv)
function recompute_method_matches(@nospecialize(atype), sv::OptimizationState)

# Regular case: Retrieve matching methods from cache (or compute them)
# World age does not need to be taken into account in the cache
# because it is forwarded from type inference through `sv.params`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems confusing to state. The world age is taken into account, because that's part of what sv.params means, and is partly why the cache is part of the params.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That comment was there before ;)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I should probably update it, because it's not the regular case anymore.

# in the case that the cache is nonempty, so it should be unchanged
# The max number of methods should be the same as in inference most
# of the time, and should not affect correctness otherwise.
(meth, min_valid, max_valid, ambig) =
matching_methods(atype, sv.matching_methods_cache, sv.params.MAX_METHODS, sv.world)
update_valid_age!(min_valid, max_valid, sv)
MethodMatchInfo(meth, ambig)
end

function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie)
todo = Any[]
Expand All @@ -1005,8 +1021,15 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)

stmt = ir.stmts[idx][:inst]
calltype = ir.stmts[idx][:type]
info = ir.stmts[idx][:info]
# Inference determined this couldn't be analyzed. Don't question it.
if info === false
continue
end

(sig, invoke_data) = r


# Ok, now figure out what method to call
if invoke_data !== nothing
inline_invoke!(ir, idx, sig, invoke_data, sv, todo)
Expand All @@ -1015,11 +1038,21 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)

nu = countunionsplit(sig.atypes)
if nu == 1 || nu > sv.params.MAX_UNION_SPLITTING
if !isa(info, MethodMatchInfo)
info = nothing
end
infos = Any[info]
splits = Any[sig.atype]
else
splits = Any[]
for union_sig in UnionSplitSignature(sig.atypes)
push!(splits, argtypes_to_type(union_sig))
if !isa(info, UnionSplitInfo)
splits = Any[]
for union_sig in UnionSplitSignature(sig.atypes)
push!(splits, argtypes_to_type(union_sig))
end
infos = Any[nothing for i = 1:length(splits)]
else
splits = info.sigs
infos = info.matches
end
end

Expand All @@ -1029,16 +1062,14 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
too_many = false
local meth
local fully_covered = true
for atype in splits
# Regular case: Retrieve matching methods from cache (or compute them)
# World age does not need to be taken into account in the cache
# because it is forwarded from type inference through `sv.params`
# in the case that the cache is nonempty, so it should be unchanged
# The max number of methods should be the same as in inference most
# of the time, and should not affect correctness otherwise.
(meth, min_valid, max_valid, ambig) =
matching_methods(atype, sv.matching_methods_cache, sv.params.MAX_METHODS, sv.world)
if meth === false || ambig
for i in 1:length(splits)
atype = splits[i]
info = infos[i]
if info === nothing
info = recompute_method_matches(atype, sv)
end
meth = info.applicable
if meth === false || info.ambig
# Too many applicable methods
# Or there is a (partial?) ambiguity
too_many = true
Expand All @@ -1055,8 +1086,6 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
else
only_method = false
end
update_valid_age!(min_valid, max_valid, sv)

for match in meth::Vector{Any}
(metharg, methsp, method) = (match[1]::Type, match[2]::SimpleVector, match[3]::Method)
# TODO: This could be better
Expand Down
14 changes: 8 additions & 6 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,25 +159,25 @@ end
struct InstructionStream
inst::Vector{Any}
type::Vector{Any}
info::Vector{Any}
line::Vector{Int32}
flag::Vector{UInt8}
end
function InstructionStream(len::Int)
insts = Array{Any}(undef, len)
types = Array{Any}(undef, len)
info = Array{Any}(undef, len)
fill!(info, nothing)
lines = fill(Int32(0), len)
flags = fill(0x00, len)
return InstructionStream(insts, types, lines, flags)
return InstructionStream(insts, types, info, lines, flags)
end
InstructionStream() = InstructionStream(0)
length(is::InstructionStream) = length(is.inst)
isempty(is::InstructionStream) = isempty(is.inst)
function add!(is::InstructionStream)
ninst = length(is) + 1
resize!(is.inst, ninst)
resize!(is.type, ninst)
resize!(is.line, ninst)
resize!(is.flag, ninst)
resize!(is, ninst)
return ninst
end
#function copy(is::InstructionStream) # unused
Expand All @@ -191,16 +191,17 @@ function resize!(stmts::InstructionStream, len)
old_length = length(stmts)
resize!(stmts.inst, len)
resize!(stmts.type, len)
resize!(stmts.info, len)
resize!(stmts.line, len)
resize!(stmts.flag, len)
for i in (old_length + 1):len
stmts.line[i] = 0
stmts.flag[i] = 0x00
stmts.info[i] = nothing
end
return stmts
end


struct Instruction
data::InstructionStream
idx::Int
Expand All @@ -220,6 +221,7 @@ end
function setindex!(is::InstructionStream, newval::Instruction, idx::Int)
is.inst[idx] = newval[:inst]
is.type[idx] = newval[:type]
is.info[idx] = newval[:info]
is.line[idx] = newval[:line]
is.flag[idx] = newval[:flag]
return is
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
ssavaluetypes = ci.ssavaluetypes
nstmts = length(code)
ssavaluetypes = ci.ssavaluetypes isa Vector{Any} ? copy(ci.ssavaluetypes) : Any[ Any for i = 1:(ci.ssavaluetypes::Int) ]
stmts = InstructionStream(code, ssavaluetypes, copy(ci.codelocs), copy(ci.ssaflags))
stmts = InstructionStream(code, ssavaluetypes, Any[nothing for i = 1:nstmts], copy(ci.codelocs), copy(ci.ssaflags))
ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), argtypes, Any[], sptypes)
return ir
end
Expand Down
4 changes: 3 additions & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,15 @@ function rename_uses!(ir::IRCode, ci::CodeInfo, idx::Int, @nospecialize(stmt), r
return fixemup!(stmt->true, stmt->renames[slot_id(stmt)], ir, ci, idx, stmt)
end

function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, flags::Vector{UInt8})
function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, info::Vector{Any}, flags::Vector{UInt8})
# Remove `nothing`s at the end, we don't handle them well
# (we expect the last instruction to be a terminator)
for i = length(code):-1:1
if code[i] !== nothing
resize!(code, i)
resize!(ci.ssavaluetypes, i)
resize!(ci.codelocs, i)
resize!(info, i)
resize!(flags, i)
break
end
Expand All @@ -200,6 +201,7 @@ function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, flags::Vector{UIn
push!(code, ReturnNode())
push!(ci.ssavaluetypes, Union{})
push!(ci.codelocs, 0)
push!(info, nothing)
push!(flags, 0x00)
end
nothing
Expand Down
19 changes: 19 additions & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
struct MethodMatchInfo
applicable::Any
ambig::Bool
end

struct UnionSplitInfo
# TODO: In principle we shouldn't have to store this, but could just
# recompute it using `switchtuple` union. However, it is not the case
# that if T == S, then switchtupleunion(T) == switchtupleunion(S), e.g. for
# T = Tuple{Tuple{Union{Float64, Int64},String}}
# S = Tuple{Union{Tuple{Float64, String}, Tuple{Int64, String}}}
sigs::Vector{Any}
matches::Vector{MethodMatchInfo}
end

struct CallMeta
rt::Any
info::Any
end
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, v
if contains_is(argtypes_vec, Union{})
return Const(Union{})
end
rt = abstract_call(interp, nothing, argtypes_vec, vtypes, sv, -1)
rt = abstract_call(interp, nothing, argtypes_vec, vtypes, sv, -1).rt
if isa(rt, Const)
# output was computed to be constant
return Const(typeof(rt.val))
Expand Down
1 change: 1 addition & 0 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ function type_annotate!(sv::InferenceState)
deleteat!(states, i)
deleteat!(src.ssavaluetypes, i)
deleteat!(src.codelocs, i)
deleteat!(sv.stmt_info, i)
nexpr -= 1
if oldidx < length(changemap)
changemap[oldidx + 1] = -1
Expand Down
1 change: 1 addition & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6220,6 +6220,7 @@ static std::pair<std::unique_ptr<Module>, jl_llvm_functions_t>
};
std::vector<DebugLineTable> linetable;
{
assert(jl_is_array(src->linetable));
size_t nlocs = jl_array_len(src->linetable);
std::map<std::tuple<StringRef, StringRef>, DISubprogram*> subprograms;
linetable.resize(nlocs + 1);
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/ssair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ let cfg = CFG(BasicBlock[
make_bb([0, 1, 2] , [5] ), # 0 predecessor should be preserved
make_bb([2, 3] , [] ),
], Int[])
insts = Compiler.InstructionStream([], [], Int32[], UInt8[])
insts = Compiler.InstructionStream([], [], Any[], Int32[], UInt8[])
code = Compiler.IRCode(insts, cfg, LineInfoNode[], [], [], [])
compact = Compiler.IncrementalCompact(code, true)
@test length(compact.result_bbs) == 4 && 0 in compact.result_bbs[3].preds
Expand Down