You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
using CuArrays
using Flux
m1 = Tracker.param(rand(3, 100000) |> gpu)
m2 = Tracker.param(rand(100000, 3) |> gpu)
Tracker.forward((x,y)->x*y, m1, m2) # takes ~15 seconds even after repeated execution
Tracker.forward((x,y)->x*y, cpu(m1), cpu(m2)) # returns almost immediately
AD of a matrix-matrix product is needed e.g. for Neural Style Transfer (used https://pytorch.org/tutorials/advanced/neural_style_tutorial.html as a rough guideline, the gram-matrix needs a matrix-matrix product), but I noticed that training on the GPU is far slower than on the CPU. I narrowed it down to the issue shown by the code above.
The text was updated successfully, but these errors were encountered:
(Created by ] add CuArrays Flux in a new environment)
For some reason, the @time macro is kind of misleading here. It displays similar timings as in your test above, but the first call blocks further calculations / the REPL task way longer. That's why I put the timings in comments, rather than using @time in the opening comment.
julia> CuArrays.@time CuArrays.@sync Tracker.forward((x,y)->x*y, m1, m2);
0.011956 seconds (2.29 k CPU allocations:111.901 KiB) (1 GPU allocation:36 bytes, 0.65% gc time of which 100.00% spent allocating)
That should get the full computation time (if it doesn't this is a very strange issue). Assuming it does, post the stats here and it'll show if it's a memory issue. Beyond that it might be good to try running under the profiling tools (both Julia and cudanative/nvprof).
AD of a matrix-matrix product is needed e.g. for Neural Style Transfer (used https://pytorch.org/tutorials/advanced/neural_style_tutorial.html as a rough guideline, the gram-matrix needs a matrix-matrix product), but I noticed that training on the GPU is far slower than on the CPU. I narrowed it down to the issue shown by the code above.
The text was updated successfully, but these errors were encountered: