Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PairwiseFusion layer, take 2 #1983

Merged
merged 14 commits into from
Jun 6, 2022
Merged

Conversation

theabhirath
Copy link
Member

@theabhirath theabhirath commented Jun 2, 2022

I messed up the commits on #1971 a little bit, so this one should close #1971. If merged, this will also close #1989.


This PR implements a PairwiseFusion layer similar to the one in Lux. This might be especially useful for implementing models like Res2Net, since the forward pass for the bottleneck layer is quite similar to this.

PR Checklist

  • Add pretty-printing similar to Parallel
  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@codecov-commenter
Copy link

codecov-commenter commented Jun 2, 2022

Codecov Report

Merging #1983 (d0f0a29) into master (b6b3569) will decrease coverage by 0.72%.
The diff coverage is 67.34%.

@@            Coverage Diff             @@
##           master    #1983      +/-   ##
==========================================
- Coverage   87.82%   87.10%   -0.73%     
==========================================
  Files          20       20              
  Lines        1487     1528      +41     
==========================================
+ Hits         1306     1331      +25     
- Misses        181      197      +16     
Impacted Files Coverage Δ
src/Flux.jl 0.00% <ø> (ø)
src/layers/show.jl 71.79% <50.00%> (-0.94%) ⬇️
src/layers/basic.jl 77.32% <68.08%> (-4.50%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update b6b3569...d0f0a29. Read the comment docs.

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
## Returns

A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a concrete example

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's one in the tests, but I was holding back because I wanted to take time later to come up with a better example

Copy link
Member

@mcabbott mcabbott Jun 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

julia> PairwiseFusion(vcat, x->x+1, x->x.+2, x->x.^3)(2)
(3, [5, 4], [125, 64, 8])

julia> PairwiseFusion(vcat, x->x+1, x->x.+2, x->x.^3)((2, 10, 20))
(3, [5, 12], [125, 1728, 8000])

@darsnack
Copy link
Member

darsnack commented Jun 2, 2022

Despite the performance difference, do we feel good about a generated function? Why just this layer and not all the control flow layers like Chain or Parallel? One of the goals of the Julia + Flux is that you should be able to just write a loop, have it be fast, and have it be differentiable. Needing to unroll any loops seems clearly sub-optimal to me.

A couple questions related to performance:

  • Is the performance difference coming from the backwards pass? Hopefully the loop and generated function are not too far off on the forwards pass.
  • Is this consistent across Zygote versions/Julia versions (including 1.5)? Somewhere in the Zygote 0.6 and/or Julia 1.6 transition, there was a noticeable degradation. I know I have said this in several places without providing hard data, but it was measurable enough that we got issues on the Zygote about it, we saw it in Metalhead test times, and my own anecdotal experience. Hopefully I have time after I defend next week to benchmark these things (I feel guilty about complaining without doing the work 😄).
  • Does using a Vector help the loop-version performance like it does for Chain?

@theabhirath
Copy link
Member Author

theabhirath commented Jun 3, 2022

Well, the performance on the forward pass is almost exactly the same. The backwards pass, though...

Using a generated function:

julia> model = PairwiseFusion(+,  Dense(1, 30),  Dense(30, 10));

julia> x = (rand(1, 10), rand(30, 10));

julia> @benchmark $model($x)
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range (min  max):  1.158 μs  265.258 μs  ┊ GC (min  max):  0.00%  98.56%
 Time  (median):     1.746 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   2.387 μs ±  13.244 μs  ┊ GC (mean ± σ):  28.64% ±  5.12%

   ▅▁                 ▁▃█▄▃▁
  ▄██▃▂▂▂▁▁▁▁▁▂▂▂▂▂▂▄▅██████▇▆▅▄▃▄▄▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  1.16 μs         Histogram: frequency by time        2.69 μs <

 Memory estimate: 12.05 KiB, allocs estimate: 7.

julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 8 evaluations.
 Range (min  max):  3.229 μs  330.734 μs  ┊ GC (min  max):  0.00%  97.61%
 Time  (median):     4.875 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   6.764 μs ±  24.378 μs  ┊ GC (mean ± σ):  27.72% ±  7.54%

   ▃▁            ▆█▄▅▅
  ▃██▃▂▃▂▂▃▄▅▆▆▆███████▇▆▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  3.23 μs         Histogram: frequency by time        9.07 μs <

 Memory estimate: 34.39 KiB, allocs estimate: 41.

Using standard for-loops:

julia> @benchmark $model($x)
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range (min  max):  1.488 μs  408.717 μs  ┊ GC (min  max):  0.00%  98.80%
 Time  (median):     2.146 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   2.842 μs ±  16.806 μs  ┊ GC (mean ± σ):  24.95% ±  4.19%

                  ▂▇▇█▄
  ▃▇▆▅▃▃▃▃▃▃▂▃▃▃▄▇██████▇▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂ ▃
  1.49 μs         Histogram: frequency by time        3.58 μs <

 Memory estimate: 12.30 KiB, allocs estimate: 13.

julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  41.417 μs    5.011 ms  ┊ GC (min  max): 0.00%  96.32%
 Time  (median):     44.000 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   48.305 μs ± 139.299 μs  ┊ GC (mean ± σ):  7.96% ±  2.74%

          ▁▁▃▆█▅▅▆▄▇▃
  ▁▂▃▅▇▆▇█████████████▇▅▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  41.4 μs         Histogram: frequency by time           52 μs <

 Memory estimate: 53.23 KiB, allocs estimate: 381.

So it's not a "difference" so much as the generated function absolutely wiping the floor with the regular for-loop (8x). I'm not sure what kind of speedup there would be if the same thing is done for Chain et al. While loops really should work, Zygote doesn't seem to be obliging just as yet 😅

Julia version:

julia> versioninfo()
Julia Version 1.9.0-DEV.695
Commit d84f8901ae (2022-06-01 05:24 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin21.5.0)
  CPU: 8 × Apple M1
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, apple-m1)
  Threads: 4 on 4 virtual cores
