-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Add inference->optimize analysis forwarding mechanism #36508
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
This change attempts to be a solution to the generalized problem encountered in #36169. In short, we do a whole bunch of analysis during inference to figure out the final type of an expression, but sometimes, we may need intermediate results that were computed along the way. So far, we don't really have a great place to put those results, so we end up having to re-compute them during the optimization phase. That's what #36169 did, but is clearly not a scalable solution. I encountered the exact same issue while working on a new AD compiler plugin, that needs to do a whole bunch of work during inference to determine what to do (e.g. call a primitive, recurse, or increase the derivative level), and optimizations need to have access to this information. This PR adds an additional `info` field to CodeInfo and IRCode that can be used to forward this kind of information. As a proof of concept, it forwards method match info from inference to inlining (we do already cache these, so there's little performance gain from this per se - it's more to exercise the infrastructure). The plan is to do an alternative fix to #36169 on top of this as the next step, but I figured I'd open it up for discussion first.
- Loading branch information
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -994,6 +994,22 @@ function process_simple!(ir::IRCode, idx::Int, params::OptimizationParams, world | |
return (sig, invoke_data) | ||
end | ||
|
||
# This is not currently called in the regular course, but may be needed | ||
# if we ever want to re-run inlining again later in the pass pipeline after | ||
# additional type information was discovered. | ||
function recompute_method_matches(atype, sv) | ||
# Regular case: Retrieve matching methods from cache (or compute them) | ||
# World age does not need to be taken into account in the cache | ||
# because it is forwarded from type inference through `sv.params` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems confusing to state. The world age is taken into account, because that's part of what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That comment was there before ;) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Though I should probably update it, because it's not the regular case anymore. |
||
# in the case that the cache is nonempty, so it should be unchanged | ||
# The max number of methods should be the same as in inference most | ||
# of the time, and should not affect correctness otherwise. | ||
(meth, min_valid, max_valid, ambig) = | ||
matching_methods(atype, sv.matching_methods_cache, sv.params.MAX_METHODS, sv.world) | ||
update_valid_age!(min_valid, max_valid, sv) | ||
MethodMatchInfo(meth, ambig) | ||
end | ||
|
||
function assemble_inline_todo!(ir::IRCode, sv::OptimizationState) | ||
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie) | ||
todo = Any[] | ||
|
@@ -1005,8 +1021,15 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState) | |
|
||
stmt = ir.stmts[idx][:inst] | ||
calltype = ir.stmts[idx][:type] | ||
info = ir.stmts[idx][:info] | ||
# Inference determined this couldn't be analyzed. Don't question it. | ||
if info === false | ||
continue | ||
end | ||
|
||
(sig, invoke_data) = r | ||
|
||
|
||
# Ok, now figure out what method to call | ||
if invoke_data !== nothing | ||
inline_invoke!(ir, idx, sig, invoke_data, sv, todo) | ||
|
@@ -1015,11 +1038,21 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState) | |
|
||
nu = countunionsplit(sig.atypes) | ||
if nu == 1 || nu > sv.params.MAX_UNION_SPLITTING | ||
if !isa(info, MethodMatchInfo) | ||
info = nothing | ||
end | ||
infos = Any[info] | ||
splits = Any[sig.atype] | ||
else | ||
splits = Any[] | ||
for union_sig in UnionSplitSignature(sig.atypes) | ||
push!(splits, argtypes_to_type(union_sig)) | ||
if !isa(info, UnionSplitInfo) | ||
splits = Any[] | ||
for union_sig in UnionSplitSignature(sig.atypes) | ||
push!(splits, argtypes_to_type(union_sig)) | ||
end | ||
infos = Any[nothing for i = 1:length(splits)] | ||
else | ||
splits = info.sigs | ||
infos = info.matches | ||
end | ||
end | ||
|
||
|
@@ -1029,16 +1062,14 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState) | |
too_many = false | ||
local meth | ||
local fully_covered = true | ||
for atype in splits | ||
# Regular case: Retrieve matching methods from cache (or compute them) | ||
# World age does not need to be taken into account in the cache | ||
# because it is forwarded from type inference through `sv.params` | ||
# in the case that the cache is nonempty, so it should be unchanged | ||
# The max number of methods should be the same as in inference most | ||
# of the time, and should not affect correctness otherwise. | ||
(meth, min_valid, max_valid, ambig) = | ||
matching_methods(atype, sv.matching_methods_cache, sv.params.MAX_METHODS, sv.world) | ||
if meth === false || ambig | ||
for i in 1:length(splits) | ||
atype = splits[i] | ||
info = infos[i] | ||
if info === nothing | ||
info = recompute_method_matches(atype, sv) | ||
end | ||
meth = info.applicable | ||
if meth === false || info.ambig | ||
# Too many applicable methods | ||
# Or there is a (partial?) ambiguity | ||
too_many = true | ||
|
@@ -1055,8 +1086,6 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState) | |
else | ||
only_method = false | ||
end | ||
update_valid_age!(min_valid, max_valid, sv) | ||
|
||
for match in meth::Vector{Any} | ||
(metharg, methsp, method) = (match[1]::Type, match[2]::SimpleVector, match[3]::Method) | ||
# TODO: This could be better | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
struct MethodMatchInfo | ||
applicable::Any | ||
ambig::Bool | ||
end | ||
|
||
struct UnionSplitInfo | ||
# TODO: In principle we shouldn't have to store this, but could just | ||
# recompute it using `switchtuple` union. However, it is not the case | ||
# that if T == S, then switchtupleunion(T) == switchtupleunion(S), e.g. for | ||
# T = Tuple{Tuple{Union{Float64, Int64},String}} | ||
# S = Tuple{Union{Tuple{Float64, String}, Tuple{Int64, String}}} | ||
sigs::Vector{Any} | ||
matches::Vector{MethodMatchInfo} | ||
end | ||
|
||
struct CallMeta | ||
rt::Any | ||
info::Any | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.