From 9b51302a295880bbcb1d2f1c6e098722c424d025 Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Tue, 30 May 2017 11:30:32 +0200 Subject: [PATCH] Store info in Cholesky type (#21976) * add info field to Cholesky type and delay throwing for non-positive definiteness * comment update [ci skip] * remove comment [ci skip] --- base/linalg/cholesky.jl | 79 ++++++++++++++++++++++++++--------------- base/linalg/dense.jl | 8 +++-- base/linalg/lapack.jl | 4 +-- test/linalg/cholesky.jl | 24 +++++++++---- 4 files changed, 74 insertions(+), 41 deletions(-) diff --git a/base/linalg/cholesky.jl b/base/linalg/cholesky.jl index 0d09fa998028d..2b5c89cad03ef 100644 --- a/base/linalg/cholesky.jl +++ b/base/linalg/cholesky.jl @@ -18,6 +18,11 @@ # through the Hermitian and Symmetric views or exact symmetric or Hermitian elements which # is checked for and an error is thrown if the check fails. +# The internal structure is as follows +# - _chol! returns the factor and info without checking positive definiteness +# - chol/chol! returns the factor and checks for positive definiteness +# - cholfact/cholfact! returns Cholesky with checking positive definiteness + # FixMe? The dispatch below seems overly complicated. One simplification could be to # merge the two Cholesky types into one. It would remove the need for Val completely but # the cost would be extra unnecessary/unused fields for the unpivoted Cholesky and runtime @@ -27,9 +32,12 @@ struct Cholesky{T,S<:AbstractMatrix} <: Factorization{T} factors::S uplo::Char + info::BlasInt end -Cholesky{T}(A::AbstractMatrix{T}, uplo::Symbol) = Cholesky{T,typeof(A)}(A, char_uplo(uplo)) -Cholesky{T}(A::AbstractMatrix{T}, uplo::Char) = Cholesky{T,typeof(A)}(A, uplo) +Cholesky{T}(A::AbstractMatrix{T}, uplo::Symbol, info::BlasInt) = + Cholesky{T,typeof(A)}(A, char_uplo(uplo), info) +Cholesky{T}(A::AbstractMatrix{T}, uplo::Char, info::BlasInt) = + Cholesky{T,typeof(A)}(A, uplo, info) struct CholeskyPivoted{T,S<:AbstractMatrix} <: Factorization{T} factors::S @@ -49,11 +57,11 @@ end ## BLAS/LAPACK element types function _chol!(A::StridedMatrix{<:BlasFloat}, ::Type{UpperTriangular}) C, info = LAPACK.potrf!('U', A) - return @assertposdef UpperTriangular(C) info + return UpperTriangular(C), info end function _chol!(A::StridedMatrix{<:BlasFloat}, ::Type{LowerTriangular}) C, info = LAPACK.potrf!('L', A) - return @assertposdef LowerTriangular(C) info + return LowerTriangular(C), info end ## Non BLAS/LAPACK element types (generic) @@ -64,7 +72,10 @@ function _chol!(A::AbstractMatrix, ::Type{UpperTriangular}) for i = 1:k - 1 A[k,k] -= A[i,k]'A[i,k] end - Akk = _chol!(A[k,k], UpperTriangular) + Akk, info = _chol!(A[k,k], UpperTriangular) + if info != 0 + return UpperTriangular(A), info + end A[k,k] = Akk AkkInv = inv(Akk') for j = k + 1:n @@ -75,7 +86,7 @@ function _chol!(A::AbstractMatrix, ::Type{UpperTriangular}) end end end - return UpperTriangular(A) + return UpperTriangular(A), convert(BlasInt, 0) end function _chol!(A::AbstractMatrix, ::Type{LowerTriangular}) n = checksquare(A) @@ -84,7 +95,10 @@ function _chol!(A::AbstractMatrix, ::Type{LowerTriangular}) for i = 1:k - 1 A[k,k] -= A[k,i]*A[k,i]' end - Akk = _chol!(A[k,k], LowerTriangular) + Akk, info = _chol!(A[k,k], LowerTriangular) + if info != 0 + return LowerTriangular(A), info + end A[k,k] = Akk AkkInv = inv(Akk) for j = 1:k @@ -99,30 +113,33 @@ function _chol!(A::AbstractMatrix, ::Type{LowerTriangular}) end end end - return LowerTriangular(A) + return LowerTriangular(A), convert(BlasInt, 0) end ## Numbers function _chol!(x::Number, uplo) rx = real(x) - if rx != abs(x) - throw(ArgumentError("x must be positive semidefinite")) - end - rxr = sqrt(rx) - convert(promote_type(typeof(x), typeof(rxr)), rxr) + rxr = sqrt(abs(rx)) + rval = convert(promote_type(typeof(x), typeof(rxr)), rxr) + rx == abs(x) ? (rval, convert(BlasInt, 0)) : (rval, convert(BlasInt, 1)) end +chol!(x::Number, uplo) = ((C, info) = _chol!(x, uplo); @assertposdef C info) + non_hermitian_error(f) = throw(ArgumentError("matrix is not symmetric/" * "Hermitian. This error can be avoided by calling $f(Hermitian(A)) " * "which will ignore either the upper or lower triangle of the matrix.")) # chol!. Destructive methods for computing Cholesky factor of real symmetric or Hermitian # matrix -chol!(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) = - _chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular) +function chol!(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) + C, info = _chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular) + @assertposdef C info +end function chol!(A::StridedMatrix) ishermitian(A) || non_hermitian_error("chol!") - return _chol!(A, UpperTriangular) + C, info = _chol!(A, UpperTriangular) + @assertposdef C info end @@ -184,7 +201,7 @@ julia> chol(16) 4.0 ``` """ -chol(x::Number, args...) = _chol!(x, nothing) +chol(x::Number, args...) = ((C, info) = _chol!(x, nothing); @assertposdef C info) @@ -193,9 +210,11 @@ chol(x::Number, args...) = _chol!(x, nothing) ## No pivoting function cholfact!(A::RealHermSymComplexHerm, ::Type{Val{false}}) if A.uplo == 'U' - Cholesky(_chol!(A.data, UpperTriangular).data, 'U') + CU, info = _chol!(A.data, UpperTriangular) + Cholesky(CU.data, 'U', info) else - Cholesky(_chol!(A.data, LowerTriangular).data, 'L') + CL, info = _chol!(A.data, LowerTriangular) + Cholesky(CL.data, 'L', info) end end @@ -354,14 +373,15 @@ end ## Number function cholfact(x::Number, uplo::Symbol=:U) - xf = fill(chol(x), 1, 1) - Cholesky(xf, uplo) + C, info = _chol!(x, uplo) + xf = fill(C, 1, 1) + Cholesky(xf, uplo, info) end function convert(::Type{Cholesky{T}}, C::Cholesky) where T Cnew = convert(AbstractMatrix{T}, C.factors) - Cholesky{T, typeof(Cnew)}(Cnew, C.uplo) + Cholesky{T, typeof(Cnew)}(Cnew, C.uplo, C.info) end convert(::Type{Factorization{T}}, C::Cholesky{T}) where {T} = C convert(::Type{Factorization{T}}, C::Cholesky) where {T} = convert(Cholesky{T}, C) @@ -386,7 +406,7 @@ convert(::Type{Matrix}, F::CholeskyPivoted) = convert(Array, convert(AbstractArr convert(::Type{Array}, F::CholeskyPivoted) = convert(Matrix, F) full(F::CholeskyPivoted) = convert(AbstractArray, F) -copy(C::Cholesky) = Cholesky(copy(C.factors), C.uplo) +copy(C::Cholesky) = Cholesky(copy(C.factors), C.uplo, C.info) copy(C::CholeskyPivoted) = CholeskyPivoted(copy(C.factors), C.uplo, C.piv, C.rank, C.tol, C.info) size(C::Union{Cholesky, CholeskyPivoted}) = size(C.factors) @@ -417,7 +437,7 @@ show(io::IO, C::Cholesky{<:Any,<:AbstractMatrix}) = (println(io, "$(typeof(C)) with factor:");show(io,C[:UL])) A_ldiv_B!(C::Cholesky{T,<:AbstractMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} = - LAPACK.potrs!(C.uplo, C.factors, B) + @assertposdef LAPACK.potrs!(C.uplo, C.factors, B) C.info function A_ldiv_B!(C::Cholesky{<:Any,<:AbstractMatrix}, B::StridedVecOrMat) if C.uplo == 'L' @@ -465,16 +485,18 @@ function A_ldiv_B!(C::CholeskyPivoted, B::StridedMatrix) end function det(C::Cholesky) + C.info == 0 || throw(PosDefException(C.info)) dd = one(real(eltype(C))) - for i in 1:size(C.factors,1) + @inbounds for i in 1:size(C.factors,1) dd *= real(C.factors[i,i])^2 end dd end function logdet(C::Cholesky) + C.info == 0 || throw(PosDefException(C.info)) dd = zero(real(eltype(C))) - for i in 1:size(C.factors,1) + @inbounds for i in 1:size(C.factors,1) dd += log(real(C.factors[i,i])) end dd + dd # instead of 2.0dd which can change the type @@ -505,10 +527,9 @@ function logdet(C::CholeskyPivoted) end inv!(C::Cholesky{<:BlasFloat,<:StridedMatrix}) = - copytri!(LAPACK.potri!(C.uplo, C.factors), C.uplo, true) + @assertposdef copytri!(LAPACK.potri!(C.uplo, C.factors), C.uplo, true) C.info -inv(C::Cholesky{<:BlasFloat,<:StridedMatrix}) = - inv!(copy(C)) +inv(C::Cholesky{<:BlasFloat,<:StridedMatrix}) = inv!(copy(C)) function inv(C::CholeskyPivoted) chkfullrank(C) diff --git a/base/linalg/dense.jl b/base/linalg/dense.jl index 0f96ffc3ab5dd..c1dc43965e13b 100644 --- a/base/linalg/dense.jl +++ b/base/linalg/dense.jl @@ -770,10 +770,12 @@ function factorize(A::StridedMatrix{T}) where T return UpperTriangular(A) end if herm - try - return cholfact(A) + cf = cholfact(A) + if cf.info == 0 + return cf + else + return factorize(Hermitian(A)) end - return factorize(Hermitian(A)) end if sym return factorize(Symmetric(A)) diff --git a/base/linalg/lapack.jl b/base/linalg/lapack.jl index 14d8f76df34ae..1f9c71488a773 100644 --- a/base/linalg/lapack.jl +++ b/base/linalg/lapack.jl @@ -2963,10 +2963,10 @@ for (posv, potrf, potri, potrs, pstrf, elty, rtyp) in (Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}), &uplo, &size(A,1), A, &lda, info) chkargsok(info[]) - #info[1]>0 means the leading minor of order info[i] is not positive definite + #info[] > 0 means the leading minor of order info[] is not positive definite #ordinarily, throw Exception here, but return error code here #this simplifies isposdef! and factorize - return A, info[] + return A, info[] # info stored in Cholesky end # SUBROUTINE DPOTRI( UPLO, N, A, LDA, INFO ) diff --git a/test/linalg/cholesky.jl b/test/linalg/cholesky.jl index e6aabd7f8ff17..b6d92d89ba9f7 100644 --- a/test/linalg/cholesky.jl +++ b/test/linalg/cholesky.jl @@ -4,7 +4,7 @@ debug = false using Base.Test -using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted +using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted, PosDefException n = 10 @@ -60,7 +60,7 @@ for eltya in (Float32, Float64, Complex64, Complex128, BigFloat, Int) apos = apd[1,1] # test chol(x::Number), needs x>0 @test all(x -> x ≈ √apos, cholfact(apos).factors) - @test_throws ArgumentError chol(-one(eltya)) + @test_throws PosDefException chol(-one(eltya)) if eltya <: Real capds = cholfact(apds) @@ -194,10 +194,9 @@ end begin # Cholesky factor of Matrix with non-commutative elements, here 2x2-matrices - X = Matrix{Float64}[0.1*rand(2,2) for i in 1:3, j = 1:3] - L = full(Base.LinAlg._chol!(X*X', LowerTriangular)) - U = full(Base.LinAlg._chol!(X*X', UpperTriangular)) + L = full(Base.LinAlg._chol!(X*X', LowerTriangular)[1]) + U = full(Base.LinAlg._chol!(X*X', UpperTriangular)[1]) XX = full(X*X') @test sum(sum(norm, L*L' - XX)) < eps() @@ -212,8 +211,8 @@ for elty in (Float32, Float64, Complex{Float32}, Complex{Float64}) A = randn(5,5) end A = convert(Matrix{elty}, A'A) - @test full(cholfact(A)[:L]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{LowerTriangular}}, copy(A), LowerTriangular)) - @test full(cholfact(A)[:U]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{UpperTriangular}}, copy(A), UpperTriangular)) + @test full(cholfact(A)[:L]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{LowerTriangular}}, copy(A), LowerTriangular)[1]) + @test full(cholfact(A)[:U]) ≈ full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{UpperTriangular}}, copy(A), UpperTriangular)[1]) end # Test up- and downdates @@ -272,3 +271,14 @@ end # Fail for non-BLAS element types @test_throws ArgumentError cholfact!(Hermitian(rand(Float16, 5,5)), Val{true}) + +@testset "throw for non positive matrix" begin + for T in (Float32, Float64, Complex64, Complex128) + A = T[1 2; 2 1]; B = T[1, 1] + C = cholfact(A) + @show typeof(A), typeof(B), typeof(C.factors) + @test_throws PosDefException C\B + @test_throws PosDefException det(C) + @test_throws PosDefException logdet(C) + end +end