From b6f9a516bf9b602cfbafe5d978412b63dba617a4 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Thu, 8 Jun 2017 08:57:48 -0500 Subject: [PATCH] Decide inline-worthiness based on a more nuanced cost model This switches to a model in which we estimate the runtime of the function; if it's fast, then we should inline it. The estimate of runtime is extremely crude, and doesn't even take loops into account. --- base/inference.jl | 380 +++++++++++++++++++++++++++++----------------- test/core.jl | 6 +- 2 files changed, 246 insertions(+), 140 deletions(-) diff --git a/base/inference.jl b/base/inference.jl index 549e7b6b4b895..99d195b2314e2 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -14,6 +14,8 @@ struct InferenceParams # optimization inlining::Bool + inline_cost_threshold::Int # number of CPU cycles beyond which it's not worth inlining + inline_nonleaf_penalty::Int # penalty for runtime/vararg method lookup # parameters limiting potentially-infinite types (configurable) MAX_METHODS::Int @@ -26,14 +28,17 @@ struct InferenceParams # reasonable defaults function InferenceParams(world::UInt; inlining::Bool = inlining_enabled(), + inline_cost_threshold::Int = 200, + inline_nonleaf_penalty::Int = 100, max_methods::Int = 4, tupletype_len::Int = 15, tuple_depth::Int = 4, tuple_splat::Int = 16, union_splitting::Int = 4, apply_union_enum::Int = 8) - return new(world, inlining, max_methods, tupletype_len, - tuple_depth, tuple_splat, union_splitting, apply_union_enum) + return new(world, inlining, inline_cost_threshold, inline_nonleaf_penalty, + max_methods, tupletype_len, + tuple_depth, tuple_splat, union_splitting, apply_union_enum) end end @@ -374,17 +379,22 @@ isconstType(t::ANY) = isType(t) && (isleaftype(t.parameters[1]) || t.parameters[ const IInf = typemax(Int) # integer infinity const n_ifunc = reinterpret(Int32,arraylen)+1 const t_ifunc = Array{Tuple{Int,Int,Any},1}(n_ifunc) +const t_ifunc_cost = Array{Int,1}(n_ifunc) const t_ffunc_key = Array{Function,1}(0) const t_ffunc_val = Array{Tuple{Int,Int,Any},1}(0) -function add_tfunc(f::IntrinsicFunction, minarg::Int, maxarg::Int, tfunc::ANY) - t_ifunc[reinterpret(Int32,f)+1] = (minarg, maxarg, tfunc) +const t_ffunc_cost = Array{Int,1}(0) +function add_tfunc(f::IntrinsicFunction, minarg::Int, maxarg::Int, tfunc::ANY, cost::Int) + idx = reinterpret(Int32,f)+1 + t_ifunc[idx] = (minarg, maxarg, tfunc) + t_ifunc_cost[idx] = cost end -function add_tfunc(f::Function, minarg::Int, maxarg::Int, tfunc::ANY) +function add_tfunc(f::Function, minarg::Int, maxarg::Int, tfunc::ANY, cost::Int) push!(t_ffunc_key, f) push!(t_ffunc_val, (minarg, maxarg, tfunc)) + push!(t_ffunc_cost, cost) end -add_tfunc(throw, 1, 1, (x::ANY) -> Bottom) +add_tfunc(throw, 1, 1, (x::ANY) -> Bottom, 0) # the inverse of typeof_tfunc function instanceof_tfunc(t::ANY) @@ -427,104 +437,103 @@ function fptosi_tfunc(x::ANY) end ## conversion ## -add_tfunc(bitcast, 2, 2, bitcast_tfunc) -add_tfunc(sext_int, 2, 2, bitcast_tfunc) -add_tfunc(zext_int, 2, 2, bitcast_tfunc) -add_tfunc(trunc_int, 2, 2, bitcast_tfunc) -add_tfunc(fptoui, 1, 2, fptoui_tfunc) -add_tfunc(fptosi, 1, 2, fptosi_tfunc) -add_tfunc(uitofp, 2, 2, bitcast_tfunc) -add_tfunc(sitofp, 2, 2, bitcast_tfunc) -add_tfunc(fptrunc, 2, 2, bitcast_tfunc) -add_tfunc(fpext, 2, 2, bitcast_tfunc) +add_tfunc(bitcast, 2, 2, bitcast_tfunc, 1) +add_tfunc(sext_int, 2, 2, bitcast_tfunc, 1) +add_tfunc(zext_int, 2, 2, bitcast_tfunc, 1) +add_tfunc(trunc_int, 2, 2, bitcast_tfunc, 1) +add_tfunc(fptoui, 1, 2, fptoui_tfunc, 1) +add_tfunc(fptosi, 1, 2, fptosi_tfunc, 1) +add_tfunc(uitofp, 2, 2, bitcast_tfunc, 1) +add_tfunc(sitofp, 2, 2, bitcast_tfunc, 1) +add_tfunc(fptrunc, 2, 2, bitcast_tfunc, 1) +add_tfunc(fpext, 2, 2, bitcast_tfunc, 1) ## checked conversion ## -add_tfunc(checked_trunc_sint, 2, 2, bitcast_tfunc) -add_tfunc(checked_trunc_uint, 2, 2, bitcast_tfunc) -add_tfunc(check_top_bit, 1, 1, math_tfunc) +add_tfunc(checked_trunc_sint, 2, 2, bitcast_tfunc, 3) +add_tfunc(checked_trunc_uint, 2, 2, bitcast_tfunc, 3) +add_tfunc(check_top_bit, 1, 1, math_tfunc, 2) ## arithmetic ## -add_tfunc(neg_int, 1, 1, math_tfunc) -add_tfunc(add_int, 2, 2, math_tfunc) -add_tfunc(sub_int, 2, 2, math_tfunc) -add_tfunc(mul_int, 2, 2, math_tfunc) -add_tfunc(sdiv_int, 2, 2, math_tfunc) -add_tfunc(udiv_int, 2, 2, math_tfunc) -add_tfunc(srem_int, 2, 2, math_tfunc) -add_tfunc(urem_int, 2, 2, math_tfunc) -add_tfunc(neg_float, 1, 1, math_tfunc) -add_tfunc(add_float, 2, 2, math_tfunc) -add_tfunc(sub_float, 2, 2, math_tfunc) -add_tfunc(mul_float, 2, 2, math_tfunc) -add_tfunc(div_float, 2, 2, math_tfunc) -add_tfunc(rem_float, 2, 2, math_tfunc) -add_tfunc(fma_float, 3, 3, math_tfunc) -add_tfunc(muladd_float, 3, 3, math_tfunc) +add_tfunc(neg_int, 1, 1, math_tfunc, 1) +add_tfunc(add_int, 2, 2, math_tfunc, 1) +add_tfunc(sub_int, 2, 2, math_tfunc, 1) +add_tfunc(mul_int, 2, 2, math_tfunc, 4) +add_tfunc(sdiv_int, 2, 2, math_tfunc, 30) +add_tfunc(udiv_int, 2, 2, math_tfunc, 30) +add_tfunc(srem_int, 2, 2, math_tfunc, 30) +add_tfunc(urem_int, 2, 2, math_tfunc, 30) +add_tfunc(neg_float, 1, 1, math_tfunc, 1) +add_tfunc(add_float, 2, 2, math_tfunc, 1) +add_tfunc(sub_float, 2, 2, math_tfunc, 1) +add_tfunc(mul_float, 2, 2, math_tfunc, 4) +add_tfunc(div_float, 2, 2, math_tfunc, 20) +add_tfunc(rem_float, 2, 2, math_tfunc, 20) +add_tfunc(fma_float, 3, 3, math_tfunc, 5) +add_tfunc(muladd_float, 3, 3, math_tfunc, 5) ## fast arithmetic ## -add_tfunc(neg_float_fast, 1, 1, math_tfunc) -add_tfunc(add_float_fast, 2, 2, math_tfunc) -add_tfunc(sub_float_fast, 2, 2, math_tfunc) -add_tfunc(mul_float_fast, 2, 2, math_tfunc) -add_tfunc(div_float_fast, 2, 2, math_tfunc) -add_tfunc(rem_float_fast, 2, 2, math_tfunc) +add_tfunc(neg_float_fast, 1, 1, math_tfunc, 1) +add_tfunc(add_float_fast, 2, 2, math_tfunc, 1) +add_tfunc(sub_float_fast, 2, 2, math_tfunc, 1) +add_tfunc(mul_float_fast, 2, 2, math_tfunc, 2) +add_tfunc(div_float_fast, 2, 2, math_tfunc, 10) +add_tfunc(rem_float_fast, 2, 2, math_tfunc, 10) ## bitwise operators ## -add_tfunc(and_int, 2, 2, math_tfunc) -add_tfunc(or_int, 2, 2, math_tfunc) -add_tfunc(xor_int, 2, 2, math_tfunc) -add_tfunc(not_int, 1, 1, math_tfunc) -add_tfunc(shl_int, 2, 2, math_tfunc) -add_tfunc(lshr_int, 2, 2, math_tfunc) -add_tfunc(ashr_int, 2, 2, math_tfunc) -add_tfunc(bswap_int, 1, 1, math_tfunc) -add_tfunc(ctpop_int, 1, 1, math_tfunc) -add_tfunc(ctlz_int, 1, 1, math_tfunc) -add_tfunc(cttz_int, 1, 1, math_tfunc) -add_tfunc(checked_sdiv_int, 2, 2, math_tfunc) -add_tfunc(checked_udiv_int, 2, 2, math_tfunc) -add_tfunc(checked_srem_int, 2, 2, math_tfunc) -add_tfunc(checked_urem_int, 2, 2, math_tfunc) +add_tfunc(and_int, 2, 2, math_tfunc, 1) +add_tfunc(or_int, 2, 2, math_tfunc, 1) +add_tfunc(xor_int, 2, 2, math_tfunc, 1) +add_tfunc(not_int, 1, 1, math_tfunc, 1) +add_tfunc(shl_int, 2, 2, math_tfunc, 1) +add_tfunc(lshr_int, 2, 2, math_tfunc, 1) +add_tfunc(ashr_int, 2, 2, math_tfunc, 1) +add_tfunc(bswap_int, 1, 1, math_tfunc, 1) +add_tfunc(ctpop_int, 1, 1, math_tfunc, 1) +add_tfunc(ctlz_int, 1, 1, math_tfunc, 1) +add_tfunc(cttz_int, 1, 1, math_tfunc, 1) +add_tfunc(checked_sdiv_int, 2, 2, math_tfunc, 40) +add_tfunc(checked_udiv_int, 2, 2, math_tfunc, 40) +add_tfunc(checked_srem_int, 2, 2, math_tfunc, 40) +add_tfunc(checked_urem_int, 2, 2, math_tfunc, 40) ## functions ## -add_tfunc(abs_float, 1, 1, math_tfunc) -add_tfunc(copysign_float, 2, 2, math_tfunc) -add_tfunc(flipsign_int, 2, 2, math_tfunc) -add_tfunc(ceil_llvm, 1, 1, math_tfunc) -add_tfunc(floor_llvm, 1, 1, math_tfunc) -add_tfunc(trunc_llvm, 1, 1, math_tfunc) -add_tfunc(rint_llvm, 1, 1, math_tfunc) -add_tfunc(sqrt_llvm, 1, 1, math_tfunc) -add_tfunc(sqrt_llvm_fast, 1, 1, math_tfunc) +add_tfunc(abs_float, 1, 1, math_tfunc, 5) +add_tfunc(copysign_float, 2, 2, math_tfunc, 5) +add_tfunc(flipsign_int, 2, 2, math_tfunc, 1) +add_tfunc(ceil_llvm, 1, 1, math_tfunc, 10) +add_tfunc(floor_llvm, 1, 1, math_tfunc, 10) +add_tfunc(trunc_llvm, 1, 1, math_tfunc, 10) +add_tfunc(rint_llvm, 1, 1, math_tfunc, 10) +add_tfunc(sqrt_llvm, 1, 1, math_tfunc, 20) ## same-type comparisons ## cmp_tfunc(x::ANY, y::ANY) = Bool -add_tfunc(eq_int, 2, 2, cmp_tfunc) -add_tfunc(ne_int, 2, 2, cmp_tfunc) -add_tfunc(slt_int, 2, 2, cmp_tfunc) -add_tfunc(ult_int, 2, 2, cmp_tfunc) -add_tfunc(sle_int, 2, 2, cmp_tfunc) -add_tfunc(ule_int, 2, 2, cmp_tfunc) -add_tfunc(eq_float, 2, 2, cmp_tfunc) -add_tfunc(ne_float, 2, 2, cmp_tfunc) -add_tfunc(lt_float, 2, 2, cmp_tfunc) -add_tfunc(le_float, 2, 2, cmp_tfunc) -add_tfunc(fpiseq, 2, 2, cmp_tfunc) -add_tfunc(fpislt, 2, 2, cmp_tfunc) -add_tfunc(eq_float_fast, 2, 2, cmp_tfunc) -add_tfunc(ne_float_fast, 2, 2, cmp_tfunc) -add_tfunc(lt_float_fast, 2, 2, cmp_tfunc) -add_tfunc(le_float_fast, 2, 2, cmp_tfunc) +add_tfunc(eq_int, 2, 2, cmp_tfunc, 1) +add_tfunc(ne_int, 2, 2, cmp_tfunc, 1) +add_tfunc(slt_int, 2, 2, cmp_tfunc, 1) +add_tfunc(ult_int, 2, 2, cmp_tfunc, 1) +add_tfunc(sle_int, 2, 2, cmp_tfunc, 1) +add_tfunc(ule_int, 2, 2, cmp_tfunc, 1) +add_tfunc(eq_float, 2, 2, cmp_tfunc, 2) +add_tfunc(ne_float, 2, 2, cmp_tfunc, 2) +add_tfunc(lt_float, 2, 2, cmp_tfunc, 2) +add_tfunc(le_float, 2, 2, cmp_tfunc, 2) +add_tfunc(fpiseq, 2, 2, cmp_tfunc, 1) +add_tfunc(fpislt, 2, 2, cmp_tfunc, 1) +add_tfunc(eq_float_fast, 2, 2, cmp_tfunc, 1) +add_tfunc(ne_float_fast, 2, 2, cmp_tfunc, 1) +add_tfunc(lt_float_fast, 2, 2, cmp_tfunc, 1) +add_tfunc(le_float_fast, 2, 2, cmp_tfunc, 1) ## checked arithmetic ## chk_tfunc(x::ANY, y::ANY) = Tuple{widenconst(x), Bool} -add_tfunc(checked_sadd_int, 2, 2, chk_tfunc) -add_tfunc(checked_uadd_int, 2, 2, chk_tfunc) -add_tfunc(checked_ssub_int, 2, 2, chk_tfunc) -add_tfunc(checked_usub_int, 2, 2, chk_tfunc) -add_tfunc(checked_smul_int, 2, 2, chk_tfunc) -add_tfunc(checked_umul_int, 2, 2, chk_tfunc) +add_tfunc(checked_sadd_int, 2, 2, chk_tfunc, 10) +add_tfunc(checked_uadd_int, 2, 2, chk_tfunc, 10) +add_tfunc(checked_ssub_int, 2, 2, chk_tfunc, 10) +add_tfunc(checked_usub_int, 2, 2, chk_tfunc, 10) +add_tfunc(checked_smul_int, 2, 2, chk_tfunc, 10) +add_tfunc(checked_umul_int, 2, 2, chk_tfunc, 10) ## other, misc intrinsics ## add_tfunc(Core.Intrinsics.llvmcall, 3, IInf, - (fptr::ANY, rt::ANY, at::ANY, a...) -> instanceof_tfunc(rt)) + (fptr::ANY, rt::ANY, at::ANY, a...) -> instanceof_tfunc(rt), 10) cglobal_tfunc(fptr::ANY) = Ptr{Void} cglobal_tfunc(fptr::ANY, t::ANY) = (isType(t) ? Ptr{t.parameters[1]} : Ptr) cglobal_tfunc(fptr::ANY, t::Const) = (isa(t.val, Type) ? Ptr{t.val} : Ptr) -add_tfunc(Core.Intrinsics.cglobal, 1, 2, cglobal_tfunc) +add_tfunc(Core.Intrinsics.cglobal, 1, 2, cglobal_tfunc, 5) add_tfunc(Core.Intrinsics.select_value, 3, 3, function (cnd::ANY, x::ANY, y::ANY) if isa(cnd, Const) @@ -538,7 +547,7 @@ add_tfunc(Core.Intrinsics.select_value, 3, 3, end (Bool ⊑ cnd) || return Bottom return tmerge(x, y) - end) + end, 1) add_tfunc(===, 2, 2, function (x::ANY, y::ANY) if isa(x, Const) && isa(y, Const) @@ -557,7 +566,7 @@ add_tfunc(===, 2, 2, x.val === true && return y end return Bool - end) + end, 1) function isdefined_tfunc(args...) arg1 = args[1] if isa(arg1, Const) @@ -598,8 +607,8 @@ function isdefined_tfunc(args...) Bool end # TODO change IInf to 2 when deprecation is removed -add_tfunc(isdefined, 1, IInf, isdefined_tfunc) -add_tfunc(Core.sizeof, 1, 1, x->Int) +add_tfunc(isdefined, 1, IInf, isdefined_tfunc, 1) +add_tfunc(Core.sizeof, 1, 1, x->Int, 0) add_tfunc(nfields, 1, 1, function (x::ANY) isa(x,Const) && return Const(nfields(x.val)) @@ -610,11 +619,11 @@ add_tfunc(nfields, 1, 1, return Const(length(x.types)) end return Int - end) -add_tfunc(Core._expr, 1, IInf, (args...)->Expr) -add_tfunc(applicable, 1, IInf, (f::ANY, args...)->Bool) -add_tfunc(Core.Intrinsics.arraylen, 1, 1, x->Int) -add_tfunc(arraysize, 2, 2, (a::ANY, d::ANY)->Int) + end, 0) +add_tfunc(Core._expr, 1, IInf, (args...)->Expr, 100) +add_tfunc(applicable, 1, IInf, (f::ANY, args...)->Bool, 100) +add_tfunc(Core.Intrinsics.arraylen, 1, 1, x->Int, 4) +add_tfunc(arraysize, 2, 2, (a::ANY, d::ANY)->Int, 4) add_tfunc(pointerref, 3, 3, function (a::ANY, i::ANY, align::ANY) a = widenconst(a) @@ -629,8 +638,8 @@ add_tfunc(pointerref, 3, 3, end end return Any - end) -add_tfunc(pointerset, 4, 4, (a::ANY, v::ANY, i::ANY, align::ANY) -> a) + end, 4) +add_tfunc(pointerset, 4, 4, (a::ANY, v::ANY, i::ANY, align::ANY) -> a, 5) function typeof_tfunc(t::ANY) if isa(t, Const) @@ -664,7 +673,7 @@ function typeof_tfunc(t::ANY) return DataType # typeof(anything)::DataType end end -add_tfunc(typeof, 1, 1, typeof_tfunc) +add_tfunc(typeof, 1, 1, typeof_tfunc, 0) add_tfunc(typeassert, 2, 2, function (v::ANY, t::ANY) t = instanceof_tfunc(t) @@ -681,7 +690,7 @@ add_tfunc(typeassert, 2, 2, return v end return typeintersect(v, t) - end) + end, 4) add_tfunc(isa, 2, 2, function (v::ANY, t::ANY) t = instanceof_tfunc(t) @@ -694,7 +703,7 @@ add_tfunc(isa, 2, 2, end # TODO: handle non-leaftype(t) by testing against lower and upper bounds return Bool - end) + end, 0) add_tfunc(issubtype, 2, 2, function (a::ANY, b::ANY) if (isa(a,Const) || isType(a)) && (isa(b,Const) || isType(b)) @@ -705,7 +714,7 @@ add_tfunc(issubtype, 2, 2, end end return Bool - end) + end, 0) function type_depth(t::ANY) if t === Bottom @@ -1180,8 +1189,8 @@ function getfield_tfunc(s00::ANY, name) # in the current type system return rewrap_unionall(limit_type_depth(R, MAX_TYPE_DEPTH), s00) end -add_tfunc(getfield, 2, 2, (s::ANY, name::ANY) -> getfield_tfunc(s, name)) -add_tfunc(setfield!, 3, 3, (o::ANY, f::ANY, v::ANY) -> v) +add_tfunc(getfield, 2, 2, (s::ANY, name::ANY) -> getfield_tfunc(s, name), 1) +add_tfunc(setfield!, 3, 3, (o::ANY, f::ANY, v::ANY) -> v, 3) function fieldtype_tfunc(s0::ANY, name::ANY) if s0 === Any || s0 === Type || DataType ⊑ s0 || UnionAll ⊑ s0 return Type @@ -1241,7 +1250,7 @@ function fieldtype_tfunc(s0::ANY, name::ANY) end return Type{<:ft} end -add_tfunc(fieldtype, 2, 2, fieldtype_tfunc) +add_tfunc(fieldtype, 2, 2, fieldtype_tfunc, 0) function valid_tparam(x::ANY) if isa(x,Tuple) @@ -1368,7 +1377,7 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...) end return ans end -add_tfunc(apply_type, 1, IInf, apply_type_tfunc) +add_tfunc(apply_type, 1, IInf, apply_type_tfunc, 10) @pure function type_typeof(v::ANY) if isa(v, Type) @@ -3242,10 +3251,10 @@ end #### finalize and record the result of running type inference #### -function isinlineable(m::Method, src::CodeInfo) +function isinlineable(m::Method, src::CodeInfo, params::InferenceParams) # compute the cost (size) of inlining this code inlineable = false - cost = 1000 + cost_threshold = params.inline_cost_threshold if m.module === _topmod(m.module) name = m.name sig = m.sig @@ -3255,11 +3264,11 @@ function isinlineable(m::Method, src::CodeInfo) inlineable = true elseif (name === :next || name === :done || name === :unsafe_convert || name === :cconvert) - cost ÷= 4 + cost_threshold *= 4 end end if !inlineable - inlineable = inline_worthy_stmts(src.code, cost) + inlineable = inline_worthy_stmts(src.code, params, cost_threshold) end return inlineable end @@ -3356,7 +3365,7 @@ function optimize(me::InferenceState) if force_noinline me.src.inlineable = false elseif !me.src.inlineable && isa(def, Method) - me.src.inlineable = isinlineable(def, me.src) + me.src.inlineable = isinlineable(def, me.src, me.params) end me.src.inferred = true nothing @@ -4343,9 +4352,9 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference # tuples, since we want to be able to inline those functions to # avoid the tuple allocation. current_stmts = vcat(sv.src.code, pending_stmts) - if inline_worthy_stmts(current_stmts) + if inline_worthy_stmts(current_stmts, sv.params) append!(current_stmts, ast) - if !inline_worthy_stmts(current_stmts) + if !inline_worthy_stmts(current_stmts, sv.params) return invoke_NF(argexprs0, e.typ, atypes, sv, atype_unlimited, invoke_data) end @@ -4411,7 +4420,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference end end free = effect_free(aei, sv.src, sv.mod, true) - if ((occ==0 && aeitype===Bottom) || (occ > 1 && !inline_worthy(aei, occ*2000)) || + if ((occ==0 && aeitype===Bottom) || (occ > 1 && !inline_worthy(aei, sv.params)) || #, occ*2000)) || (affect_free && !free) || (!affect_free && !effect_free(aei, sv.src, sv.mod, false))) if occ != 0 vnew = newvar!(sv, aeitype) @@ -4590,38 +4599,131 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference return (expr, stmts) end -inline_worthy(body::ANY, cost::Integer) = true - -# should the expression be part of the inline cost model -function inline_ignore(ex::ANY) - if isa(ex, LineNumberNode) || ex === nothing - return true +inline_worthy(body::ANY, params::InferenceParams, + cost_threshold::Integer=params.inline_cost_threshold) = true + +## Computing the cost of a function body + +# saturating sum (inputs are nonnegative), prevents overflow with typemax(Int) below +plus_saturate(x, y) = max(x, y, x+y) +# known return type +isknowntype(T) = (T == Union{}) || isleaftype(T) + +statement_cost(::Any, params::InferenceParams, line::Int) = 0 +# statement_cost(mi::MethodInstance, params::InferenceParams, line::Int) = +# (isknowntype(mi.specTypes) & isknowntype(mi.rettype)) ? 0 : params.inline_nonleaf_penalty +statement_cost(qn::QuoteNode, params::InferenceParams, line::Int) = + statement_cost(qn.value, params, line) +# function statement_cost(gn::GotoNode, params::InferenceParams, line::Int) +# # If we jump backwards, it's a sign of a loop, and we don't want +# # to inline functions with loops because the cost might be wildly +# # underestimated if there are many iterations +# return gn.label < line ? typemax(Int) : 0 +# end +function statement_cost(ex::Expr, params::InferenceParams, line::Int) + head = ex.head + if is_meta_expr(ex) || head == :copyast # not sure if copyast is right + return 0 end - return isa(ex, Expr) && is_meta_expr(ex::Expr) + argcost = 0 + for a in ex.args + argcost = plus_saturate(argcost, statement_cost(a, params, line)) + end + if head == :return || head == :(=) + return argcost + end + if head == :call + callfunc = ex.args[1] + if isa(callfunc, SSAValue) || isa(callfunc, Slot) + # Not quite sure what to do here. Would it be better + # to be able to look these up, or are these just + # constants? + return argcost + end + if isa(callfunc, Type) || isa(callfunc, Function) || + isa(callfunc, QuoteNode) || + isa(callfunc, Symbol) || isa(callfunc, Expr) + return argcost + end + if isa(callfunc, GlobalRef) + # Here is where the main cost accounting occurs, + # reading out the cost of intrinsics or low-level functions + grfunc = abstract_eval_global(callfunc.mod, callfunc.name) + if isa(grfunc, Type) + return argcost + end + if isa(grfunc, Const) + f = (grfunc::Const).val + if isa(f, IntrinsicFunction) + iidx = Int(reinterpret(Int32, f::IntrinsicFunction)) + 1 + if !isassigned(t_ifunc_cost, iidx) + # unknown/unhandled intrinsic + return plus_saturate(argcost, params.inline_nonleaf_penalty) + end + return plus_saturate(argcost, t_ifunc_cost[iidx]) + end + if isa(f, Function) + # The efficiency of operations like a[i] and s.b + # depend strongly on whether the result can be + # inferred, so check ex.typ + if f == Main.Core.getfield || f == Main.Core.tuple + return plus_saturate(argcost, isknowntype(ex.typ) ? 1 : params.inline_nonleaf_penalty) + elseif f == Main.Core.arrayref + return plus_saturate(argcost, isknowntype(ex.typ) ? 4 : params.inline_nonleaf_penalty) + end + fidx = findfirst(t_ffunc_key, f::Function) + if fidx == 0 + # unknown/unhandled builtin or anonymous function + # Use the generic cost of a direct function call + return plus_saturate(argcost, 20) + end + return plus_saturate(argcost, t_ffunc_cost[fidx]) + end + if isa(f, Type) || isa(f, UnionAll) + return argcost + end + stmt_cost_error("unhandled f, ", f, " of type ", typeof(f)) + end + stmt_cost_error("unhandled grfunc, ", grfunc, " of type ", typeof(grfunc)) + end + stmt_cost_error("unhandled callfunc, ", callfunc, " with type ", typeof(callfunc)) + elseif head == :foreigncall || ex.head == :invoke + return plus_saturate(20, argcost) + elseif head == :llvmcall + return plus_saturate(10, argcost) # a wild guess at typical cost + # elseif head == :gotoifnot + # return ex.args[2] < line ? typemax(Int) : argcost + elseif (head == :&) + return plus_saturate(length(ex.args), argcost) + end + argcost end -function inline_worthy_stmts(stmts::Vector{Any}, cost::Integer = 1000) +function stmt_cost_error(args...) + println(args...) + error("statement_cost is broken") +end + +function inline_worthy_stmts(stmts::Vector{Any}, params::InferenceParams, + cost_threshold::Integer=params.inline_cost_threshold) body = Expr(:block) body.args = stmts - return inline_worthy(body, cost) + return inline_worthy(body, params, cost_threshold) end -function inline_worthy(body::Expr, cost::Integer=1000) # precondition: 0 < cost; nominal cost = 1000 - symlim = 1000 + 5_000_000 ÷ cost - nstmt = 0 - for stmt in body.args - if !(isa(stmt, SSAValue) || inline_ignore(stmt)) - nstmt += 1 - end - end - if nstmt < (symlim + 500) ÷ 1000 - symlim *= 16 - symlim ÷= 1000 - if occurs_more(body, e->!inline_ignore(e), symlim) < symlim - return true +function inline_worthy(body::Expr, params::InferenceParams, + cost_threshold::Integer=params.inline_cost_threshold) + bodycost = 0 + if body.head == :block + for line = 1:length(body.args) + stmt = body.args[line] + thiscost = statement_cost(stmt, params, line) + bodycost = plus_saturate(bodycost, thiscost) end + else + bodycost = statement_cost(body, params, 1) end - return false + bodycost <= cost_threshold end ssavalue_increment(body::ANY, incr) = body diff --git a/test/core.jl b/test/core.jl index a4f3f02d85a63..03562c23ccd8a 100644 --- a/test/core.jl +++ b/test/core.jl @@ -2215,6 +2215,10 @@ mutable struct Obj; x; end push!(r, x) push!(wr, WeakRef(x)) end + @noinline function wr_pop!(a) + pop!(a) + nothing + end test_wr(r,wr) = @test r[1] == wr[1].value function test_wr() ref = [] @@ -2223,7 +2227,7 @@ mutable struct Obj; x; end test_wr(ref, wref) gc() test_wr(ref, wref) - pop!(ref) + wr_pop!(ref) gc() @test wref[1].value === nothing end