Skip to content

Commit

Permalink
Store info in Cholesky type (#21976)
Browse files Browse the repository at this point in the history
* add info field to Cholesky type

and delay throwing for non-positive definiteness

* comment update [ci skip]

* remove comment [ci skip]
  • Loading branch information
fredrikekre authored and KristofferC committed May 30, 2017
1 parent fdf97c1 commit 9b51302
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 41 deletions.
79 changes: 50 additions & 29 deletions base/linalg/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
# through the Hermitian and Symmetric views or exact symmetric or Hermitian elements which
# is checked for and an error is thrown if the check fails.

# The internal structure is as follows
# - _chol! returns the factor and info without checking positive definiteness
# - chol/chol! returns the factor and checks for positive definiteness
# - cholfact/cholfact! returns Cholesky with checking positive definiteness

# FixMe? The dispatch below seems overly complicated. One simplification could be to
# merge the two Cholesky types into one. It would remove the need for Val completely but
# the cost would be extra unnecessary/unused fields for the unpivoted Cholesky and runtime
Expand All @@ -27,9 +32,12 @@
struct Cholesky{T,S<:AbstractMatrix} <: Factorization{T}
factors::S
uplo::Char
info::BlasInt
end
Cholesky{T}(A::AbstractMatrix{T}, uplo::Symbol) = Cholesky{T,typeof(A)}(A, char_uplo(uplo))
Cholesky{T}(A::AbstractMatrix{T}, uplo::Char) = Cholesky{T,typeof(A)}(A, uplo)
Cholesky{T}(A::AbstractMatrix{T}, uplo::Symbol, info::BlasInt) =
Cholesky{T,typeof(A)}(A, char_uplo(uplo), info)
Cholesky{T}(A::AbstractMatrix{T}, uplo::Char, info::BlasInt) =
Cholesky{T,typeof(A)}(A, uplo, info)

struct CholeskyPivoted{T,S<:AbstractMatrix} <: Factorization{T}
factors::S
Expand All @@ -49,11 +57,11 @@ end
## BLAS/LAPACK element types
function _chol!(A::StridedMatrix{<:BlasFloat}, ::Type{UpperTriangular})
C, info = LAPACK.potrf!('U', A)
return @assertposdef UpperTriangular(C) info
return UpperTriangular(C), info
end
function _chol!(A::StridedMatrix{<:BlasFloat}, ::Type{LowerTriangular})
C, info = LAPACK.potrf!('L', A)
return @assertposdef LowerTriangular(C) info
return LowerTriangular(C), info
end

## Non BLAS/LAPACK element types (generic)
Expand All @@ -64,7 +72,10 @@ function _chol!(A::AbstractMatrix, ::Type{UpperTriangular})
for i = 1:k - 1
A[k,k] -= A[i,k]'A[i,k]
end
Akk = _chol!(A[k,k], UpperTriangular)
Akk, info = _chol!(A[k,k], UpperTriangular)
if info != 0
return UpperTriangular(A), info
end
A[k,k] = Akk
AkkInv = inv(Akk')
for j = k + 1:n
Expand All @@ -75,7 +86,7 @@ function _chol!(A::AbstractMatrix, ::Type{UpperTriangular})
end
end
end
return UpperTriangular(A)
return UpperTriangular(A), convert(BlasInt, 0)
end
function _chol!(A::AbstractMatrix, ::Type{LowerTriangular})
n = checksquare(A)
Expand All @@ -84,7 +95,10 @@ function _chol!(A::AbstractMatrix, ::Type{LowerTriangular})
for i = 1:k - 1
A[k,k] -= A[k,i]*A[k,i]'
end
Akk = _chol!(A[k,k], LowerTriangular)
Akk, info = _chol!(A[k,k], LowerTriangular)
if info != 0
return LowerTriangular(A), info
end
A[k,k] = Akk
AkkInv = inv(Akk)
for j = 1:k
Expand All @@ -99,30 +113,33 @@ function _chol!(A::AbstractMatrix, ::Type{LowerTriangular})
end
end
end
return LowerTriangular(A)
return LowerTriangular(A), convert(BlasInt, 0)
end

## Numbers
function _chol!(x::Number, uplo)
rx = real(x)
if rx != abs(x)
throw(ArgumentError("x must be positive semidefinite"))
end
rxr = sqrt(rx)
convert(promote_type(typeof(x), typeof(rxr)), rxr)
rxr = sqrt(abs(rx))
rval = convert(promote_type(typeof(x), typeof(rxr)), rxr)
rx == abs(x) ? (rval, convert(BlasInt, 0)) : (rval, convert(BlasInt, 1))
end

chol!(x::Number, uplo) = ((C, info) = _chol!(x, uplo); @assertposdef C info)

non_hermitian_error(f) = throw(ArgumentError("matrix is not symmetric/" *
"Hermitian. This error can be avoided by calling $f(Hermitian(A)) " *
"which will ignore either the upper or lower triangle of the matrix."))

# chol!. Destructive methods for computing Cholesky factor of real symmetric or Hermitian
# matrix
chol!(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) =
_chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular)
function chol!(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix})
C, info = _chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular)
@assertposdef C info
end
function chol!(A::StridedMatrix)
ishermitian(A) || non_hermitian_error("chol!")
return _chol!(A, UpperTriangular)
C, info = _chol!(A, UpperTriangular)
@assertposdef C info
end


