Skip to content

Commit

Permalink
Fix MLE/MAP with Zygote and ReverseDiff (TuringLang#1408)
Browse files Browse the repository at this point in the history
* Fix AD for modes

* Increment version number

* Remove display calls

* Move tests
  • Loading branch information
cpfiffer authored Sep 15, 2020
1 parent ae24c28 commit 1d87a1d
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.14.2"
version = "0.14.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
6 changes: 4 additions & 2 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ getADbackend(spl::SampleFromPrior) = ADBackend()()
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
sampler::AbstractSampler,
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
Computes the value of the log joint of `θ` and its gradient for the model
Expand All @@ -89,6 +90,7 @@ gradient_logp(
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
Compute the value of the log joint of `θ` and its gradient for the model
Expand Down Expand Up @@ -160,7 +162,7 @@ function gradient_logp(
# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler)
model(new_vi, sampler, context)
return getlogp(new_vi)
end

Expand Down
4 changes: 2 additions & 2 deletions src/core/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function gradient_logp(
# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler)
model(new_vi, sampler, context)
return getlogp(new_vi)
end
tp, result = taperesult(f, θ)
Expand Down Expand Up @@ -65,7 +65,7 @@ end
# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler)
model(new_vi, sampler, context)
return getlogp(new_vi)
end
ctp, result = memoized_taperesult(f, θ)
Expand Down
1 change: 1 addition & 0 deletions test/modes/ModeEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using NamedArrays
using ReverseDiff
using Random
using LinearAlgebra
using Zygote

dir = splitdir(splitdir(pathof(Turing))[1])[1]
include(dir*"/test/test_utils/AllUtils.jl")
Expand Down
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ include("test_utils/AllUtils.jl")
@testset "variational algorithms : $adbackend" begin
include("variational/advi.jl")
end

@testset "modes" begin
include("modes/ModeEstimation.jl")
end
end
@testset "variational optimisers" begin
include("variational/optimisers.jl")
Expand All @@ -55,8 +59,4 @@ include("test_utils/AllUtils.jl")
# include("utilities/stan-interface.jl")
include("inference/utilities.jl")
end

@testset "modes" begin
include("modes/ModeEstimation.jl")
end
end

0 comments on commit 1d87a1d

Please sign in to comment.