From 35ee48a5a25363d88d023ab7ada89c0c008b39fa Mon Sep 17 00:00:00 2001 From: Thibaut Lienart Date: Wed, 14 Jun 2023 17:28:17 +0200 Subject: [PATCH 1/3] typo fix in readme (#47) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c8bec3d..ee8b1f3 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ bound it do the variable `l` and added `1.0` to `l`, whereas if it was a `Right` The `@cases` macro still falls far short of a full on pattern matching system, lacking many features. For anything advanced, I'd recommend using `@match` from [MLStyle.jl](https://github.com/thautwarm/MLStyle.jl). -### Defining many repretitive cases simultaneously +### Defining many repetitive cases simultaneously `@cases` does not allow for fallback branches, and it also does not allow one to write inexhaustive cases. To avoid making some code overly verbose and repetitive, we instead provide syntax for defining many cases in one line: From f78634935da7a17ded0cfbcf0e2d5a0f7121f5b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leandro=20Mart=C3=ADnez?= <31046348+lmiq@users.noreply.github.com> Date: Sat, 17 Jun 2023 17:11:47 -0300 Subject: [PATCH 2/3] fix typo in README.md (#46) From c7628c48185f75fc7369e6929ab748418322b4e1 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 26 Jun 2023 23:54:59 -0600 Subject: [PATCH 3/3] Back to union storage (#49) * overhaul back to `Union` storage * cleanup * remove `matching_error` since its no longer needed * indentation fixes --- Project.toml | 2 +- README.md | 241 ++++++++++++++++++----------------------- src/SumTypes.jl | 56 +++------- src/cases.jl | 19 ++-- src/compute_storage.jl | 143 ------------------------ src/sum_type.jl | 100 +++++++---------- test/runtests.jl | 64 ++--------- 7 files changed, 181 insertions(+), 444 deletions(-) delete mode 100644 src/compute_storage.jl diff --git a/Project.toml b/Project.toml index 5c2a75c..8cc7312 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SumTypes" uuid = "8e1ec7a9-0e02-4297-b0fe-6433085c89f2" authors = ["MasonProtter "] -version = "0.4.8" +version = "0.5.0" [deps] MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" diff --git a/README.md b/README.md index ee8b1f3..eb11ec9 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ - [Basics](https://github.com/MasonProtter/SumTypes.jl#basics) - [Destructuring sum types](https://github.com/MasonProtter/SumTypes.jl#destructuring-sum-types) -- [Using `full_type` to get the concrete type of a Sum Type](https://github.com/MasonProtter/SumTypes.jl/#using-full_type-to-get-the-concrete-type-of-a-sum-type) - [Avoiding namespace clutter](https://github.com/MasonProtter/SumTypes.jl#avoiding-namespace-clutter) - [Custom printing](https://github.com/MasonProtter/SumTypes.jl#custom-printing) - [Performance](https://github.com/MasonProtter/SumTypes.jl#performance) @@ -118,20 +117,20 @@ Okay, but how do I actually access the data enclosed in a `Fruit` or an `Either` SumTypes.jl exposes a `@cases` macro for efficiently unwrapping and branching on the contents of a sum type: ```julia -julia> myfruit = Orange -Orange::Fruit +julia> myfruit = orange +orange::Fruit julia> @cases myfruit begin - Apple => "Got an apple!" - Orange => "Got an orange!" - Banana => error("I'm allergic to bananas!") + apple => "Got an apple!" + orange => "Got an orange!" + banana => error("I'm allergic to bananas!") end "Got an orange!" -julia> @cases Banana begin - Apple => "Got an apple!" - Orange => "Got an orange!" - Banana => error("I'm allergic to bananas!") +julia> @cases banana begin + apple => "Got an apple!" + orange => "Got an orange!" + banana => error("I'm allergic to bananas!") end ERROR: I'm allergic to bananas! [...] @@ -139,10 +138,10 @@ ERROR: I'm allergic to bananas! `@cases` can automatically detect if you didn't give an exhaustive set of cases (with no runtime penalty) and throw an error. ```julia julia> @cases myfruit begin - Apple => "Got an apple!" - Orange => "Got an orange!" + apple => "Got an apple!" + orange => "Got an orange!" end -ERROR: Inexhaustive @cases specification. Got cases (:Apple, :Orange), expected (:Apple, :Banana, :Orange) +ERROR: Inexhaustive @cases specification. Got cases (:apple, :orange), expected (:apple, :banana, :orange) [...] ``` @@ -211,44 +210,6 @@ end; -## Using `full_type` to get the concrete type of a Sum Type - -
-Click to expand - -SumTypes.jl generates structs with a compactified memory layout which is computed on demand for parametric types. Because of this, -every SumTypes actually has two extra type parameters related to its memory layout. This means that for instance, with `Either{Int, Int}`: - -``` julia -julia> @sum_type Either{A, B} begin - Left{A}(::A) - Right{B}(::B) - end - -julia> isconcretetype(Either{Int, Int}) -false -``` - -In order to get the proper, concrete type corresponding to `Either{Int, Int}`, one can use the `full_type` function exported by SumTypes.jl: - -``` julia -julia> full_type(Either{Int, Int}) -Either{Int64, Int64, 8, 0, UInt64} - -julia> full_type(Either{Int, String}) -Either{Int64, String, 8, 1, UInt8} - -julia> full_type(Either{Tuple{Int, Int, Int}, String}) -Either{Tuple{Int64, Int64, Int64}, String, 24, 1, UInt8} - -julia> isconcretetype(ans) -true -``` - -Avoiding these extra parameters would require https://github.com/JuliaLang/julia/issues/8472 to be implemented. - -
- ## Avoiding namespace clutter @@ -334,58 +295,45 @@ julia> SumTypes.show_sumtype(io::IO, ::MIME"text/plain", x::Fruit2) = @cases x b julia> apple apple! ``` -If you overload `Base.show` directly inside a package, you might get annoying method deletion warnings during pre-compilation. +If you overload `Base.show** directly inside a package, you might get annoying method deletion warnings during pre-compilation. ## Performance -In the same way as [Unityper.jl](https://github.com/YingboMa/Unityper.jl) is able to provide a dramatic speedup versus manual union splitting, SumTypes.jl can do this too: +SumTypes.jl can provide some speedups compared to union-splitting when destructuring and branching on abstractly typed data. -Branching on abstractly typed data +#### SumTypes.jl
Benchmark code ``` julia -module AbstractTypeTest - -using BenchmarkTools +module SumTypeTest -abstract type AT end -Base.@kwdef struct A <: AT - common_field::Int = 0 - a::Bool = true - b::Int = 10 -end -Base.@kwdef struct B <: AT - common_field::Int = 0 - a::Int = 1 - b::Float64 = 1.0 - d::Complex = 1 + 1.0im # not isbits -end -Base.@kwdef struct C <: AT - common_field::Int = 0 - b::Float64 = 2.0 - d::Bool = false - e::Float64 = 3.0 - k::Complex{Real} = 1 + 2im # not isbits -end -Base.@kwdef struct D <: AT - common_field::Int = 0 - b::Any = :hi # not isbits +using SumTypes, BenchmarkTools +@sum_type AT begin + A(common_field::Int, a::Bool, b::Int) + B(common_field::Int, a::Int, b::Float64, d::Complex) + C(common_field::Int, b::Float64, d::Bool, e::Float64, k::Complex{Real}) + D(common_field::Int, b::Any) end foo!(xs) = for i in eachindex(xs) - @inbounds x = xs[i] - @inbounds xs[i] = x isa A ? B() : - x isa B ? C() : - x isa C ? D() : - x isa D ? A() : error() + xs[i] = @cases xs[i] begin + A(cf, a, b) => B(cf+1, a, b, b) + B(cf, a, b, d) => C(cf-1, b, isodd(a), b, d) + C(cf, b, d, e, k) => D(cf+1, isodd(cf) ? "hi" : "bye") + D(cf, b) => A(cf-1, b=="hi", cf) + end end +xs = rand((A(1, true, 10), + B(1, 1, 1.0, 1+1im), + C(1, 2.0, false, 3.0, Complex{Real}(1 + 2im)), + D(1, "hi")), + 10000); -xs = rand((A(), B(), C(), D()), 10000); display(@benchmark foo!($xs);) end @@ -395,74 +343,95 @@ end ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. - Range (min … max): 267.399 μs … 3.118 ms ┊ GC (min … max): 0.00% … 90.36% - Time (median): 278.904 μs ┊ GC (median): 0.00% - Time (mean ± σ): 316.971 μs ± 306.290 μs ┊ GC (mean ± σ): 11.68% ± 10.74% + Range (min … max): 300.541 μs … 2.585 ms ┊ GC (min … max): 0.00% … 86.91% + Time (median): 313.611 μs ┊ GC (median): 0.00% + Time (mean ± σ): 342.285 μs ± 242.158 μs ┊ GC (mean ± σ): 8.29% ± 10.04% █ ▁ - █▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▇ █ - 267 μs Histogram: log(frequency) by time 2.77 ms < + █▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ █ + 301 μs Histogram: log(frequency) by time 2.37 ms < - Memory estimate: 654.75 KiB, allocs estimate: 21952. + Memory estimate: 620.88 KiB, allocs estimate: 19900. ``` -SumTypes.jl + +#### Branching on abstractly typed data:
Benchmark code ``` julia -module SumTypeTest +module AbstractTypeTest -using SumTypes, BenchmarkTools -@sum_type AT begin - A(common_field::Int, a::Bool, b::Int) - B(common_field::Int, a::Int, b::Float64, d::Complex) - C(common_field::Int, b::Float64, d::Bool, e::Float64, k::Complex{Real}) - D(common_field::Int, b::Any) -end +using BenchmarkTools -A(;common_field=1, a=true, b=10) = A(common_field, a, b) -B(;common_field=1, a=1, b=1.0, d=1 + 1.0im) = B(common_field, a, b, d) -C(;common_field=1, b=2.0, d=false, e=3.0, k=Complex{Real}(1 + 2im)) = C(common_field, b, d, e, k) -D(;common_field=1, b=:hi) = D(common_field, b) +abstract type AT end +Base.@kwdef struct A <: AT + common_field::Int + a::Bool + b::Int +end +Base.@kwdef struct B <: AT + common_field::Int + a::Int + b::Float64 + d::Complex # not isbits +end +Base.@kwdef struct C <: AT + common_field::Int + b::Float64 + d::Bool + e::Float64 + k::Complex{Real} # not isbits +end +Base.@kwdef struct D <: AT + common_field::Int + b::Any # not isbits +end foo!(xs) = for i in eachindex(xs) - xs[i] = @cases xs[i] begin - A => B() - B => C() - C => D() - D => A() - end + @inbounds x = xs[i] + @inbounds xs[i] = x isa A ? B(x.common_field+1, x.a, x.b, x.b) : + x isa B ? C(x.common_field-1, x.b, isodd(x.a), x.b, x.d) : + x isa C ? D(x.common_field+1, isodd(x.common_field) ? "hi" : "bye") : + x isa D ? A(x.common_field-1, x.b=="hi", x.common_field) : error() end -xs = rand((A(), B(), C(), D()), 10000); + +xs = rand((A(1, true, 10), + B(1, 1, 1.0, 1+1im), + C(1, 2.0, false, 3.0, Complex{Real}(1 + 2im)), + D(1, "hi")), + 10000); display(@benchmark foo!($xs);) -end +end ```
``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. - Range (min … max): 52.680 μs … 72.570 μs ┊ GC (min … max): 0.00% … 0.00% - Time (median): 53.590 μs ┊ GC (median): 0.00% - Time (mean ± σ): 53.718 μs ± 756.064 ns ┊ GC (mean ± σ): 0.00% ± 0.00% + Range (min … max): 366.510 μs … 4.504 ms ┊ GC (min … max): 0.00% … 90.65% + Time (median): 386.470 μs ┊ GC (median): 0.00% + Time (mean ± σ): 478.369 μs ± 571.525 μs ┊ GC (mean ± σ): 18.62% ± 13.77% - ▁▂▁▃▆▅█▇▅▅▃▁▁ - ▁▂▂▃▅▆██████████████▇▇▅▄▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃ - 52.7 μs Histogram: frequency by time 56.7 μs < + █ ▂ ▁ + █▇▄▅▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ █ + 367 μs Histogram: log(frequency) by time 4.1 ms < - Memory estimate: 0 bytes, allocs estimate: 0. + Memory estimate: 1.06 MiB, allocs estimate: 31958. ``` -And Unityper.jl: +#### Unityper.jl + +Unityper.jl is a somewhat similar package, with some overlapping goals to SumTypes.jl. However, In this test, Unityper.jl ends up doing much worse than abstract containers or SumTypes.jl:
Benchmark code ``` julia + module UnityperTest using Unityper, BenchmarkTools @@ -487,17 +456,17 @@ using Unityper, BenchmarkTools k::Complex{Real} = 1 + 2im # not isbits end struct D <: AT - b::Any = :hi # not isbits + b::Any = "hi" # not isbits end end foo!(xs) = for i in eachindex(xs) @inbounds x = xs[i] @inbounds xs[i] = @compactified x::AT begin - A => B() - B => C() - C => D() - D => A() + A => B(;common_field=x.common_field+1, a=x.a, b=x.b, d=x.b) + B => C(;common_field=x.common_field-1, b=x.b, d=isodd(x.a), e=x.b, k=x.d) + C => D(;common_field=x.common_field+1, b=isodd(x.common_field) ? "hi" : "bye") + D => A(;common_field=x.common_field-1, a=x.b=="hi", b=x.common_field) end end @@ -510,18 +479,18 @@ end
``` -BenchmarkTools.Trial: 10000 samples with 1 evaluation. - Range (min … max): 54.220 μs … 76.000 μs ┊ GC (min … max): 0.00% … 0.00% - Time (median): 55.030 μs ┊ GC (median): 0.00% - Time (mean ± σ): 55.073 μs ± 466.103 ns ┊ GC (mean ± σ): 0.00% ± 0.00% - - ▁▁▅▄▄▅▅█▄▅▃▃▂ - ▁▁▁▁▂▂▂▃▄▆▆███████████████▇▆▅▆▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃ - 54.2 μs Histogram: frequency by time 56.7 μs < - Memory estimate: 0 bytes, allocs estimate: 0. +BenchmarkTools.Trial: 2539 samples with 1 evaluation. + Range (min … max): 1.847 ms … 5.341 ms ┊ GC (min … max): 0.00% … 64.05% + Time (median): 1.890 ms ┊ GC (median): 0.00% + Time (mean ± σ): 1.968 ms ± 478.604 μs ┊ GC (mean ± σ): 3.93% ± 9.68% + + █▆ + ██▇▆▁▃▁▁▃▁▃▁▁▁▁▁▃▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆▆▆▇ █ + 1.85 ms Histogram: log(frequency) by time 4.9 ms < + + Memory estimate: 1.14 MiB, allocs estimate: 27272. ``` -SumTypes.jl and Unityper.jl are about equal in this benchmark, though there are cases where there are differences. SumTypes.jl has some other advantages relative to Unityper.jl such as: - SumTypes.jl allows [parametric types](https://docs.julialang.org/en/v1/manual/types/#Parametric-Types) for much greater container flexibility. - SumTypes.jl does not require default values for every field of the struct. @@ -529,4 +498,4 @@ SumTypes.jl has some other advantages relative to Unityper.jl such as: - SumTypes.jl allows you to hide its variants from the namespace (opt in). One advantage of Unityper.jl is: -- Because Unityper.jl doesn't allow parameterized types and needs to know all type information at macroexpansion time, their structs have a fixed layout for boxed variables that lets them avoid an allocation when storing heap allocated objects (this allocation would be in addition to the heap allocation for the object itself). If we had used `D(;common_field=1, b="hi")` in our benchmarks, SumTypes.jl could have incurred an allocation whereas Unityper.jl would not. As far as I know, this would requre https://github.com/JuliaLang/julia/issues/8472 in order to avoid in SumTypes.jl +- If you're not modifying the data and just re-using old heap allocated data, there are cases where Unityper.jl can avoid an allocation that SumTypes.jl would have incurred. diff --git a/src/SumTypes.jl b/src/SumTypes.jl index 6cbe2d6..cbe4559 100644 --- a/src/SumTypes.jl +++ b/src/SumTypes.jl @@ -4,22 +4,14 @@ export @sum_type, @cases, Uninit, full_type using MacroTools: MacroTools -function parent end function constructors end -function constructor end -function constructors_Union end +function constructor end function variants_Tuple end function unwrap end -function tags end -function deparameterize end is_sumtype(::Type{T}) where {T} = false -function flagtype end -function flag_to_symbol end -function symbol_to_flag end -function tags_flags_nt end -function variants_Tuple end -function strip_size_params end -function full_type end + +function get_tag end +function tags end isexpr(x, head) = x isa Expr && x.head == head @@ -38,56 +30,40 @@ For an `x` which was created as a `@sum_type`, check if it's variant tag is `s`. isvariant(x, :Right) # false end """ -isvariant(x::T, s::Symbol) where {T} = get_tag(x) == symbol_to_flag(T, s) +function isvariant end struct Unsafe end const unsafe = Unsafe() struct Uninit end -struct Variant{fieldnames, Tup <: Tuple} +struct Variant{name, fieldnames, Tup <: Tuple} data::Tup - Variant{fieldnames, Tup}(::Unsafe) where {fieldnames, Tup} = new{fieldnames, Tup}() - Variant{fieldnames, Tup}(t::Tuple) where {fieldnames, Tup <: Tuple} = new{fieldnames, Tup}(t) + Variant{name, fieldnames, Tup}(::Unsafe) where {name, fieldnames, Tup} = new{name, fieldnames, Tup}() + Variant{name, fieldnames, Tup}(t::Tuple) where {name, fieldnames, Tup <: Tuple} = new{name, fieldnames, Tup}(t) end -Base.:(==)(v1::Variant, v2::Variant) = v1.data == v2.data +get_name(::Variant{name}) where {name} = name +Base.:(==)(v1::Variant{name}, v2::Variant{name}) where {name} = v1.data == v2.data Base.iterate(x::Variant, s = 1) = iterate(x.data, s) Base.indexed_iterate(x::Variant, i::Int, state=1) = (Base.@_inline_meta; (getfield(x.data, i), i+1)) Base.getindex(x::Variant, i) = x.data[i] -const tag = Symbol("#tag#") -get_tag(x) = getfield(x, tag) -get_tag_sym(x::T) where {T} = keys(tags_flags_nt(T))[Int(get_tag(x)) + 1] - show_sumtype(io::IO, m::MIME, x) = show_sumtype(io, x) function show_sumtype(io::IO, x::T) where {T} - tag = get_tag(x) - sym = flag_to_symbol(T, tag) - T_stripped = T_string_stripped(T) - if unwrap(x) isa Variant{(), Tuple{}} - print(io, String(sym), "::", T_stripped) + data = unwrap(x) + sym = get_name(data) + if length(data.data) == 0 + print(io, String(sym), "::", T) else - print(io, String(sym), '(', join((repr(data) for data ∈ unwrap(x)), ", "), ")::", T_stripped) + print(io, String(sym), '(', join((repr(field) for field ∈ data), ", "), ")::", T) end end -function T_string_stripped(::Type{_T}) where {_T} - @assert is_sumtype(_T) - T = full_type(_T) - T_stripped = if length(T.parameters) == 3 - String(T.name.name) - else - string(String(T.name.name), "{", join(repr.(T.parameters[1:end-3]), ", "), "}") - end -end - struct Converter{T, U} end (::Converter{T, U})(x) where {T, U} = convert(T, U(x)) -Base.show(io::IO, x::Converter{T, U}) where {T, U} = print(io, "$(T_string_stripped(T))'.$U") - +Base.show(io::IO, x::Converter{T, U}) where {T, U} = print(io, "$T'.$U") -include("compute_storage.jl") include("sum_type.jl") # @sum_type defined here include("cases.jl") # @cases defined here diff --git a/src/cases.jl b/src/cases.jl index ef745b0..16dd6ef 100644 --- a/src/cases.jl +++ b/src/cases.jl @@ -2,7 +2,6 @@ # "Inexhaustive @cases specification. Got cases $(variants), expected $(tags(T))")) @noinline check_sum_type(::Type{T}) where {T} = is_sumtype(T) ? nothing : throw(error("@cases only works on SumTypes, got $T which is not a SumType")) -@noinline matching_error() = throw(error("Something went wrong during matching")) @generated function assert_exhaustive(::Type{Val{tags}}, ::Type{Val{variants}}) where {tags, variants} ret = nothing @@ -67,22 +66,21 @@ macro cases(to_match, block) @gensym con_Union @gensym Typ @gensym nt + @gensym unwrapped variants = map(x -> x.variant, stmts) - ex = :(if $isvariant($data, $(QuoteNode(stmts[1].variant))); - $(stmts[1].iscall ? :(($(stmts[1].fieldnames...),) = - $unwrap($data, $constructor($Typ, $Val{$(QuoteNode(stmts[1].variant))}), $variants_Tuple($Typ)) ) : nothing); + ex = :(if $unwrapped isa $Variant{$(QuoteNode(stmts[1].variant))} + $(stmts[1].iscall ? :(($(stmts[1].fieldnames...),) = $unwrapped) : nothing); $(stmts[1].rhs) end) Base.remove_linenums!(ex) pushfirst!(ex.args[2].args, lnns[1]) to_push = ex.args for i ∈ 2:length(stmts) - _if = :(if $isvariant($data, $(QuoteNode(stmts[i].variant))); - $(stmts[i].iscall ? :(($(stmts[i].fieldnames...),) = - $unwrap($data, $constructor($Typ, $Val{$(QuoteNode(stmts[i].variant))}), $variants_Tuple($Typ))) : nothing); - $(stmts[i].rhs) - end) + _if = :(if $unwrapped isa $Variant{$(QuoteNode(stmts[i].variant))} + $(stmts[i].iscall ? :(($(stmts[i].fieldnames...),) = $unwrapped) : nothing); + $(stmts[i].rhs) + end) _if.head = :elseif Base.remove_linenums!(_if) pushfirst!(_if.args[2].args, lnns[i]) @@ -90,13 +88,14 @@ macro cases(to_match, block) push!(to_push, _if) to_push = to_push[3].args end - push!(to_push, :($matching_error())) + # push!(to_push, :($matching_error())) deparameterize(x) = x isa Symbol ? x : x isa Expr && x.head == :curly ? x.args[1] : throw("Invalid variant name $x") quote let $data = $to_match $Typ = $typeof($data) $check_sum_type($Typ) $assert_exhaustive(Val{$tags($Typ)}, Val{$(Expr(:tuple, QuoteNode.(deparameterize.(variants))...))}) + $unwrapped = $unwrap($data) $ex end end |> esc diff --git a/src/compute_storage.jl b/src/compute_storage.jl deleted file mode 100644 index 361649d..0000000 --- a/src/compute_storage.jl +++ /dev/null @@ -1,143 +0,0 @@ -struct PlaceHolder end - -macro assume_effects(args...) - if isdefined(Base, Symbol("@assume_effects")) - ex = :($Base.@assume_effects($(args...))) - else - ex = args[end] - end - esc(ex) -end - -@assume_effects :foldable function unsafe_padded_reinterpret(::Type{T}, x::U) where {T, U} - @assert isbitstype(T) && isbitstype(U) - n, m = sizeof(T), sizeof(U) - if sizeof(U) < sizeof(T) - payload = (x, ntuple(_ -> zero(UInt8), Val(n-m)), ) - else - payload = x - end - let r = Ref(payload) - GC.@preserve r begin - p = pointer_from_objref(r) - unsafe_load(Ptr{T}(p)) - end - end -end - -_rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im) -_argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2] - -function extract_info(::Type{ST}, variants) where {ST} - - data = map(variants) do variant - (names, store_types) = variant.parameters - bits = [] - ptrs = [] - @assert length(names) == length(store_types.parameters) - foreach(zip(names, store_types.parameters)) do (name, T) - if isbitstype(T) - push!(bits, name => T) - else - push!(bits, name => SumTypes.PlaceHolder) - push!(ptrs, name => T) - end - end - bits, ptrs - end - bitss = map(x -> x[1], data) - ptrss = map(x -> x[2], data) - nptrs = maximum(length, ptrss) - ptr_names = map(v -> map(x -> x[1], v), ptrss) - bit_names = map(v -> map(x -> x[1], v), bitss) - bit_sigs = map(v -> map(x -> x[2], v), bitss) - - FT = fieldtype(ST, 3) - bit_size = maximum(v -> sizeof(Tuple{map(x -> x[2], v)..., }), bitss) - - _FT = length(variants) < typemax(UInt8) ? UInt8 : length(variants) < typemax(UInt16) ? UInt16 : UInt32 - - FT = if nptrs == 0 - if bit_size <= 1 - _argmax(sizeof, (_FT, UInt8)) - elseif bit_size <= 2 - _argmax(sizeof, (_FT, UInt16)) - elseif bit_size <= 4 - _argmax(sizeof, (_FT, UInt32)) - else - UInt - end - else - _FT - end - - (; - bitss = bitss, - ptrss = ptrss, - nptrs = nptrs, - ptr_names = ptr_names, - bit_size = bit_size, - bit_names = bit_names, - bit_sigs = bit_sigs, - flagtype = FT - ) -end - - -make(::Type{ST}, to_make, tag) where {ST} = make(ST, to_make, tag, variants_Tuple(ST)) -@generated function make(::Type{ST}, to_make::Var, tag, ::Type{var_Tuple}) where {ST, Var <: Variant, var_Tuple <: Tuple} - variants = var_Tuple.parameters - i = findfirst(==(Var), variants) - nt = extract_info(ST, variants) - - nptrs = nt.nptrs - ptr_names = nt.ptr_names - bit_size = nt.bit_size - bit_names = nt.bit_names - bit_sigs = nt.bit_sigs - FT = nt.flagtype - - bitvariant = :(SumTypes.Variant{($(QuoteNode.(bit_names[i])...),), Tuple{$(bit_sigs[i]...)}}( - ($(([bit_sigs[i][j] == PlaceHolder ? PlaceHolder() : :(to_make.data[$j]) for j ∈ eachindex(bit_sigs[i]) ])...),) )) - ptr_args = [:(to_make.data[$j]) for j ∈ eachindex(bit_names[i]) if bit_names[i][j] ∈ ptr_names[i]] - con = Expr( - :new, - ST{bit_size, nptrs, FT}, - :(unsafe_padded_reinterpret(NTuple{$bit_size, UInt8}, $bitvariant)), - Expr(:tuple, ptr_args..., (nothing for _ ∈ 1:(nptrs-length(ptr_args)))...), - :($FT(tag)), - ) -end - - -function unwrap(x::ST, s::Symbol) where {ST} - isvariant(x, s) || error("Incorrect tag used in unwrap") - unwrap(x, constructor(ST, Val{s}), variants_Tuple(ST)) -end -unwrap(x::ST, var) where {ST} = unwrap(x, var, variants_Tuple(ST)) -@generated function unwrap(x::ST, ::Type{Var}, ::Type{var_Tuple}) where {ST, Var, var_Tuple} - variants = var_Tuple.parameters - i = findfirst(==(Var), variants) - nt = extract_info(ST, variants) - ptrss = nt.ptrss - nptrs = nt.nptrs - ptr_names = nt.ptr_names - bit_size = nt.bit_size - bit_names = nt.bit_names - bit_sigs = nt.bit_sigs - quote - names = ($(QuoteNode.(bit_names[i])...),) - bits = unsafe_padded_reinterpret(Variant{names, Tuple{$(bit_sigs[i]...)}}, getfield(x, :bits)) - args = $(Expr(:tuple, - (bit_names[i][j] ∈ ptr_names[i] ? let k = findfirst(x -> x == bit_names[i][j], ptr_names[i]) - :(getfield(x, :ptrs)[$k]:: $(ptrss[i][k][2])) - end : :(bits.data[$j]) for j ∈ eachindex(bit_names[i]))...)) - Variant{names, $(Var.parameters[2])}(args) - end -end - -Base.@generated function full_type(::Type{ST}, ::Type{var_Tuple}) where {ST, var_Tuple} - variants = var_Tuple.parameters - nt = extract_info(ST, variants) - :($ST{$(nt.bit_size), $(nt.nptrs), $(nt.flagtype)}) -end diff --git a/src/sum_type.jl b/src/sum_type.jl index 0658efc..33d8d82 100644 --- a/src/sum_type.jl +++ b/src/sum_type.jl @@ -28,9 +28,9 @@ function _sum_type(T, hidden, blk) error("constructors must have unique names, got $(map(x -> x.name, constructors))") end - con_expr = generate_constructor_exprs(T_name, T_params, T_params_constrained, T_nameparam, constructors) + con_expr, con_structs = generate_constructor_exprs(T_name, T_params, T_params_constrained, T_nameparam, constructors) out = generate_sum_struct_expr(T, T_name, T_params, T_params_constrained, T_param_bounds, T_nameparam, constructors) - Expr(:toplevel, out, con_expr) + Expr(:toplevel, con_structs, out, con_expr) end #------------------------------------------------------ @@ -50,8 +50,8 @@ function generate_constructor_data(T_name, T_params, T_params_constrained, T_nam name = name, gname = gname, params = [], - store_type = Variant{(), Tuple{}}, - store_type_uninit = Variant{(), Tuple{}}, + store_type = Variant{name, (), Tuple{}}, + store_type_uninit = Variant{name, (), Tuple{}}, outer_type = name, gouter_type = gname, field_names = [], @@ -116,8 +116,8 @@ function generate_constructor_data(T_name, T_params, T_params_constrained, T_nam name = con_name, gname = gname, params = con_params, - store_type = :($Variant{($(QuoteNode.(con_field_names)...),), Tuple{$(con_field_types...)}}), - store_type_uninit = :($Variant{($(QuoteNode.(con_field_names)...),), Tuple{$(con_field_types_uninit...)}}), + store_type = :($Variant{$(QuoteNode(con_name)), ($(QuoteNode.(con_field_names)...),), Tuple{$(con_field_types...)}}), + store_type_uninit = :($Variant{$(QuoteNode(con_name)), ($(QuoteNode.(con_field_names)...),), Tuple{$(con_field_types_uninit...)}}), outer_type = con_nameparam, gouter_type = gnameparam, field_names = con_field_names, @@ -139,15 +139,9 @@ end function generate_constructor_exprs(T_name, T_params, T_params_constrained, T_nameparam, constructors) out = Expr(:toplevel) converts = [] - - @gensym N M FT _tag _T x + con_structs = Expr(:block) + @gensym _tag _T x enumerate_constructors = collect(enumerate(constructors)) - if_nest_conv = :( - let $_tag = $get_tag($x) - $(mapfoldr(((cond, data), old) -> Expr(:if, cond, data, old), enumerate_constructors, init=:(error("invalid tag"))) do (i, nt) - :($_tag == $(i-1) ), :($make($T_nameparam, $unwrap($x, $(nt.store_type)) , $_tag)) - end) - end) for nt ∈ constructors name = nt.name @@ -170,41 +164,39 @@ function generate_constructor_exprs(T_name, T_params, T_params_constrained, T_na T_init = isempty(T_params) ? T_name : :($T_name{$(T_params...)}) if value ex = quote - const $gname = $(Expr(:call, make, T_uninit, :($(nt.store_type_uninit)($unsafe)), Expr(:call, symbol_to_flag, T_name, QuoteNode(name)) )) - + const $gname = $(Expr(:call, T_uninit, :($(nt.store_type_uninit)($unsafe)))) end push!(out.args, ex) else field_names_typed = map(((name, type),) -> :($name :: $type), zip(field_names, field_types)) T_con = :($gouter_type($(field_names_typed...)) where {$(params_constrained...)} = - $(Expr(:call, make, T_uninit, :($store_type(($(field_names...),))), Expr(:call, symbol_to_flag, T_name, QuoteNode(name)) ))) + $(Expr(:call, T_uninit, :($store_type(($(field_names...),))), ))) T_con2 = if !all(x -> x ∈ (Any, :Any) ,field_types) s = Expr(:call, store_type, Expr(:tuple, [:($convert($field_type, $field_name)) for (field_type, field_name) ∈ zip(field_types, field_names)]...)) - :($gouter_type($(field_names...)) where {$(params_constrained...)} = - $(Expr(:call, make, T_uninit, s, Expr(:call, symbol_to_flag, T_name, QuoteNode(name))))) + $(Expr(:call, T_uninit, s))) end maybe_no_param = if !isempty(params) :($gname($(field_names_typed...)) where {$(params...)} = $gouter_type($(field_names...))) end struct_def = Expr(:struct, false, gouter_type_constrained, Expr(:block, :(1 + 1))) ex = quote - $struct_def $T_con $T_con2 $maybe_no_param $SumTypes.parent(::Type{<:$gname}) = $T_name end + push!(con_structs.args, struct_def) push!(out.args, ex) end - if true + if true push!(converts, T_uninit => quote - $Base.convert(::$Type{<:$T_init}, $x::$T_uninit) where {$(T_params_constrained...)} = $if_nest_conv - (::$Type{<:$T_init})($x::$T_uninit) where {$(T_params_constrained...)} = $if_nest_conv - $Base.convert(::$Type{$T_init}, $x::$T_uninit{$N, $M, $FT}) where {$(T_params_constrained...), $N, $M, $FT} = $if_nest_conv - (::$Type{$T_init})($x::$T_uninit) where {$(T_params_constrained...)} = $if_nest_conv + $Base.convert(::$Type{$T_init}, $x::$T_uninit) where {$(T_params_constrained...)} = $T_init($unwrap($x)) + $Base.convert(::$Type{<:$T_init}, $x::$T_uninit) where {$(T_params_constrained...)} = $T_init($unwrap($x)) + (::$Type{<:$T_init})($x::$T_uninit) where {$(T_params_constrained...)} = $T_init($unwrap($x)) + (::$Type{$T_init})($x::$T_uninit) where {$(T_params_constrained...)} = $T_init($unwrap($x)) end) end end @@ -214,7 +206,7 @@ function generate_constructor_exprs(T_name, T_params, T_params_constrained, T_na $Base.convert(::$Type{$_T}, $x::$_T) where {$(T_params_constrained...), $_T <: $T_nameparam} = $x (::$Type{$_T})($x::$_T) where {$(T_params_constrained...), $_T <: $T_nameparam} = $x end) - out + out, con_structs end @@ -225,20 +217,21 @@ function generate_sum_struct_expr(T, T_name, T_params, T_params_constrained, T_p con_gouter_types = (x -> x.gouter_type).(constructors) con_names = (x -> x.name ).(constructors) con_gnames = (x -> x.gname ).(constructors) + store_types = (x -> x.store_type).(constructors) + T_full = T - N = Symbol("#N#") - M = Symbol("#M#") - FT = Symbol("#FT#") - T_full = T isa Expr && T.head == :curly ? Expr(:curly, T.args..., N, M, FT) : Expr(:curly, T, N, M, FT) sum_struct_def = Expr(:struct, false, T_full, - Expr(:block, :(bits :: $NTuple{$N, $UInt8}), :(ptrs :: $NTuple{$M, $Any}), :($tag :: $FT), :(1 + 1))) + Expr(:block, :(data :: ($Union){$(store_types...)}), )) enumerate_constructors = collect(enumerate(constructors)) - - if_nest_unwrap = mapfoldr(((cond, data), old) -> Expr(:if, cond, data, old), enumerate_constructors, init=:(error("invalid tag"))) do (i, nt) - :(tag == $FT($(i-1))), :($unwrap(x, $(nt.store_type))) - end + ifnest_isvariant = mapfoldr(((cond, data), old) -> Expr(:if, cond, data, old), enumerate_constructors, init=false) do (i, nt) + :(unwrapped isa $(nt.store_type)), :($(QuoteNode(nt.name)) == s) + end + ifnest_get_tag = mapfoldr(((cond, data), old) -> Expr(:if, cond, data, old), enumerate_constructors, init=:THIS_SHOULD_BE_UNREACHABLE) do (i, nt) + :(unwrapped isa $(nt.store_type)), :($(QuoteNode(nt.name))) + end + only_define_with_params = if !isempty(T_params) @gensym x quote @@ -249,58 +242,41 @@ function generate_sum_struct_expr(T, T_name, T_params, T_params_constrained, T_p for nt ∈ constructors)...))) $SumTypes.variants_Tuple(::Type{<:$T_nameparam}) where {$(T_params_constrained...)} = $Tuple{$((nt.store_type for nt ∈ constructors)...)} - $SumTypes.full_type(::Type{$T_name}) = $full_type($T_name{$(T_param_bounds...)}, $variants_Tuple($T_name{$(T_param_bounds...)})) end end - # @show only_define_with_params @gensym _T ex = quote $sum_struct_def $SumTypes.is_sumtype(::Type{<:$T_name}) = true - $SumTypes.strip_size_params(::Type{$T_name{$(T_params...), $N, $M, $FT}}) where {$(T_params_constrained...), $N, $M, $FT} = $T_nameparam - $SumTypes.flagtype(::Type{$_T}) where {$_T <: $T_name} = $flagtype($full_type($_T)) - $SumTypes.flagtype(::Type{$T_name{$(T_params...), $N, $M, $FT}}) where {$(T_params_constrained...), $N, $M, $FT} = $FT - - $SumTypes.symbol_to_flag(::Type{$_T}, sym::Symbol) where {$_T <: $T_name} = - $(foldr(collect(enumerate(con_names)), init=:(error("Invalid tag symbol $sym"))) do (i, _sym), old - Expr(:if, :(sym == $(QuoteNode(_sym))), :($flagtype($_T)($(i-1))), old) - end) - $SumTypes.flag_to_symbol(::Type{<:$T_name}, flag::$Integer) = - $(foldr(collect(enumerate(con_names)), init=:(error("Invalid tag symbol $sym"))) do (i, sym), old - Expr(:if, :(flag == $(i-1)), QuoteNode(sym), old) - end) - $SumTypes.tags_flags_nt(::Type{<:$_T}) where {$_T <: $T_name} = - $(Expr(:tuple, Expr(:parameters, (Expr(:kw, name, :($flagtype($_T)($(i-1)))) for (i, name) ∈ enumerate(con_names))...))) - $SumTypes.tags(::Type{<:$T_name}) = $(Expr(:tuple, map(x -> QuoteNode(x.name), constructors)...)) - $SumTypes.constructors(::Type{<:$T_name}) = $NamedTuple{$tags($T_name)}($(Expr(:tuple, (nt.store_type_uninit for nt ∈ constructors)...))) $SumTypes.variants_Tuple(::Type{<:$T_name}) = $Tuple{$((nt.store_type_uninit for nt ∈ constructors)...)} - $SumTypes.unwrap(x::$T_nameparam{$N, $M, $FT}) where {$(T_params_constrained...), $N, $M, $FT}= let tag = $get_tag(x) - $if_nest_unwrap - end + $SumTypes.unwrap(x::$T_nameparam) where {$(T_params_constrained...)} = $getfield(x, :data) $Base.adjoint(::Type{<:$T_name}) = $NamedTuple{$tags($T_name)}($(Expr(:tuple, (nt.gname for nt ∈ constructors)...))) - - $SumTypes.full_type(::Type{$T_nameparam}) where {$(T_params_constrained...)} = $full_type($T_nameparam, $variants_Tuple($T_nameparam)) - $SumTypes.full_type(::Type{$T_nameparam{$N, $M, $FT}}) where {$(T_params_constrained...), $N, $M, $FT} = $T_nameparam{$N, $M, $FT} - + $SumTypes.isvariant(x::$T_nameparam, s::Symbol) where {$(T_params_constrained...)} = let unwrapped = $unwrap(x) + $ifnest_isvariant + end + $SumTypes.get_tag(x::$T_nameparam) where {$(T_params_constrained...)} = let unwrapped = $unwrap(x) + $ifnest_get_tag + end + $SumTypes.tags(::Type{<:$T_name}) = $(Expr(:tuple, QuoteNode.(con_names)...)) $Base.show(io::IO, x::$T_name) = $show_sumtype(io, x) $Base.show(io::IO, m::MIME"text/plain", x::$T_name) = $show_sumtype(io, m, x) - Base.:(==)(x::$T_name, y::$T_name) = ($get_tag(x) == $get_tag(y)) && ($unwrap(x) == $unwrap(y)) + Base.:(==)(x::$T_name, y::$T_name) = $Base.:(==)($unwrap(x), $unwrap(y)) $only_define_with_params end foreach(constructors) do nt con1 = :($SumTypes.constructor(::Type{<:$T_name}, ::Type{Val{$(QuoteNode(nt.name))}}) = $(nt.store_type_uninit)) con2 = if !isempty(T_params) :($SumTypes.constructor(::Type{<:$T_nameparam}, ::Type{Val{$(QuoteNode(nt.name))}}) where {$(T_params_constrained...)} = $(nt.store_type)) - end + end push!(ex.args, con1, con2) end ex diff --git a/test/runtests.jl b/test/runtests.jl index 5782663..fde239e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -105,63 +105,26 @@ Base.getproperty(f::Result, s::Symbol) = error("Don't do that!") @test x == y @test x != z end - @test SumTypes.get_tag_sym(Left([1])) == :Left + @test SumTypes.get_tag(Left([1])) == :Left - @test convert(full_type(Either{Int, Int}), Left(1)) == Left(1) - @test convert(full_type(Either{Int, Int}), Left(1)) !== Left(1) - @test convert(full_type(Either{Int, Int}), Left(1)) === Either{Int, Int}'.Left(1) - @test Either{Int, Int, 8, 0, UInt}(Left(1)) isa Either{Int, Int, 8, 0, UInt} - @test Either{Int, Int, 8, 0, UInt}(Either{Int, Int}(Left(1))) isa Either{Int, Int, 8, 0, UInt} + @test convert(Either{Int, Int}, Left(1)) == Left(1) + @test convert(Either{Int, Int}, Left(1)) !== Left(1) + @test convert(Either{Int, Int}, Left(1)) === Either{Int, Int}'.Left(1) + @test Either{Int, Int}(Left(1)) isa Either{Int, Int} @test_throws MethodError Left{Int}("hi") @test_throws MethodError Right{String}(1) @test Left{Int}(0x01) === Left{Int}(1) - @test full_type(Either{Nothing, Nothing}) == Either{Nothing, Nothing, 0, 0, UInt8} - @test full_type(Either{Int, Int}) == Either{Int, Int, 8, 0, UInt} - @test full_type(Either{Int, String}) == Either{Int, String, 8, 1, UInt8} - - @test full_type(Either{Nothing, Int16}) == Either{Nothing, Int16, 2, 0, UInt16} - @test full_type(Either{Int32, Int32}) == Either{Int32, Int32, 4, 0, UInt32} - @test convert(full_type(Result{Float64}), Success(1.0)) == Success(1.0) - let x = Left(1.0) @test SumTypes.isvariant(x, :Left) == true @test SumTypes.isvariant(x, :Right) == false - @test SumTypes.unwrap(x, :Left)[1] == 1.0 + @test SumTypes.unwrap(x)[1] == 1.0 end end #-------------------------------------------------------- -function pass_through(x::Either{T, U})::Either{T, U} where {T, U} - @cases x begin - Left(l) => Left(l) - Right(r) => Right(r) - end -end - -@testset "search for memory safety problems from lying about consistency" begin - data = (rand(Int8), rand(Int16), rand(Int32), rand(Int64), rand(Int128), - (rand(Int8), rand(Int)), (rand(Int), rand(Int8)), (rand(Int8), rand(Int32), rand(Int)), "hi", (rand(Int8), "hi")) - for x ∈ data - T = typeof(x) - for y ∈ data - U = typeof(y) - L = Either{T, U}'.Left(x) - R = Either{T, U}'.Right(y) - @test L == pass_through(L) - @test R == pass_through(R) - - @test L == @eval pass_through($L) - @test R == @eval pass_through($R) - end - end -end - -#-------------------------------------------------------- - - @sum_type List{A} begin Nil Cons{A}(::A, ::List) @@ -217,16 +180,16 @@ end @sum_type AT begin A(common_field::Int, a::Bool, b::Int) - B(common_field::Int, a::Int, b::Float64, d::Complex) - C(common_field::Int, b::Float64, d::Bool, e::Float64, k::Complex{Real}) - D(common_field::Int, b::Any) + B(common_field::Int, a::Int, b::Float64, d::Complex{Float64}) + C(common_field::Int, b::Float64, d::Bool, e::Float64, k::Complex{Float64}) + D(common_field::Int, b::Char) end Base.getproperty(f::AT, s::Symbol) = error("Don't do that!") A(;common=1, a=true, b=10) = A(common, a, b) B(;common=1, a=1, b=1.0, d=1 + 1.0im) = B(common, a, b, d) -C(;common=1, b=2.0, d=false, e=3.0, k=Complex{Real}(1 + 2im)) = C(common, b, d, e, k) -D(;common=1, b=:hi) = D(common, b) +C(;common=1, b=2.0, d=false, e=3.0, k=1 + 2im) = C(common, b, d, e, k) +D(;common=1, b='h') = D(common, b) foo!(xs) = for i in eachindex(xs) xs[i] = @cases xs[i] begin @@ -236,9 +199,7 @@ foo!(xs) = for i in eachindex(xs) D => A() end end - - -# #CI Doesn't like this test so just uncomment it for local testing +#CI Doesn't like this test so just disable it in CI if !haskey(ENV, "CI") || ENV["CI"] != "true" @testset "Allocation-free @cases" begin xs = map(x->rand((A(), B(), C(), D())), 1:10000); @@ -287,7 +248,6 @@ end A => 1 B => 2 end - end