-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PairwiseFusion
layer, take 2
#1983
Conversation
Add pretty printing
Also add NEWS.md entry
Codecov Report
@@ Coverage Diff @@
## master #1983 +/- ##
==========================================
- Coverage 87.82% 87.10% -0.73%
==========================================
Files 20 20
Lines 1487 1528 +41
==========================================
+ Hits 1306 1331 +25
- Misses 181 197 +16
Continue to review full report at Codecov.
|
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
## Returns | ||
|
||
A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above). | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to add a concrete example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's one in the tests, but I was holding back because I wanted to take time later to come up with a better example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about:
julia> PairwiseFusion(vcat, x->x+1, x->x.+2, x->x.^3)(2)
(3, [5, 4], [125, 64, 8])
julia> PairwiseFusion(vcat, x->x+1, x->x.+2, x->x.^3)((2, 10, 20))
(3, [5, 12], [125, 1728, 8000])
Despite the performance difference, do we feel good about a generated function? Why just this layer and not all the control flow layers like A couple questions related to performance:
|
Well, the performance on the forward pass is almost exactly the same. The backwards pass, though... Using a generated function: julia> model = PairwiseFusion(+, Dense(1, 30), Dense(30, 10));
julia> x = (rand(1, 10), rand(30, 10));
julia> @benchmark $model($x)
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
Range (min … max): 1.158 μs … 265.258 μs ┊ GC (min … max): 0.00% … 98.56%
Time (median): 1.746 μs ┊ GC (median): 0.00%
Time (mean ± σ): 2.387 μs ± 13.244 μs ┊ GC (mean ± σ): 28.64% ± 5.12%
▅▁ ▁▃█▄▃▁
▄██▃▂▂▂▁▁▁▁▁▂▂▂▂▂▂▄▅██████▇▆▅▄▃▄▄▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
1.16 μs Histogram: frequency by time 2.69 μs <
Memory estimate: 12.05 KiB, allocs estimate: 7.
julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 8 evaluations.
Range (min … max): 3.229 μs … 330.734 μs ┊ GC (min … max): 0.00% … 97.61%
Time (median): 4.875 μs ┊ GC (median): 0.00%
Time (mean ± σ): 6.764 μs ± 24.378 μs ┊ GC (mean ± σ): 27.72% ± 7.54%
▃▁ ▆█▄▅▅
▃██▃▂▃▂▂▃▄▅▆▆▆███████▇▆▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
3.23 μs Histogram: frequency by time 9.07 μs <
Memory estimate: 34.39 KiB, allocs estimate: 41. Using standard for-loops: julia> @benchmark $model($x)
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
Range (min … max): 1.488 μs … 408.717 μs ┊ GC (min … max): 0.00% … 98.80%
Time (median): 2.146 μs ┊ GC (median): 0.00%
Time (mean ± σ): 2.842 μs ± 16.806 μs ┊ GC (mean ± σ): 24.95% ± 4.19%
▂▇▇█▄
▃▇▆▅▃▃▃▃▃▃▂▃▃▃▄▇██████▇▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂ ▃
1.49 μs Histogram: frequency by time 3.58 μs <
Memory estimate: 12.30 KiB, allocs estimate: 13.
julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 41.417 μs … 5.011 ms ┊ GC (min … max): 0.00% … 96.32%
Time (median): 44.000 μs ┊ GC (median): 0.00%
Time (mean ± σ): 48.305 μs ± 139.299 μs ┊ GC (mean ± σ): 7.96% ± 2.74%
▁▁▃▆█▅▅▆▄▇▃
▁▂▃▅▇▆▇█████████████▇▅▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
41.4 μs Histogram: frequency by time 52 μs <
Memory estimate: 53.23 KiB, allocs estimate: 381. So it's not a "difference" so much as the generated function absolutely wiping the floor with the regular for-loop (8x). I'm not sure what kind of speedup there would be if the same thing is done for Julia version: julia> versioninfo()
Julia Version 1.9.0-DEV.695
Commit d84f8901ae (2022-06-01 05:24 UTC)
Platform Info:
OS: macOS (arm64-apple-darwin21.5.0)
CPU: 8 × Apple M1
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, apple-m1)
Threads: 4 on 4 virtual cores
Environment:
JULIA_NUM_THREADS = 4 This is Flux 0.13.3, Zygote 0.6.40. I will have to check for other Julia versions - but anything before 1.7 might not be the best benchmarked on my computer because it won't be native, so Rosetta might interfere with performance. |
I've also opened #1989 so that if there is an obvious performance trick that I'm missing there, someone can help me fix it. It'll also help in people benchmarking independently on their respective machines as well |
To be fair, we did switch For that reason, my vote is to get this in first and then work on getting rid of |
Okay this is embarrassing but I didn't realize we already switched In which case I'm inclined to agree with you, if we're already using generated functions in the short term, then no need to not merge this. |
end | ||
ChainRulesCore.@non_differentiable _pairwise_check(lx, N, T) | ||
|
||
function (m::PairwiseFusion)(x::T) where {T} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this, like Parallel, allow m(x1, x2, x3) == m((x1, x2, x3))
?
I also wonder if the one-x
case should be x::AbstractArray{<:Number}
, or something. So that we don't find out someone is relying on some unintended behaviour, e.g. how a NamedTuple is handled. Although Parallel does not make such a restriction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first suggestion seems okay. The second would make the layer less usable if the sub-layers are custom layers that accept something other than the type restriction that we provide. Presumably, in most cases, the sub-layers should be appropriately restricted to throw an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've allowed the first. Not sure about the second...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the method allowing this:
julia> PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(2, 10, 20)
ERROR: MethodError: no method matching (::PairwiseFusion{typeof(vcat), Tuple{var"#130#133", var"#131#134", var"#132#135"}})(::Int64, ::Int64, ::Int64)
Closest candidates are:
(::PairwiseFusion)(::T) where T at REPL[11]:1
Agree that the second is less obvious. In particular it rules out many easy readme examples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, yeah, I think I'd missed pushing that - should be fixed now
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-Authored-By: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Co-Authored-By: Carlo Lucibello <carlo.lucibello@gmail.com>
Also move checks to non-differentiable helper functions
…ux.jl into pairwise-fusion-2
I rebased, not merged, but git seems to have weird duplicated docs commits that are partially squashed 🤦🏽♂️ Not sure what happened |
Bump? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No more comments from me, will merge tonight barring any objections :)
Add tests with `vcat`
I messed up the commits on #1971 a little bit, so this one should close #1971. If merged, this will also close #1989.
This PR implements a
PairwiseFusion
layer similar to the one in Lux. This might be especially useful for implementing models like Res2Net, since the forward pass for the bottleneck layer is quite similar to this.PR Checklist
Parallel