Skip to content

Commit

Permalink
delete IR for non-inlineable functions after codegen to save memory
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffBezanson committed Jun 13, 2016
1 parent 85d098c commit e24fec2
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 59 deletions.
22 changes: 14 additions & 8 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr
# something completely new
elseif isa(code, LambdaInfo)
# something existing
if code.inferred
if code.inferred && !(needtree && code.code === nothing)
return (code, code.rettype, true)
end
else
Expand Down Expand Up @@ -1457,7 +1457,7 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr
# inference not started yet, make a new frame for a new lambda
# add lam to be inferred and record the edge

if caller === nothing && needtree && in_typeinf_loop
if caller === nothing && in_typeinf_loop
# if the caller needed the ast, but we are already in the typeinf loop
# then just return early -- we can't fulfill this request
# if the client was inlining, then this means we decided not to try to infer this
Expand Down Expand Up @@ -1486,7 +1486,7 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr
end
end

if isa(code, LambdaInfo)
if isa(code, LambdaInfo) && code.code !== nothing
# reuse the existing code object
linfo = code
@assert typeseq(linfo.specTypes, atypes)
Expand Down Expand Up @@ -2412,11 +2412,10 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
# end
# end

(linfo, ty, inferred) = typeinf(method, metharg, methsp, true)
if is(linfo,nothing) || !inferred
(linfo, ty, inferred) = typeinf(method, metharg, methsp, false)
if !inferred || linfo === nothing
return NF
end
if !linfo.inlineable
elseif !linfo.inlineable
# TODO
#=
if incompletematch
Expand All @@ -2443,6 +2442,11 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end
=#
return NF
elseif linfo.code === nothing
(linfo, ty, inferred) = typeinf(method, metharg, methsp, true)
end
if linfo === nothing || !inferred || !linfo.inlineable
return NF
end

na = linfo.nargs
Expand Down Expand Up @@ -2483,7 +2487,9 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
nm = length(methargs)

ast = linfo.code
if !isa(ast,Array{Any,1})
if ast === nothing
return NF
elseif !isa(ast,Array{Any,1})
ast = ccall(:jl_uncompress_ast, Any, (Any,Any), linfo, ast)
else
ast = copy_exprargs(ast)
Expand Down
2 changes: 1 addition & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ isempty(mt::MethodTable) = (mt.defs === nothing)

uncompressed_ast(l::Method) = uncompressed_ast(l.lambda_template)
uncompressed_ast(l::LambdaInfo) =
isa(l.code,Array{Any,1}) ? l.code::Array{Any,1} : ccall(:jl_uncompress_ast, Array{Any,1}, (Any,Any), l, l.code)
isa(l.code,Array{UInt8,1}) ? ccall(:jl_uncompress_ast, Array{Any,1}, (Any,Any), l, l.code) : l.code

# Printing code representations in IR and assembly
function _dump_function(f, t::ANY, native, wrapper, strip_ir_metadata, dump_module)
Expand Down
16 changes: 2 additions & 14 deletions base/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,6 @@ function serialize(s::SerializationState, meth::Method)
serialize(s, meth.ambig)
serialize(s, meth.isstaged)
serialize(s, meth.lambda_template)
if isdefined(meth, :roots)
serialize(s, meth.roots)
else
writetag(s.io, UNDEFREF_TAG)
end
nothing
end

