Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Embedding layer #1516

Merged
merged 14 commits into from
Jul 13, 2021
Prev Previous commit
Next Next commit
more tests
  • Loading branch information
CarloLucibello committed Jul 13, 2021
commit e54440b476877ff5b3857f25ddb4028c8d85e74c
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand Down
44 changes: 9 additions & 35 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ pixelshuffle = [PixelShuffle]
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)

embedding = [Embedding]
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix repeated indices", OneHotMatrix([1,2,2], 5), 5, 2)

@testset "function layers" begin
x = rand(Float32, 3,3)
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
Expand Down Expand Up @@ -259,38 +268,3 @@ end
end
end
end

@testset "Embedding" begin
vocab_size, embed_size = 5, 2
m = Flux.Embedding(vocab_size, embed_size)

x = [1, 3, 5]
y = m(x)
m_g = m |> gpu
x_g = x |> gpu
y_g = m_g(x_g)
@test collect(y_g) == y

gs = gradient(() -> sum(m(x)), params(m))
gs_g = gradient(() -> sum(m_g(x_g)), params(m_g))
@test collect(gs_g[m_g.weight]) ≈ gs[m.weight]

gs = gradient(() -> sum(tanh.(m(x))), params(m))
gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g))
@test collect(gs_g[m_g.weight]) ≈ gs[m.weight]

@testset "repeated indices" begin
vocab_size, embed_size = 5, 2
m = Flux.Embedding(vocab_size, embed_size)

x = [1, 3, 5, 3] # repeated indexes
y = m(x)
m_g = m |> gpu
x_g = x |> gpu
y_g = m_g(x_g)
@test Array(y_g) == y
gs = gradient(() -> sum(m(x)), params(m))
gs_g = gradient(() -> sum(m_g(x_g)), params(m_g))
@test Array(gs_g[m_g.weight]) ≈ gs[m.weight]
end
end