Skip to content

Commit

Permalink
adding == for structured matrices (#30108)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcognetta authored and andreasnoack committed Dec 11, 2018
1 parent 411a7cf commit 2460301
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
9 changes: 8 additions & 1 deletion stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,14 @@ end
*(A::Bidiagonal, B::Number) = Bidiagonal(A.dv*B, A.ev*B, A.uplo)
*(B::Number, A::Bidiagonal) = A*B
/(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.uplo)
==(A::Bidiagonal, B::Bidiagonal) = (A.uplo==B.uplo) && (A.dv==B.dv) && (A.ev==B.ev)

function ==(A::Bidiagonal, B::Bidiagonal)
if A.uplo == B.uplo
return A.dv == B.dv && A.ev == B.ev
else
return iszero(A.ev) && iszero(B.ev) && A.dv == B.dv
end
end

const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
const BiTri = Union{Bidiagonal,Tridiagonal}
Expand Down
22 changes: 22 additions & 0 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,25 @@ function fill!(A::Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}, x)
throw(ArgumentError("array of type $(typeof(A)) and size $(size(A)) can
not be filled with $x, since some of its entries are constrained."))
end

# equals and approx equals methods for structured matrices
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl

# SymTridiagonal and Bidiagonal have the same field names
==(A::Diagonal, B::Union{SymTridiagonal, Bidiagonal}) = iszero(B.ev) && A.diag == B.dv
==(B::Bidiagonal, A::Diagonal) = A == B

==(A::Diagonal, B::Tridiagonal) = iszero(B.dl) && iszero(B.du) && A.diag == B.d
==(B::Tridiagonal, A::Diagonal) = A == B

function ==(A::Bidiagonal, B::Tridiagonal)
if A.uplo == 'U'
return iszero(B.dl) && A.dv == B.d && A.ev == B.du
else
return iszero(B.du) && A.dv == B.d && A.ev == B.dl
end
end
==(B::Tridiagonal, A::Bidiagonal) = A == B

==(A::Bidiagonal, B::SymTridiagonal) = iszero(B.ev) && iszero(A.ev) && A.dv == B.dv
==(B::SymTridiagonal, A::Bidiagonal) = A == B
24 changes: 24 additions & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,28 @@ end
@test isa((@inferred vcat(Float64[], spzeros(1))), SparseVector)
end

@testset "== for structured matrices" begin
diag = rand(10)
offdiag = rand(9)
D = Diagonal(rand(10))
Bup = Bidiagonal(diag, offdiag, 'U')
Blo = Bidiagonal(diag, offdiag, 'L')
Bupd = Bidiagonal(diag, zeros(9), 'U')
Blod = Bidiagonal(diag, zeros(9), 'L')
T = Tridiagonal(offdiag, diag, offdiag)
Td = Tridiagonal(zeros(9), diag, zeros(9))
Tu = Tridiagonal(zeros(9), diag, offdiag)
Tl = Tridiagonal(offdiag, diag, zeros(9))
S = SymTridiagonal(diag, offdiag)
Sd = SymTridiagonal(diag, zeros(9))

mats = [D, Bup, Blo, Bupd, Blod, T, Td, Tu, Tl, S, Sd]

for a in mats
for b in mats
@test (a == b) == (Matrix(a) == Matrix(b)) == (b == a) == (Matrix(b) == Matrix(a))
end
end
end

end # module TestSpecial

0 comments on commit 2460301

Please sign in to comment.