Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

Commit

Permalink
SimpleJNFK working version
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 12, 2023
1 parent 59d69cd commit 9858095
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 18 deletions.
5 changes: 1 addition & 4 deletions ext/SimpleNonlinearSolveNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, Sc
import SimpleNonlinearSolve: _construct_batched_problem_structure,
_get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace

function __init__()
SimpleNonlinearSolve.NNlibExtLoaded[] = true
return
end
SimpleNonlinearSolve.extension_loaded(::Val{NNlib}) = true

Check warning on line 7 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L7

Added line #L7 was not covered by tests

@views function SciMLBase.__solve(prob::NonlinearProblem,
alg::BatchedBroyden;
Expand Down
2 changes: 1 addition & 1 deletion src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function __init__()
@require_extensions
end

const NNlibExtLoaded = Ref{Bool}(false)
extension_loaded(::Val) = false

Check warning on line 19 in src/SimpleNonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/SimpleNonlinearSolve.jl#L19

Added line #L19 was not covered by tests

abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
Expand Down
2 changes: 1 addition & 1 deletion src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function Broyden(; batched = false,
abstol = nothing,
reltol = nothing))
if batched
@assert NNlibExtLoaded[] "Please install and load `NNlib.jl` to use batched Broyden."
@assert extension_loaded(Val(:NNlib)) "Please install and load `NNlib.jl` to use batched Broyden."

Check warning on line 25 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L25

Added line #L25 was not covered by tests
return BatchedBroyden(termination_condition)
end
return Broyden(termination_condition)
Expand Down
36 changes: 25 additions & 11 deletions src/jnfk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,32 @@ function jvp_forwarddiff(f, x::AbstractArray{T}, v) where {T}
y = (Dual{Tag{SimpleJNFKJacVecTag, T}, T, 1}).(x, Partials.(tuple.(v_)))
return vec(ForwardDiff.partials.(vec(f(y)), 1))

Check warning on line 6 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L3-L6

Added lines #L3 - L6 were not covered by tests
end
jvp_forwarddiff!(r, f, x, v) = copyto!(r, jvp_forwarddiff(f, x, v))

Check warning on line 8 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L8

Added line #L8 was not covered by tests

struct JacVecOperator{F, X}
f::F
x::X
end

(jvp::JacVecOperator)(v, _, _) = jvp_forwarddiff(jvp.f, jvp.x, v)
(jvp::JacVecOperator)(r, v, _, _) = jvp_forwarddiff!(r, jvp.f, jvp.x, v)

Check warning on line 16 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L15-L16

Added lines #L15 - L16 were not covered by tests

"""
SimpleJNFK()
SimpleJNFK(; batched::Bool = false)
A low overhead Jacobian-free Newton-Krylov method. This method internally uses `GMRES` to
avoid computing the Jacobian Matrix.
!!! warning
JNFK doesn't work well without preconditioning, which is currently not supported. We
recommend using `NewtonRaphson(linsolve = KrylovJL_GMRES())` for preconditioning
support.
"""
struct SimpleJFNK end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleJFNK, args...;

Check warning on line 32 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L32

Added line #L32 was not covered by tests
abstol = nothing, reltol= nothing, maxiters = 1000, linsolve_kwargs = (;), kwargs...)
abstol = nothing, reltol = nothing, maxiters = 1000, linsolve_kwargs = (;), kwargs...)
iip = SciMLBase.isinplace(prob)
@assert !iip "SimpleJFNK does not support inplace problems"

Check warning on line 35 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L34-L35

Added lines #L34 - L35 were not covered by tests

Expand All @@ -29,26 +39,30 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleJFNK, args...;
fx = f(x)
T = typeof(x)

Check warning on line 40 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L37-L40

Added lines #L37 - L40 were not covered by tests

iszero(fx) &&

Check warning on line 42 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L42

Added line #L42 was not covered by tests
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

atol = abstol !== nothing ? abstol :

Check warning on line 45 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L45

Added line #L45 was not covered by tests
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)

Check warning on line 47 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L47

Added line #L47 was not covered by tests

op = FunctionOperator(JacVecOperator(f, x), x)
linprob = LinearProblem(op, -fx)
lincache = init(linprob, SimpleGMRES(); abstol, reltol, maxiters, linsolve_kwargs...)
linprob = LinearProblem(op, vec(fx))
lincache = init(linprob, KrylovJL_GMRES(); abstol = atol, reltol = rtol, maxiters,

Check warning on line 51 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L49-L51

Added lines #L49 - L51 were not covered by tests
linsolve_kwargs...)

for i in 1:maxiters
iszero(fx) &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

linsol = solve!(lincache)
x .-= linsol.u
axpy!(-1, linsol.u, x)
lincache = linsol.cache

Check warning on line 57 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L54-L57

Added lines #L54 - L57 were not covered by tests

# FIXME: not nothing
if isapprox(x, nothing; atol, rtol)
fx = f(x)

Check warning on line 59 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L59

Added line #L59 was not covered by tests

norm(fx, Inf) atol &&

Check warning on line 61 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L61

Added line #L61 was not covered by tests
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
end

lincache.b = vec(fx)
lincache.A = FunctionOperator(JacVecOperator(f, x), x)
end

Check warning on line 66 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L64-L66

Added lines #L64 - L66 were not covered by tests

return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)

Check warning on line 68 in src/jnfk.jl

View check run for this annotation

Codecov / codecov/patch

src/jnfk.jl#L68

Added line #L68 was not covered by tests
Expand Down
1 change: 0 additions & 1 deletion src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ function SimpleNewtonRaphson(; batched = false,
throw(ArgumentError("`termination_condition` is currently only supported for batched problems"))
end
if batched
# @assert ADLinearSolveFDExtLoaded[] "Please install and load `LinearSolve.jl`, `FiniteDifferences.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson."
termination_condition = ismissing(termination_condition) ?
NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
Expand Down

0 comments on commit 9858095

Please sign in to comment.