Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memoization issue #1395

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix memoization issues
  • Loading branch information
devmotion committed Aug 26, 2020
commit 8d818de4f9d45f903a87983d1f7b6af061350b58
14 changes: 5 additions & 9 deletions src/core/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,10 @@ end
@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
setrdcache(::Val{true}) = RDCache[] = true
function emptyrdcache()
for k in keys(Memoization.caches)
if k[1] === typeof(memoized_taperesult)
pop!(Memoization.caches, k)
end
end
Memoization.empty_cache!(memoized_taperesult)
return
end

function gradient_logp(
backend::ReverseDiffAD{true},
θ::AbstractVector{<:Real},
Expand Down Expand Up @@ -81,15 +79,13 @@ end
f::F
x::Tx
end
function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any})
function Memoization._get!(f, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any})
key = keys[1][1]
return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x)))
return Memoization._get!(f, d, (key.f, typeof(key.x), size(key.x)))
end
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
Memoization.@memoize function memoized_taperesult(k::RDTapeKey)
return compiledtape(k.f, k.x), GradientResult(k.x)
end
memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x))
Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x)
compiledtape(f, x) = compile(GradientTape(f, x))
end
30 changes: 28 additions & 2 deletions test/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,13 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
sample(dir(), HMC(0.01, 1), 1000);
Turing.setrdcache(true)
sample(dir(), HMC(0.01, 1), 1000);
@test length(Memoization.caches) == 1
caches = Memoization.find_caches(Turing.Core.memoized_taperesult)
@test length(caches) == 1
@test !isempty(first(values(caches)))
Turing.emptyrdcache()
@test length(Memoization.caches) == 0
caches = Memoization.find_caches(Turing.Core.memoized_taperesult)
@test length(caches) == 1
@test isempty(first(values(caches)))
end
# FIXME: For some reasons PDMatDistribution AD tests fail with ReverseDiff
@testset "PDMatDistribution AD" begin
Expand Down Expand Up @@ -340,4 +344,26 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
@test H_f == [1.0 0.0; 0.0 1.0]
@test H_f == H_r
end

@testset "memoization: issue #1393" begin
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)

@model function demo(data)
sigma ~ Uniform(0.0, 20.0)
data ~ Normal(0, sigma)
end

N = 1000

for emptycache in (true, false)
for i in 1:10
d = Normal(0.0, i)
data = rand(d, N)
chn = sample(demo(data), NUTS(0.65), 1000)
@test mean(Array(chn[:sigma])) ≈ std(data) atol=0.5
emptycache && Turing.emptyrdcache()
end
end
end
end