Expand Down Expand Up @@ -184,7 +201,7 @@ julia> chol(16)
4.0
```
"""
chol(x::Number, args...) = _chol!(x, nothing)
chol(x::Number, args...) = ((C, info) = _chol!(x, nothing); @assertposdef C info)



Expand All @@ -193,9 +210,11 @@ chol(x::Number, args...) = _chol!(x, nothing)
## No pivoting
function cholfact!(A::RealHermSymComplexHerm, ::Type{Val{false}})
if A.uplo == 'U'
Cholesky(_chol!(A.data, UpperTriangular).data, 'U')
CU, info = _chol!(A.data, UpperTriangular)
Cholesky(CU.data, 'U', info)
else
Cholesky(_chol!(A.data, LowerTriangular).data, 'L')
CL, info = _chol!(A.data, LowerTriangular)
Cholesky(CL.data, 'L', info)
end
end

Expand Down Expand Up @@ -354,14 +373,15 @@ end

## Number
function cholfact(x::Number, uplo::Symbol=:U)
xf = fill(chol(x), 1, 1)
Cholesky(xf, uplo)
C, info = _chol!(x, uplo)
xf = fill(C, 1, 1)
Cholesky(xf, uplo, info)
end


function convert(::Type{Cholesky{T}}, C::Cholesky) where T
Cnew = convert(AbstractMatrix{T}, C.factors)
Cholesky{T, typeof(Cnew)}(Cnew, C.uplo)
Cholesky{T, typeof(Cnew)}(Cnew, C.uplo, C.info)
end
convert(::Type{Factorization{T}}, C::Cholesky{T}) where {T} = C
convert(::Type{Factorization{T}}, C::Cholesky) where {T} = convert(Cholesky{T}, C)
Expand All @@ -386,7 +406,7 @@ convert(::Type{Matrix}, F::CholeskyPivoted) = convert(Array, convert(AbstractArr
convert(::Type{Array}, F::CholeskyPivoted) = convert(Matrix, F)
full(F::CholeskyPivoted) = convert(AbstractArray, F)

copy(C::Cholesky) = Cholesky(copy(C.factors), C.uplo)
copy(C::Cholesky) = Cholesky(copy(C.factors), C.uplo, C.info)
copy(C::CholeskyPivoted) = CholeskyPivoted(copy(C.factors), C.uplo, C.piv, C.rank, C.tol, C.info)

size(C::Union{Cholesky, CholeskyPivoted}) = size(C.factors)
Expand Down Expand Up @@ -417,7 +437,7 @@ show(io::IO, C::Cholesky{<:Any,<:AbstractMatrix}) =
(println(io, "$(typeof(C)) with factor:");show(io,C[:UL]))

A_ldiv_B!(C::Cholesky{T,<:AbstractMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
LAPACK.potrs!(C.uplo, C.factors, B)
@assertposdef LAPACK.potrs!(C.uplo, C.factors, B) C.info

function A_ldiv_B!(C::Cholesky{<:Any,<:AbstractMatrix}, B::StridedVecOrMat)
if C.uplo == 'L'
Expand Down Expand Up @@ -465,16 +485,18 @@ function A_ldiv_B!(C::CholeskyPivoted, B::StridedMatrix)
end

function det(C::Cholesky)
C.info == 0 || throw(PosDefException(C.info))
dd = one(real(eltype(C)))
for i in 1:size(C.factors,1)
@inbounds for i in 1:size(C.factors,1)
dd *= real(C.factors[i,i])^2
end
dd
end

function logdet(C::Cholesky)
C.info == 0 || throw(PosDefException(C.info))
dd = zero(real(eltype(C)))
for i in 1:size(C.factors,1)
@inbounds for i in 1:size(C.factors,1)
dd += log(real(C.factors[i,i]))
end
dd + dd # instead of 2.0dd which can change the type
Expand Down Expand Up @@ -505,10 +527,9 @@ function logdet(C::CholeskyPivoted)
end

inv!(C::Cholesky{<:BlasFloat,<:StridedMatrix}) =
copytri!(LAPACK.potri!(C.uplo, C.factors), C.uplo, true)
@assertposdef copytri!(LAPACK.potri!(C.uplo, C.factors), C.uplo, true) C.info

inv(C::Cholesky{<:BlasFloat,<:StridedMatrix}) =
inv!(copy(C))
inv(C::Cholesky{<:BlasFloat,<:StridedMatrix}) = inv!(copy(C))

function inv(C::CholeskyPivoted)
chkfullrank(C)
Expand Down
8 changes: 5 additions & 3 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -770,10 +770,12 @@ function factorize(A::StridedMatrix{T}) where T
return UpperTriangular(A)
end
if herm
try
return cholfact(A)
cf = cholfact(A)
if cf.info == 0
return cf
else
return factorize(Hermitian(A))
end
return factorize(Hermitian(A))
end
if sym
return factorize(Symmetric(A))
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2963,10 +2963,10 @@ for (posv, potrf, potri, potrs, pstrf, elty, rtyp) in
(Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}),
&uplo, &size(A,1), A, &lda, info)
chkargsok(info[])
#info[1]>0 means the leading minor of order info[i] is not positive definite
#info[] > 0 means the leading minor of order info[] is not positive definite
#ordinarily, throw Exception here, but return error code here
#this simplifies isposdef! and factorize
return A, info[]
return A, info[] # info stored in Cholesky
end

# SUBROUTINE DPOTRI( UPLO, N, A, LDA, INFO )
Expand Down
24 changes: 17 additions & 7 deletions test/linalg/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ debug = false

using Base.Test

using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted
using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted, PosDefException

n = 10

Expand Down Expand Up @@ -60,7 +60,7 @@ for eltya in (Float32, Float64, Complex64, Complex128, BigFloat, Int)

apos = apd[1,1] # test chol(x::Number), needs x>0
@test all(x -> x apos, cholfact(apos).factors)
@test_throws ArgumentError chol(-one(eltya))
@test_throws PosDefException chol(-one(eltya))

if eltya <: Real
capds = cholfact(apds)
Expand Down Expand Up @@ -194,10 +194,9 @@ end

begin
# Cholesky factor of Matrix with non-commutative elements, here 2x2-matrices

X = Matrix{Float64}[0.1*rand(2,2) for i in 1:3, j = 1:3]
L = full(Base.LinAlg._chol!(X*X', LowerTriangular))
U = full(Base.LinAlg._chol!(X*X', UpperTriangular))
L = full(Base.LinAlg._chol!(X*X', LowerTriangular)[1])
U = full(Base.LinAlg._chol!(X*X', UpperTriangular)[1])
XX = full(X*X')

@test sum(sum(norm, L*L' - XX)) < eps()
Expand All @@ -212,8 +211,8 @@ for elty in (Float32, Float64, Complex{Float32}, Complex{Float64})
A = randn(5,5)
end
A = convert(Matrix{elty}, A'A)
@test full(cholfact(A)[:L]) full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{LowerTriangular}}, copy(A), LowerTriangular))
@test full(cholfact(A)[:U]) full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{UpperTriangular}}, copy(A), UpperTriangular))
@test full(cholfact(A)[:L]) full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{LowerTriangular}}, copy(A), LowerTriangular)[1])
@test full(cholfact(A)[:U]) full(invoke(Base.LinAlg._chol!, Tuple{AbstractMatrix, Type{UpperTriangular}}, copy(A), UpperTriangular)[1])
end

# Test up- and downdates
Expand Down Expand Up @@ -272,3 +271,14 @@ end

# Fail for non-BLAS element types
@test_throws ArgumentError cholfact!(Hermitian(rand(Float16, 5,5)), Val{true})

@testset "throw for non positive matrix" begin
for T in (Float32, Float64, Complex64, Complex128)
A = T[1 2; 2 1]; B = T[1, 1]
C = cholfact(A)
@show typeof(A), typeof(B), typeof(C.factors)
@test_throws PosDefException C\B
@test_throws PosDefException det(C)
@test_throws PosDefException logdet(C)
end
end

1 comment on commit 9b51302

@tkelman
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit out [ci skip] from merge commit messages

Please sign in to comment.