Skip to content

Commit

Permalink
Fix memoization issue (TuringLang#1414)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Sep 24, 2020
1 parent e6430f1 commit 96a79f3
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.14.3"
version = "0.14.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
20 changes: 8 additions & 12 deletions src/core/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function gradient_logp(
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
Expand All @@ -46,12 +46,10 @@ end
@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
setrdcache(::Val{true}) = RDCache[] = true
function emptyrdcache()
for k in keys(Memoization.caches)
if k[1] === typeof(memoized_taperesult)
pop!(Memoization.caches, k)
end
end
Memoization.empty_cache!(memoized_taperesult)
return
end

function gradient_logp(
backend::ReverseDiffAD{true},
θ::AbstractVector{<:Real},
Expand All @@ -61,7 +59,7 @@ end
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
Expand All @@ -81,15 +79,13 @@ end
f::F
x::Tx
end
function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any})
function Memoization._get!(f, d::Dict, keys::Tuple{Tuple{RDTapeKey}, Any})
key = keys[1][1]
return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x)))
return Memoization._get!(f, d, (key.f, typeof(key.x), size(key.x), Threads.threadid()))
end
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
Memoization.@memoize function memoized_taperesult(k::RDTapeKey)
Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey)
return compiledtape(k.f, k.x), GradientResult(k.x)
end
memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x))
Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x)
compiledtape(f, x) = compile(GradientTape(f, x))
end
28 changes: 26 additions & 2 deletions test/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,13 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
sample(dir(), HMC(0.01, 1), 1000);
Turing.setrdcache(true)
sample(dir(), HMC(0.01, 1), 1000);
@test length(Memoization.caches) == 1
caches = Memoization.find_caches(Turing.Core.memoized_taperesult)
@test length(caches) == 1
@test !isempty(first(values(caches)))
Turing.emptyrdcache()
@test length(Memoization.caches) == 0
caches = Memoization.find_caches(Turing.Core.memoized_taperesult)
@test length(caches) == 1
@test isempty(first(values(caches)))
end
# FIXME: For some reasons PDMatDistribution AD tests fail with ReverseDiff
@testset "PDMatDistribution AD" begin
Expand Down Expand Up @@ -340,4 +344,24 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
@test H_f == [1.0 0.0; 0.0 1.0]
@test H_f == H_r
end

@testset "memoization: issue #1393" begin
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)

@model function demo(data)
sigma ~ Uniform(0.0, 20.0)
data ~ Normal(0, sigma)
end

N = 1000
for i in 1:5
d = Normal(0.0, i)
data = rand(d, N)
chn = sample(demo(data), NUTS(0.65), 1000)
@test mean(Array(chn[:sigma])) std(data) atol=0.5
end

Turing.emptyrdcache()
end
end
7 changes: 1 addition & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,8 @@ include("test_utils/AllUtils.jl")
include("core/container.jl")
end

test_adbackends = if VERSION >= v"1.2"
[:forwarddiff, :tracker, :reversediff]
else
[:forwarddiff, :tracker]
end
Turing.setrdcache(false)
for adbackend in test_adbackends
for adbackend in (:forwarddiff, :tracker, :reversediff)
Turing.setadbackend(adbackend)
@testset "inference: $adbackend" begin
@testset "samplers" begin
Expand Down

0 comments on commit 96a79f3

Please sign in to comment.