Expand Down Expand Up @@ -611,12 +606,6 @@ function deserialize(s::SerializationState, ::Type{Method})
ambig = deserialize(s)
isstaged = deserialize(s)::Bool
template = deserialize(s)::LambdaInfo
tag = Int32(read(s.io, UInt8)::UInt8)
if tag != UNDEFREF_TAG
roots = handle_deserialize(s, tag)::Array{Any, 1}
else
roots = nothing
end
if makenew
meth.module = mod
meth.name = name
Expand All @@ -627,7 +616,6 @@ function deserialize(s::SerializationState, ::Type{Method})
meth.ambig = ambig
meth.isstaged = isstaged
meth.lambda_template = template
roots === nothing || (meth.roots = roots)
ccall(:jl_method_init_properties, Void, (Any,), meth)
known_object_data[lnumber] = meth
end
Expand All @@ -637,8 +625,8 @@ end
function deserialize(s::SerializationState, ::Type{LambdaInfo})
linfo = ccall(:jl_new_lambda_info_uninit, Ref{LambdaInfo}, (Ptr{Void},), C_NULL)
deserialize_cycle(s, linfo)
linfo.code = deserialize(s)::Array{Any, 1}
linfo.slotnames = deserialize(s)::Array{Any, 1}
linfo.code = deserialize(s)
linfo.slotnames = deserialize(s)
linfo.slottypes = deserialize(s)
linfo.slotflags = deserialize(s)
linfo.ssavaluetypes = deserialize(s)
Expand Down
8 changes: 4 additions & 4 deletions src/alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ void jl_lambda_info_set_ast(jl_lambda_info_t *li, jl_expr_t *ast)
jl_expr_t *bodyex = (jl_expr_t*)jl_exprarg(ast, 2);
assert(jl_is_expr(bodyex));
jl_array_t *body = bodyex->args;
li->code = body; jl_gc_wb(li, li->code);
li->code = (jl_value_t*)body; jl_gc_wb(li, li->code);
if (has_meta(body, pure_sym))
li->pure = 1;
jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1);
Expand Down Expand Up @@ -473,7 +473,7 @@ static jl_lambda_info_t *jl_instantiate_staged(jl_method_t *generator, jl_tuplet
func->specTypes = tt;
jl_gc_wb(func, tt);

jl_array_t *stmts = func->code;
jl_array_t *stmts = (jl_array_t*)func->code;
for(i = 0, l = jl_array_len(stmts); i < l; i++) {
jl_array_ptr_set(stmts, i, jl_resolve_globals(jl_array_ptr_ref(stmts, i), func));
}
Expand Down Expand Up @@ -541,7 +541,7 @@ JL_DLLEXPORT jl_lambda_info_t *jl_get_specialized(jl_method_t *m, jl_tupletype_t
JL_DLLEXPORT void jl_method_init_properties(jl_method_t *m)
{
jl_lambda_info_t *li = m->lambda_template;
jl_value_t *body1 = skip_meta(li->code);
jl_value_t *body1 = skip_meta((jl_array_t*)li->code);
if (jl_is_linenode(body1)) {
m->line = jl_linenode_line(body1);
}
Expand Down Expand Up @@ -614,7 +614,7 @@ jl_method_t *jl_new_method(jl_lambda_info_t *definition, jl_sym_t *name, jl_tupl
m->called = oldm->called;
}
else {
jl_array_t *stmts = definition->code;
jl_array_t *stmts = (jl_array_t*)definition->code;
int i, l;
for(i = 0, l = jl_array_len(stmts); i < l; i++) {
jl_array_ptr_set(stmts, i, jl_resolve_globals(jl_array_ptr_ref(stmts, i), definition));
Expand Down
21 changes: 20 additions & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,16 @@ static void to_function(jl_lambda_info_t *li)
// mark the pointer calling convention
li->jlcall_api = (f->getFunctionType() == jl_func_sig ? 0 : 1);

// if not inlineable, code won't be needed again
if (JL_DELETE_NON_INLINEABLE &&
li->def && li->inferred && !li->inlineable && !jl_options.outputji) {
li->code = jl_nothing;
li->slottypes = jl_nothing;
li->ssavaluetypes = jl_box_long(jl_array_len(li->ssavaluetypes)); jl_gc_wb(li, li->ssavaluetypes);
li->slotflags = NULL;
li->slotnames = NULL;
}

// done compiling: restore global state
if (old != NULL) {
builder.SetInsertPoint(old);
Expand Down Expand Up @@ -1115,6 +1125,15 @@ void *jl_get_llvmf(jl_tupletype_t *tt, bool getwrapper, bool getdeclarations)
return NULL;
}

if (linfo->code == jl_nothing) {
// re-infer if we've deleted the code
jl_type_infer(linfo, 0);
if (linfo->code == jl_nothing) {
JL_GC_POP();
return NULL;
}
}

if (!getdeclarations) {
// emit this function into a new module
Function *f, *specf;
Expand Down Expand Up @@ -3946,7 +3965,7 @@ static std::unique_ptr<Module> emit_function(jl_lambda_info_t *lam, jl_llvm_func
assert(declarations && "Capturing declarations is always required");

// step 1. unpack AST and allocate codegen context for this function
jl_array_t *code = lam->code;
jl_array_t *code = (jl_array_t*)lam->code;
JL_GC_PUSH1(&code);
if (!jl_typeis(code,jl_array_any_type))
code = jl_uncompress_ast(lam, code);
Expand Down
18 changes: 8 additions & 10 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -467,15 +467,12 @@ static int module_in_worklist(jl_module_t *mod)

static int jl_prune_tcache(jl_typemap_entry_t *ml, void *closure)
{
if (!jl_is_leaf_type((jl_value_t*)ml->sig)) {
jl_value_t *ret = ml->func.value;
if (jl_is_lambda_info(ret)) {
jl_array_t *code = ((jl_lambda_info_t*)ret)->code;
if (jl_is_array(code) && jl_array_len(code) > 500) {
ml->func.value = ((jl_lambda_info_t*)ret)->rettype;
jl_gc_wb(ml, ml->func.value);
}
}
jl_value_t *ret = ml->func.value;
if (jl_is_lambda_info(ret) &&
((!jl_is_leaf_type((jl_value_t*)ml->sig) && !((jl_lambda_info_t*)ret)->inlineable) ||
((jl_lambda_info_t*)ret)->code == jl_nothing)) {
ml->func.value = ((jl_lambda_info_t*)ret)->rettype;
jl_gc_wb(ml, ml->func.value);
}
return 1;
}
Expand Down Expand Up @@ -1476,7 +1473,7 @@ static jl_value_t *jl_deserialize_value_(ios_t *s, jl_value_t *vtag, jl_value_t
NWORDS(sizeof(jl_lambda_info_t)));
if (usetable)
arraylist_push(&backref_list, li);
li->code = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&li->code); jl_gc_wb(li, li->code);
li->code = jl_deserialize_value(s, &li->code); jl_gc_wb(li, li->code);
li->slotnames = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&li->slotnames); jl_gc_wb(li, li->slotnames);
li->slottypes = jl_deserialize_value(s, &li->slottypes); jl_gc_wb(li, li->slottypes);
li->slotflags = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&li->slotflags); jl_gc_wb(li, li->slotflags);
Expand Down Expand Up @@ -2076,6 +2073,7 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_lambda_info_t *li, jl_array_t *ast)
{
JL_LOCK(&dump_lock); // Might GC
assert(jl_is_lambda_info(li));
assert(jl_is_array(ast));
DUMP_MODES last_mode = mode;
mode = MODE_AST;
ios_t dest;
Expand Down
6 changes: 3 additions & 3 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ jl_value_t *jl_mk_builtin_func(const char *name, jl_fptr_t fptr)
jl_value_t *f = jl_new_generic_function_with_supertype(sname, jl_core_module, jl_builtin_type, 0);
jl_lambda_info_t *li = jl_new_lambda_info_uninit();
li->fptr = fptr;
// TODO jb/functions: what should li->ast be?
li->code = (jl_array_t*)jl_an_empty_vec_any; jl_gc_wb(li, li->code);
// TODO jb/functions: what should li->code be?
li->code = jl_nothing; jl_gc_wb(li, li->code);
li->def = jl_new_method_uninit();
li->def->name = sname;
li->def->lambda_template = li;
Expand All @@ -164,7 +164,7 @@ jl_lambda_info_t *jl_get_unspecialized(jl_lambda_info_t *method)
return method->unspecialized_ducttape;
if (method->sparam_syms != jl_emptysvec) {
if (def->needs_sparam_vals_ducttape == 2) {
jl_array_t *code = method->code;
jl_array_t *code = (jl_array_t*)method->code;
JL_GC_PUSH1(&code);
if (!jl_typeis(code, jl_array_any_type))
code = jl_uncompress_ast(def->lambda_template, code);
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s, int start,

jl_value_t *jl_interpret_call(jl_lambda_info_t *lam, jl_value_t **args, uint32_t nargs, jl_svec_t *sparam_vals)
{
jl_array_t *stmts = lam->code;
jl_array_t *stmts = (jl_array_t*)lam->code;
assert(jl_typeis(stmts, jl_array_any_type));
jl_value_t **locals;
JL_GC_PUSHARGS(locals, jl_linfo_nslots(lam) + jl_linfo_nssavalues(lam));
Expand Down
20 changes: 10 additions & 10 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3630,15 +3630,15 @@ void jl_init_types(void)
jl_new_datatype(jl_symbol("LambdaInfo"),
jl_any_type, jl_emptysvec,
jl_svec(25,
jl_symbol("code"),
jl_symbol("slotnames"),
jl_symbol("slottypes"),
jl_symbol("slotflags"),
jl_symbol("ssavaluetypes"),
jl_symbol("rettype"),
jl_symbol("sparam_syms"),
jl_symbol("sparam_vals"),
jl_symbol("specTypes"),
jl_symbol("code"),
jl_symbol("slottypes"),
jl_symbol("ssavaluetypes"),
jl_symbol("slotnames"),
jl_symbol("slotflags"),
jl_symbol("unspecialized_ducttape"),
jl_symbol("def"),
jl_symbol("nargs"),
Expand All @@ -3655,14 +3655,14 @@ void jl_init_types(void)
jl_symbol(""), jl_symbol("")),
jl_svec(25,
jl_any_type,
jl_array_any_type,
jl_simplevector_type,
jl_simplevector_type,
jl_any_type,
jl_array_uint8_type,
jl_any_type,
jl_any_type,
jl_simplevector_type,
jl_simplevector_type,
jl_any_type,
jl_array_any_type,
jl_array_uint8_type,
jl_any_type,
jl_method_type,
jl_int32_type,
Expand All @@ -3677,7 +3677,7 @@ void jl_init_types(void)
jl_any_type,
jl_any_type, jl_any_type,
jl_int32_type, jl_int32_type),
0, 1, 10);
0, 1, 7);
jl_svecset(jl_lambda_info_type->types, 9, jl_lambda_info_type);
jl_svecset(jl_method_type->types, 8, jl_lambda_info_type);

Expand Down
10 changes: 5 additions & 5 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,15 @@ typedef struct _jl_method_t {
// a function pointer.
typedef struct _jl_lambda_info_t {
JL_DATA_TYPE
jl_array_t *code; // compressed uint8 array, or Any array of statements
jl_array_t *slotnames; // names of local variables
jl_value_t *slottypes;
jl_array_t *slotflags; // local var bit flags
jl_value_t *ssavaluetypes; // types of ssa values
jl_value_t *rettype;
jl_svec_t *sparam_syms; // sparams is a vector of values indexed by symbols
jl_svec_t *sparam_vals;
jl_tupletype_t *specTypes; // argument types this was specialized for
jl_value_t *code; // compressed uint8 array, or Any array of statements
jl_value_t *slottypes;
jl_value_t *ssavaluetypes; // types of ssa values
jl_array_t *slotnames; // names of local variables
jl_array_t *slotflags; // local var bit flags
struct _jl_lambda_info_t *unspecialized_ducttape; // if template can't be compiled due to intrinsics, an un-inferred executable copy may get stored here
jl_method_t *def; // method this is specialized from, (null if this is a toplevel thunk)
int32_t nargs;
Expand Down
3 changes: 3 additions & 0 deletions src/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
// with KEEP_BODIES, we keep LLVM function bodies around for later debugging
// #define KEEP_BODIES

// delete julia IR for non-inlineable functions after they're codegen'd
#define JL_DELETE_NON_INLINEABLE 1

// GC options -----------------------------------------------------------------

// debugging options
Expand Down
4 changes: 2 additions & 2 deletions src/toplevel.c
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ jl_value_t *jl_toplevel_eval_flex(jl_value_t *e, int fast, int expanded)
thk = (jl_lambda_info_t*)jl_exprarg(ex,0);
assert(jl_is_lambda_info(thk));
assert(jl_typeis(thk->code, jl_array_any_type));
ewc = jl_eval_with_compiler_p(thk, thk->code, fast, jl_current_module);
ewc = jl_eval_with_compiler_p(thk, (jl_array_t*)thk->code, fast, jl_current_module);
}
else {
if (head && jl_eval_expr_with_compiler_p((jl_value_t*)ex, fast, jl_current_module)) {
Expand Down Expand Up @@ -774,7 +774,7 @@ JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata, jl_lambda_info_t *f, jl_valu
jl_call_tracer(jl_newmeth_tracer, (jl_value_t*)m);

if (jl_boot_file_loaded && f->code && jl_typeis(f->code, jl_array_any_type)) {
f->code = jl_compress_ast(f, f->code);
f->code = (jl_value_t*)jl_compress_ast(f, (jl_array_t*)f->code);
jl_gc_wb(f, f->code);
}
JL_GC_POP();
Expand Down

0 comments on commit e24fec2

Please sign in to comment.