Skip to content

Commit

Permalink
Merge pull request #68 from SciML/dataarg
Browse files Browse the repository at this point in the history
Don't use DiffResults in Flux optimiser dispatch with FiniteDiff
  • Loading branch information
Vaibhavdixit02 authored Oct 25, 2020
2 parents 2d585bb + 69127dd commit 7a74b3b
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GalacticOptim"
uuid = "a75be94c-b780-496d-a8a9-0878b188d577"
authors = ["Vaibhavdixit02 <vaibhavyashdixit@gmail.com>"]
version = "0.4.0"
version = "0.4.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
4 changes: 2 additions & 2 deletions src/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,13 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))

if f.grad === nothing
grad = (res, θ, args...) -> FiniteDiff.finite_difference_gradient!(res,x ->_f(x, args...), θ, FiniteDiff.GradientCache(res, x, adtype.fdtype))
grad = (res, θ, args...) -> FiniteDiff.finite_difference_gradient!(res, x ->_f(x, args...), θ, FiniteDiff.GradientCache(res, x, adtype.fdtype))
else
grad = f.grad
end

if f.hess === nothing
hess = (res, θ, args...) -> FiniteDiff.finite_difference_hessian!(res,x ->_f(x, args...), θ, FiniteDiff.HessianCache(x, adtype.fdhtype))
hess = (res, θ, args...) -> FiniteDiff.finite_difference_hessian!(res, x ->_f(x, args...), θ, FiniteDiff.HessianCache(x, adtype.fdhtype))
else
hess = f.hess
end
Expand Down
4 changes: 2 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function __solve(prob::OptimizationProblem, opt, _data = DEFAULT_DATA;cb = (args

@withprogress progress name="Training" begin
for (i,d) in enumerate(data)
gs = DiffResults.GradientResult(θ)
gs = prob.f.adtype isa AutoFiniteDiff ? Array{Number}(undef,length(θ)) : DiffResults.GradientResult(θ)
f.grad(gs, θ, d...)
x = f.f(θ, prob.p, d...)
cb_call = cb(θ, x...)
Expand All @@ -105,7 +105,7 @@ function __solve(prob::OptimizationProblem, opt, _data = DEFAULT_DATA;cb = (args
end
msg = @sprintf("loss: %.3g", x[1])
progress && ProgressLogging.@logprogress msg i/maxiters
update!(opt, ps, DiffResults.gradient(gs))
update!(opt, ps, prob.f.adtype isa AutoFiniteDiff ? gs : DiffResults.gradient(gs))

if save_best
if first(x) < first(min_err) #found a better solution
Expand Down
3 changes: 3 additions & 0 deletions test/ADtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,6 @@ sol = solve(prob, Newton())

sol = solve(prob, Optim.KrylovTrustRegion())
@test sol.minimum < l1 #the loss doesn't go below 5e-1 here

sol = solve(prob, ADAM(0.1))
@test 10*sol.minimum < l1

0 comments on commit 7a74b3b

Please sign in to comment.