Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow splatting in calls to new #30577

Merged
merged 2 commits into from
Feb 7, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
implement splatting in new calls
adds `splatnew` form; like `new` but accepts a tuple
  • Loading branch information
JeffBezanson committed Feb 7, 2019
commit 7deecacd6d5feecc6b39b6803aa3f2171353f171
11 changes: 7 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ Julia v1.2 Release Notes
New language features
---------------------

* The `extrema` function now accepts a function argument in the same manner as `minimum` and
`maximum` ([#30323]).
* `hasmethod` can now check for matching keyword argument names ([#30712]).
* `startswith` and `endswith` now accept a `Regex` for the second argument ([#29790]).
* Argument splatting (`x...`) can now be used in calls to the `new` pseudo-function in
constructors ([#30577]).

Multi-threading changes
-----------------------
Expand Down Expand Up @@ -35,6 +33,11 @@ New library functions
Standard library changes
------------------------

* The `extrema` function now accepts a function argument in the same manner as `minimum` and
`maximum` ([#30323]).
* `hasmethod` can now check for matching keyword argument names ([#30712]).
* `startswith` and `endswith` now accept a `Regex` for the second argument ([#29790]).

#### LinearAlgebra

* Added keyword arguments `rtol`, `atol` to `pinv` and `nullspace` ([#29998]).
Expand Down
3 changes: 3 additions & 0 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,9 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
end
end
elseif e.head === :splatnew
t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1]
# TODO: improve
elseif e.head === :&
abstract_eval(e.args[1], vtypes, sv)
t = Any
Expand Down
25 changes: 25 additions & 0 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,31 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
todo = Any[]
for idx in 1:length(ir.stmts)
stmt = ir.stmts[idx]

if isexpr(stmt, :splatnew)
ty = ir.types[idx]
nf = nfields_tfunc(ty)
if nf isa Const
eargs = stmt.args
tup = eargs[2]
tt = argextype(tup, ir, sv.sptypes)
tnf = nfields_tfunc(tt)
if tnf isa Const && tnf.val <= nf.val
n = tnf.val
new_argexprs = Any[eargs[1]]
for j = 1:n
atype = getfield_tfunc(tt, Const(j))
new_call = Expr(:call, Core.getfield, tup, j)
new_argexpr = insert_node!(ir, idx, atype, new_call)
push!(new_argexprs, new_argexpr)
end
stmt.head = :new
stmt.args = new_argexprs
end
end
continue
end

isexpr(stmt, :call) || continue
eargs = stmt.args
isempty(eargs) && continue
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ function getindex(x::UseRef)
end

function is_relevant_expr(e::Expr)
return e.head in (:call, :invoke, :new, :(=), :(&),
return e.head in (:call, :invoke, :new, :splatnew, :(=), :(&),
:gc_preserve_begin, :gc_preserve_end,
:foreigncall, :isdefined, :copyast,
:undefcheck, :throw_undef_if_not,
Expand Down
20 changes: 10 additions & 10 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,17 +329,17 @@ function sizeof_tfunc(@nospecialize(x),)
return Int
end
add_tfunc(Core.sizeof, 1, 1, sizeof_tfunc, 0)
add_tfunc(nfields, 1, 1,
function (@nospecialize(x),)
isa(x, Const) && return Const(nfields(x.val))
isa(x, Conditional) && return Const(0)
if isa(x, DataType) && !x.abstract && !(x.name === Tuple.name && isvatuple(x))
if !(x.name === _NAMEDTUPLE_NAME && !isconcretetype(x))
return Const(length(x.types))
end
function nfields_tfunc(@nospecialize(x))
isa(x, Const) && return Const(nfields(x.val))
isa(x, Conditional) && return Const(0)
if isa(x, DataType) && !x.abstract && !(x.name === Tuple.name && isvatuple(x))
if !(x.name === _NAMEDTUPLE_NAME && !isconcretetype(x))
return Const(length(x.types))
end
return Int
end, 0)
end
return Int
end
add_tfunc(nfields, 1, 1, nfields_tfunc, 0)
add_tfunc(Core._expr, 1, INT_INF, (@nospecialize args...)->Expr, 100)
function typevar_tfunc(@nospecialize(n), @nospecialize(lb_arg), @nospecialize(ub_arg))
lb = Union{}
Expand Down
5 changes: 3 additions & 2 deletions base/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const VALID_EXPR_HEADS = IdDict{Any,Any}(
:method => 1:4,
:const => 1:1,
:new => 1:typemax(Int),
:splatnew => 2:2,
:return => 1:1,
:unreachable => 0:0,
:the_exception => 0:0,
Expand Down Expand Up @@ -142,7 +143,7 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_
head === :inbounds || head === :foreigncall || head === :cfunction ||
head === :const || head === :enter || head === :leave || head == :pop_exception ||
head === :method || head === :global || head === :static_parameter ||
head === :new || head === :thunk || head === :simdloop ||
head === :new || head === :splatnew || head === :thunk || head === :simdloop ||
head === :throw_undef_if_not || head === :unreachable
validate_val!(x)
else
Expand Down Expand Up @@ -224,7 +225,7 @@ end

function is_valid_rvalue(@nospecialize(x))
is_valid_argument(x) && return true
if isa(x, Expr) && x.head in (:new, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
if isa(x, Expr) && x.head in (:new, :splatnew, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
return true
end
return false
Expand Down
7 changes: 7 additions & 0 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,13 @@ convert(::Type{Tuple{Vararg{V}}}, x::Tuple{Vararg{V}}) where {V} = x
convert(T::Type{Tuple{Vararg{V}}}, x::Tuple) where {V} =
(convert(tuple_type_head(T), x[1]), convert(T, tail(x))...)

# used for splatting in `new`
convert_prefix(::Type{Tuple{}}, x::Tuple) = x
convert_prefix(::Type{<:AtLeast1}, x::Tuple{}) = x
convert_prefix(::Type{T}, x::T) where {T<:AtLeast1} = x
convert_prefix(::Type{T}, x::AtLeast1) where {T<:AtLeast1} =
(convert(tuple_type_head(T), x[1]), convert_prefix(tuple_type_tail(T), tail(x))...)

# TODO: the following definitions are equivalent (behaviorally) to the above method
# I think they may be faster / more efficient for inference,
# if we could enable them, but are they?
Expand Down
4 changes: 2 additions & 2 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1223,8 +1223,8 @@ function show_unquoted(io::IO, ex::Expr, indent::Int, prec::Int)
end

# new expr
elseif head === :new
show_enclosed_list(io, "%new(", args, ", ", ")", indent)
elseif head === :new || head === :splatnew
show_enclosed_list(io, "%$head(", args, ", ", ")", indent)

# other call-like expressions ("A[1,2]", "T{X,Y}", "f.(X,Y)")
elseif haskey(expr_calls, head) && nargs >= 1 # :ref/:curly/:calldecl/:(.)
Expand Down
5 changes: 5 additions & 0 deletions doc/src/devdocs/ast.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ These symbols appear in the `head` field of [`Expr`](@ref)s in lowered form.
to this, and the type is always inserted by the compiler. This is very much an internal-only
feature, and does no checking. Evaluating arbitrary `new` expressions can easily segfault.

* `splatnew`

Similar to `new`, except field values are passed as a single tuple. Works similarly to
`Base.splat(new)` if `new` were a first-class function, hence the name.

* `return`

Returns its argument as the value of the enclosing function.
Expand Down
2 changes: 2 additions & 0 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ jl_sym_t *enter_sym; jl_sym_t *leave_sym;
jl_sym_t *pop_exception_sym;
jl_sym_t *exc_sym; jl_sym_t *error_sym;
jl_sym_t *new_sym; jl_sym_t *using_sym;
jl_sym_t *splatnew_sym;
jl_sym_t *const_sym; jl_sym_t *thunk_sym;
jl_sym_t *abstracttype_sym; jl_sym_t *primtype_sym;
jl_sym_t *structtype_sym; jl_sym_t *foreigncall_sym;
Expand Down Expand Up @@ -325,6 +326,7 @@ void jl_init_frontend(void)
leave_sym = jl_symbol("leave");
pop_exception_sym = jl_symbol("pop_exception");
new_sym = jl_symbol("new");
splatnew_sym = jl_symbol("splatnew");
const_sym = jl_symbol("const");
global_sym = jl_symbol("global");
thunk_sym = jl_symbol("thunk");
Expand Down
21 changes: 21 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ static Function *jltls_states_func;

// important functions
static Function *jlnew_func;
static Function *jlsplatnew_func;
static Function *jlthrow_func;
static Function *jlerror_func;
static Function *jltypeerror_func;
Expand Down Expand Up @@ -4069,6 +4070,15 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval)
// it to the inferred type.
return mark_julia_type(ctx, val, true, (jl_value_t*)jl_any_type);
}
else if (head == splatnew_sym) {
jl_cgval_t argv[2];
argv[0] = emit_expr(ctx, args[0]);
argv[1] = emit_expr(ctx, args[1]);
Value *typ = boxed(ctx, argv[0]);
Value *tup = boxed(ctx, argv[1]);
Value *val = ctx.builder.CreateCall(prepare_call(jlsplatnew_func), { typ, tup });
return mark_julia_type(ctx, val, true, (jl_value_t*)jl_any_type);
}
else if (head == exc_sym) {
return mark_julia_type(ctx,
ctx.builder.CreateCall(prepare_call(jl_current_exception_func)),
Expand Down Expand Up @@ -6981,6 +6991,17 @@ static void init_julia_llvm_env(Module *m)
jlnew_func->addFnAttr(Thunk);
add_named_global(jlnew_func, &jl_new_structv);

std::vector<Type *> args_2rptrs_(0);
args_2rptrs_.push_back(T_prjlvalue);
args_2rptrs_.push_back(T_prjlvalue);
jlsplatnew_func =
Function::Create(FunctionType::get(T_prjlvalue, args_2rptrs_, false),
Function::ExternalLinkage,
"jl_new_structt", m);
add_return_attr(jlsplatnew_func, Attribute::NonNull);
jlsplatnew_func->addFnAttr(Thunk);
add_named_global(jlsplatnew_func, &jl_new_structt);

std::vector<Type*> args2(0);
args2.push_back(T_pint8);
#ifndef _OS_WINDOWS_
Expand Down
62 changes: 50 additions & 12 deletions src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,24 @@ JL_DLLEXPORT jl_value_t *jl_new_struct(jl_datatype_t *type, ...)
return jv;
}

JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
uint32_t na)
static void init_struct_tail(jl_datatype_t *type, jl_value_t *jv, size_t na)
{
size_t nf = jl_datatype_nfields(type);
for(size_t i=na; i < nf; i++) {
if (jl_field_isptr(type, i)) {
*(jl_value_t**)((char*)jl_data_ptr(jv)+jl_field_offset(type,i)) = NULL;
}
else {
jl_value_t *ft = jl_field_type(type, i);
if (jl_is_uniontype(ft)) {
uint8_t *psel = &((uint8_t *)jv)[jl_field_offset(type, i) + jl_field_size(type, i) - 1];
*psel = 0;
}
}
}
}

JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, uint32_t na)
{
jl_ptls_t ptls = jl_get_ptls_states();
if (type->instance != NULL) {
Expand All @@ -811,7 +827,6 @@ JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
}
if (type->layout == NULL)
jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type);
size_t nf = jl_datatype_nfields(type);
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type);
JL_GC_PUSH1(&jv);
for (size_t i = 0; i < na; i++) {
Expand All @@ -820,18 +835,41 @@ JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
jl_type_error("new", ft, args[i]);
jl_set_nth_field(jv, i, args[i]);
}
for(size_t i=na; i < nf; i++) {
if (jl_field_isptr(type, i)) {
*(jl_value_t**)((char*)jl_data_ptr(jv)+jl_field_offset(type,i)) = NULL;
}
else {
init_struct_tail(type, jv, na);
JL_GC_POP();
return jv;
}

JL_DLLEXPORT jl_value_t *jl_new_structt(jl_datatype_t *type, jl_value_t *tup)
{
jl_ptls_t ptls = jl_get_ptls_states();
if (!jl_is_tuple(tup))
jl_type_error("new", (jl_value_t*)jl_tuple_type, tup);
size_t na = jl_nfields(tup);
size_t nf = jl_datatype_nfields(type);
if (na > nf)
jl_too_many_args("new", nf);
if (type->instance != NULL) {
for (size_t i = 0; i < na; i++) {
jl_value_t *ft = jl_field_type(type, i);
if (jl_is_uniontype(ft)) {
uint8_t *psel = &((uint8_t *)jv)[jl_field_offset(type, i) + jl_field_size(type, i) - 1];
*psel = 0;
}
jl_value_t *fi = jl_get_nth_field(tup, i);
if (!jl_isa(fi, ft))
jl_type_error("new", ft, fi);
}
return type->instance;
}
if (type->layout == NULL)
jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type);
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type);
JL_GC_PUSH1(&jv);
for (size_t i = 0; i < na; i++) {
jl_value_t *ft = jl_field_type(type, i);
jl_value_t *fi = jl_get_nth_field(tup, i);
if (!jl_isa(fi, ft))
jl_type_error("new", ft, fi);
jl_set_nth_field(jv, i, fi);
}
init_struct_tail(type, jv, na);
JL_GC_POP();
return jv;
}
Expand Down
10 changes: 10 additions & 0 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,16 @@ SECT_INTERP static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
JL_GC_POP();
return v;
}
else if (head == splatnew_sym) {
jl_value_t **argv;
JL_GC_PUSHARGS(argv, 2);
argv[0] = eval_value(args[0], s);
argv[1] = eval_value(args[1], s);
assert(jl_is_structtype(argv[0]));
jl_value_t *v = jl_new_structt((jl_datatype_t*)argv[0], argv[1]);
JL_GC_POP();
return v;
}
else if (head == static_parameter_sym) {
ssize_t n = jl_unbox_long(args[0]);
assert(n > 0);
Expand Down
26 changes: 19 additions & 7 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,6 @@
(call (curly ,name ,@params) ,@field-names)))))

(define (new-call Tname type-params params args field-names field-types)
(if (any vararg? args)
(error "... is not supported inside \"new\""))
(if (any kwarg? args)
(error "\"new\" does not accept keyword arguments"))
(if (length> params (length type-params))
Expand All @@ -706,8 +704,22 @@
`(outerref ,Tname)
`(curly (outerref ,Tname)
,@type-params))))
(cond ((length> args (length field-names))
`(call (top error) "new: too many arguments"))
(cond ((length> (filter (lambda (a) (not (vararg? a))) args) (length field-names))
`(call (core throw) (call (top ArgumentError)
,(string "new: too many arguments (expected " (length field-names) ")"))))
((any vararg? args)
(if (every (lambda (ty) (equal? ty '(core Any)))
field-types)
`(splatnew ,Texpr (call (core tuple) ,@args))
(let ((tn (make-ssavalue)))
`(block
(= ,tn ,Texpr)
(splatnew ,tn (call (top convert_prefix)
(curly (core Tuple)
,@(map (lambda (fld)
`(call (core fieldtype) ,tn (quote ,fld)))
field-names))
(call (core tuple) ,@args)))))))
(else
(if (equal? type-params params)
`(new ,Texpr ,@(map (lambda (fty val)
Expand Down Expand Up @@ -2995,7 +3007,7 @@ f(x) = yt(x)
#f)
((eq? (car e) 'scope-block)
(visit (cadr e)))
((memq (car e) '(block call new _do_while))
((memq (car e) '(block call new splatnew _do_while))
(eager-any visit (cdr e)))
((eq? (car e) 'break-block)
(visit (caddr e)))
Expand Down Expand Up @@ -3405,7 +3417,7 @@ f(x) = yt(x)
(or (ssavalue? lhs)
(valid-ir-argument? e)
(and (symbol? lhs) (pair? e)
(memq (car e) '(new the_exception isdefined call invoke foreigncall cfunction gc_preserve_begin copyast)))))
(memq (car e) '(new splatnew the_exception isdefined call invoke foreigncall cfunction gc_preserve_begin copyast)))))

(define (valid-ir-return? e)
;; returning lambda directly is needed for @generated
Expand Down Expand Up @@ -3595,7 +3607,7 @@ f(x) = yt(x)
((and (pair? e1) (eq? (car e1) 'globalref)) (emit e1) #f) ;; keep globals for undefined-var checking
(else #f)))
(case (car e)
((call new foreigncall cfunction)
((call new splatnew foreigncall cfunction)
(let* ((args
(cond ((eq? (car e) 'foreigncall)
;; NOTE: 2nd to 5th arguments of ccall must be left in place
Expand Down
4 changes: 2 additions & 2 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1146,8 +1146,8 @@ jl_datatype_t *jl_new_abstracttype(jl_value_t *name, jl_module_t *module,
// constructors
JL_DLLEXPORT jl_value_t *jl_new_bits(jl_value_t *bt, void *data);
JL_DLLEXPORT jl_value_t *jl_new_struct(jl_datatype_t *type, ...);
JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
uint32_t na);
JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, uint32_t na);
JL_DLLEXPORT jl_value_t *jl_new_structt(jl_datatype_t *type, jl_value_t *tup);
JL_DLLEXPORT jl_value_t *jl_new_struct_uninit(jl_datatype_t *type);
JL_DLLEXPORT jl_method_instance_t *jl_new_method_instance_uninit(void);
JL_DLLEXPORT jl_svec_t *jl_svec(size_t n, ...) JL_MAYBE_UNROOTED;
Expand Down
Loading