Skip to content

Commit

Permalink
Merge pull request #9 from CliMA/ck/tests2
Browse files Browse the repository at this point in the history
Add and fix tests for identity edge cases
  • Loading branch information
charleskawczynski authored Sep 27, 2024
2 parents 73ac980 + 0504f3d commit 7f2787a
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/LazyBroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function transform(e::Expr)
se = code_lowered_single_expression(e)
margs = materialize_args(se)
return :($(margs[2]))
elseif e.head == :call && isa_dot_op(e.args[1])
elseif e.head == :call && isa_dot_op(e.args[1]) || e.head == :.=
se = code_lowered_single_expression(e)
margs = materialize_args(se)
return :($(margs[2]))
Expand Down
21 changes: 17 additions & 4 deletions src/code_lowered_single_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,22 @@ substitute(s::Symbol, code) = s
substitute(e::Expr, code) =
Expr(substitute(e.head, code), substitute.(e.args, Ref(code))...)

code_info(expr) = Base.Meta.lower(Main, expr).args[1]
function code_info(expr::Expr)
lc = Base.Meta.lower(Main, expr)
if lc isa Symbol
return code_info(lc)
else
return code_info(lc.args[1])
end
end
code_info(s::Symbol) = s
code_info(ci::Core.CodeInfo) = ci.code
function code_lowered_single_expression(expr)
code = code_info(expr).code # vector
s = string(substitute(code[end], code))
return Base.Meta.parse(s)
code = code_info(expr) # vector
if code isa Symbol
return code
else
s = string(substitute(code[end], code))
return Base.Meta.parse(s)
end
end
6 changes: 4 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
wrap_bce(e) = :(Base.Broadcast.broadcasted(identity, $e))
materialize_args(s::Symbol) = (wrap_bce(s), wrap_bce(s))
function materialize_args(expr::Expr)
@assert expr.head == :call
if expr.args[1] == :(Base.materialize!)
return (expr.args[2], expr.args[3])
elseif expr.args[1] == :(Base.materialize)
return (expr.args[2], expr.args[2]) # for type-stability, we always return a 2-tuple
else
error("Invalid expression given to materialize_args")
return (wrap_bce(expr), wrap_bce(expr))
end
end

Expand Down Expand Up @@ -42,8 +44,8 @@ function check_restrictions(expr)
# as foo could change the pointer of `b` to something else
# however, this seems unlikely.
elseif expr.head == Symbol(".") # dot function call
elseif expr.head == Symbol(".=") # dot = expr
else
@show dump(expr)
@show dump(expr)
error("Uncaught edge case")
end
Expand Down
4 changes: 1 addition & 3 deletions test/expr_materialize_args.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,5 @@ import LazyBroadcast as LB
@test LB.materialize_args(expr_in) == (entry_out, entry_out)

expr_in = :(foo(Base.broadcasted(+, x1, x2, x3, x4)))
@test_throws ErrorException("Invalid expression given to materialize_args") LB.materialize_args(
expr_in,
)
@test expr_in == :(foo(Base.broadcasted(+, x1, x2, x3, x4)))
end
59 changes: 59 additions & 0 deletions test/lazy_broadcasted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,62 @@ b = rand(3, 3)
bce = LB.lazy_broadcasted(:(a .+ foo(b)))
@test bce == :(Base.broadcasted(+, a, foo(b)))
end


Base.@kwdef struct Foo{T}
x::T
end
Base.Broadcast.broadcastable(x::Foo) = tuple(x)

@testset "edge cases" begin
state = Foo(; x = 1)
bce = LB.lazy_broadcasted(:(@. state.x))
@test bce ==
:(Base.Broadcast.broadcasted(identity, Base.getproperty(state, :x)))
bco = LB.@lazy_broadcasted @. state.x
@test bco ==
Base.Broadcast.broadcasted(identity, Base.getproperty(state, :x))

bce = LB.lazy_broadcasted(:(@. state))
bco = LB.@lazy_broadcasted @. state
@test bce == :(Base.Broadcast.broadcasted(identity, state))
@test bco == Base.Broadcast.broadcasted(identity, state)

a, b = [1], [2]
bce = LB.lazy_broadcasted(:(a .= b))
@test bce == :(Base.broadcasted(Base.identity, b))

bce = LB.lazy_broadcasted(:(out .= x .+ foo(y) .+ bar.(z)))
@test bce == :(Base.broadcasted(
+,
Base.broadcasted(+, x, foo(y)),
Base.broadcasted(bar, z),
))
end

@testset "edge cases" begin
state = (; x = 1)
bce = LB.lazy_broadcasted(:(@. state.x))
@test bce ==
:(Base.Broadcast.broadcasted(identity, Base.getproperty(state, :x)))
bco = LB.@lazy_broadcasted @. state.x
@test bco ==
Base.Broadcast.broadcasted(identity, Base.getproperty(state, :x))

bce = LB.lazy_broadcasted(:(@. state))
@test_throws ArgumentError(
"broadcasting over dictionaries and `NamedTuple`s is reserved",
) LB.@lazy_broadcasted @. state
@test bce == :(Base.Broadcast.broadcasted(identity, state))

a, b = [1], [2]
bce = LB.lazy_broadcasted(:(a .= b))
@test bce == :(Base.broadcasted(Base.identity, b))

bce = LB.lazy_broadcasted(:(out .= x .+ foo(y) .+ bar.(z)))
@test bce == :(Base.broadcasted(
+,
Base.broadcasted(+, x, foo(y)),
Base.broadcasted(bar, z),
))
end

0 comments on commit 7f2787a

Please sign in to comment.