Environment:
  JULIA_NUM_THREADS = 4

This is Flux 0.13.3, Zygote 0.6.40. I will have to check for other Julia versions - but anything before 1.7 might not be the best benchmarked on my computer because it won't be native, so Rosetta might interfere with performance.

@theabhirath
Copy link
Member Author

I've also opened #1989 so that if there is an obvious performance trick that I'm missing there, someone can help me fix it. It'll also help in people benchmarking independently on their respective machines as well

@ToucheSir
Copy link
Member

Despite the performance difference, do we feel good about a generated function? Why just this layer and not all the control flow layers like Chain or Parallel?

To be fair, we did switch Chain over to a generated function for exactly this reason. Parallel doesn't need it because mapreduce is flexible enough to handle all of its cases, but after playing around with a number of variations (including foldl), a generated function came out on top for Chain.

For that reason, my vote is to get this in first and then work on getting rid of @generated. One idea I've been mulling over is a throwback to the old applychain model of recursively folding over a tuple, but with liberal use of @nospecialize to avoid the crushing compile time penalty that comes with. Unclear whether that will have better TTFG than a generated function, so I'd not count them out until we know for sure.

@darsnack
Copy link
Member

darsnack commented Jun 3, 2022

Okay this is embarrassing but I didn't realize we already switched Chain over. I somehow thought we ended on slightly varied recursive implementation + Vector. I should have checked the current code instead of my memory of PRs, my bad.

In which case I'm inclined to agree with you, if we're already using generated functions in the short term, then no need to not merge this.

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
end
ChainRulesCore.@non_differentiable _pairwise_check(lx, N, T)

function (m::PairwiseFusion)(x::T) where {T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this, like Parallel, allow m(x1, x2, x3) == m((x1, x2, x3))?

I also wonder if the one-x case should be x::AbstractArray{<:Number}, or something. So that we don't find out someone is relying on some unintended behaviour, e.g. how a NamedTuple is handled. Although Parallel does not make such a restriction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first suggestion seems okay. The second would make the layer less usable if the sub-layers are custom layers that accept something other than the type restriction that we provide. Presumably, in most cases, the sub-layers should be appropriately restricted to throw an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've allowed the first. Not sure about the second...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the method allowing this:

julia> PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(2, 10, 20)
ERROR: MethodError: no method matching (::PairwiseFusion{typeof(vcat), Tuple{var"#130#133", var"#131#134", var"#132#135"}})(::Int64, ::Int64, ::Int64)
Closest candidates are:
  (::PairwiseFusion)(::T) where T at REPL[11]:1

Agree that the second is less obvious. In particular it rules out many easy readme examples.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, yeah, I think I'd missed pushing that - should be fixed now

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
theabhirath and others added 4 commits June 3, 2022 22:09
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-Authored-By: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-Authored-By: Carlo Lucibello <carlo.lucibello@gmail.com>
Also move checks to non-differentiable helper functions
@theabhirath
Copy link
Member Author

theabhirath commented Jun 3, 2022

I rebased, not merged, but git seems to have weird duplicated docs commits that are partially squashed 🤦🏽‍♂️ Not sure what happened

@theabhirath
Copy link
Member Author

Bump?

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No more comments from me, will merge tonight barring any objections :)

Add tests with `vcat`
@ToucheSir ToucheSir merged commit 0b01b77 into FluxML:master Jun 6, 2022
@theabhirath theabhirath deleted the pairwise-fusion-2 branch June 6, 2022 16:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants