Skip to content

Commit

Permalink
Backport "Fix (l/r)mul! with Diagonal/Bidiagonal #55052" to v1.11 (#5…
Browse files Browse the repository at this point in the history
…5359)

This should hopefully fix the failing tests.

Co-authored-by: Kristoffer Carlsson <kcarlsson89@gmail.com>
  • Loading branch information
jishnub and KristofferC authored Aug 19, 2024
1 parent 28fcbff commit 4e7648f
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 4 deletions.
72 changes: 70 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,76 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))

lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
# B .= A * B
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
_muldiag_size_check(A, B)
(; dv, ev) = A
if A.uplo == 'U'
for k in axes(B,2)
for i in axes(ev,1)
B[i,k] = dv[i] * B[i,k] + ev[i] * B[i+1,k]
end
B[end,k] = dv[end] * B[end,k]
end
else
for k in axes(B,2)
for i in reverse(axes(dv,1)[2:end])
B[i,k] = dv[i] * B[i,k] + ev[i-1] * B[i-1,k]
end
B[1,k] = dv[1] * B[1,k]
end
end
return B
end
# B .= D * B
function lmul!(D::Diagonal, B::Bidiagonal)
_muldiag_size_check(D, B)
(; dv, ev) = B
isL = B.uplo == 'L'
dv[1] = D.diag[1] * dv[1]
for i in axes(ev,1)
ev[i] = D.diag[i + isL] * ev[i]
dv[i+1] = D.diag[i+1] * dv[i+1]
end
return B
end
# B .= B * A
function rmul!(B::AbstractMatrix, A::Bidiagonal)
_muldiag_size_check(A, B)
(; dv, ev) = A
if A.uplo == 'U'
for k in reverse(axes(dv,1)[2:end])
for i in axes(B,1)
B[i,k] = B[i,k] * dv[k] + B[i,k-1] * ev[k-1]
end
end
for i in axes(B,1)
B[i,1] *= dv[1]
end
else
for k in axes(ev,1)
for i in axes(B,1)
B[i,k] = B[i,k] * dv[k] + B[i,k+1] * ev[k]
end
end
for i in axes(B,1)
B[i,end] *= dv[end]
end
end
return B
end
# B .= B * D
function rmul!(B::Bidiagonal, D::Diagonal)
_muldiag_size_check(B, D)
(; dv, ev) = B
isU = B.uplo == 'U'
dv[1] *= D.diag[1]
for i in axes(ev,1)
ev[i] *= D.diag[i + isU]
dv[i+1] *= D.diag[i+1]
end
return B
end

function check_A_mul_B!_sizes(C, A, B)
mA, nA = size(A)
Expand Down
45 changes: 43 additions & 2 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,49 @@ function (*)(D::Diagonal, V::AbstractVector)
return D.diag .* V
end

rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)
function rmul!(A::AbstractMatrix, D::Diagonal)
_muldiag_size_check(A, D)
for I in CartesianIndices(A)
row, col = Tuple(I)
@inbounds A[row, col] *= D.diag[col]
end
return A
end
# T .= T * D
function rmul!(T::Tridiagonal, D::Diagonal)
_muldiag_size_check(T, D)
(; dl, d, du) = T
d[1] *= D.diag[1]
for i in axes(dl,1)
dl[i] *= D.diag[i]
du[i] *= D.diag[i+1]
d[i+1] *= D.diag[i+1]
end
return T
end

function lmul!(D::Diagonal, B::AbstractVecOrMat)
_muldiag_size_check(D, B)
for I in CartesianIndices(B)
row = I[1]
@inbounds B[I] = D.diag[row] * B[I]
end
return B
end

# in-place multiplication with a diagonal
# T .= D * T
function lmul!(D::Diagonal, T::Tridiagonal)
_muldiag_size_check(D, T)
(; dl, d, du) = T
d[1] = D.diag[1] * d[1]
for i in axes(dl,1)
dl[i] = D.diag[i+1] * dl[i]
du[i] = D.diag[i] * du[i]
d[i+1] = D.diag[i+1] * d[i+1]
end
return T
end

function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
require_one_based_indexing(out, B)
Expand Down
35 changes: 35 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,41 @@ end
@test mul!(C1, B, sv, 1, 2) == mul!(C2, B, v, 1 ,2)
end

@testset "rmul!/lmul! with banded matrices" begin
dv, ev = rand(4), rand(3)
for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L))
@testset "$(nameof(typeof(B)))" for B in (
Bidiagonal(dv, ev, :U),
Bidiagonal(dv, ev, :L),
Diagonal(dv)
)
@test_throws ArgumentError rmul!(B, A)
@test_throws ArgumentError lmul!(A, B)
end
D = Diagonal(dv)
@test rmul!(copy(A), D) A * D
@test lmul!(D, copy(A)) D * A
end
@testset "non-commutative" begin
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
for uplo in (:L, :U)
B = Bidiagonal(fill(S32, 4), fill(S32, 3), uplo)
D = Diagonal(fill(S22, size(B,2)))
@test rmul!(copy(B), D) B * D
D = Diagonal(fill(S33, size(B,1)))
@test lmul!(D, copy(B)) D * B
end

B = Bidiagonal(fill(S33, 4), fill(S33, 3), :U)
D = Diagonal(fill(S32, 4))
@test lmul!(B, Array(D)) B * D
B = Bidiagonal(fill(S22, 4), fill(S22, 3), :U)
@test rmul!(Array(D), B) D * B
end
end

@testset "off-band indexing error" begin
B = Bidiagonal(Vector{BigInt}(undef, 4), Vector{BigInt}(undef,3), :L)
@test_throws "cannot set entry" B[1,2] = 4
Expand Down
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1288,4 +1288,16 @@ end
@test yadj == x'
end

@testset "rmul!/lmul! with banded matrices" begin
@testset "$(nameof(typeof(B)))" for B in (
Bidiagonal(rand(4), rand(3), :L),
Tridiagonal(rand(3), rand(4), rand(3))
)
BA = Array(B)
D = Diagonal(rand(size(B,1)))
DA = Array(D)
@test rmul!(copy(B), D) B * D BA * DA
@test lmul!(D, copy(B)) D * B DA * BA
end
end
end # module TestDiagonal
18 changes: 18 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -833,4 +833,22 @@ end
@test axes(B) === (ax, ax)
end

@testset "rmul!/lmul! with banded matrices" begin
dl, d, du = rand(3), rand(4), rand(3)
A = Tridiagonal(dl, d, du)
D = Diagonal(d)
@test rmul!(copy(A), D) A * D
@test lmul!(D, copy(A)) D * A

@testset "non-commutative" begin
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
T = Tridiagonal(fill(S32,3), fill(S32, 4), fill(S32, 3))
D = Diagonal(fill(S22, size(T,2)))
@test rmul!(copy(T), D) T * D
D = Diagonal(fill(S33, size(T,1)))
@test lmul!(D, copy(T)) D * T
end
end
end # module TestTridiagonal
18 changes: 18 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ Base.first(::SOneTo) = 1
Base.last(r::SOneTo) = length(r)
Base.show(io::IO, r::SOneTo) = print(io, "SOneTo(", length(r), ")")

Broadcast.axistype(a::Base.OneTo, s::SOneTo) = s
Broadcast.axistype(s::SOneTo, a::Base.OneTo) = s

struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
data::A
function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}
Expand All @@ -43,10 +46,25 @@ Base.size(a::SizedArray) = size(typeof(a))
Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ
Base.axes(a::SizedArray) = map(SOneTo, size(a))
Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...)
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
Base.parent(S::SizedArray) = S.data
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data

homogenize_shape(t::Tuple) = (_homogenize_shape(first(t)), homogenize_shape(Base.tail(t))...)
homogenize_shape(::Tuple{}) = ()
_homogenize_shape(x::Integer) = x
_homogenize_shape(x::AbstractUnitRange) = length(x)
const Dims = Union{Integer, Base.OneTo, SOneTo}
function Base.similar(::Type{A}, shape::Tuple{Dims, Vararg{Dims}}) where {A<:AbstractArray}
similar(A, homogenize_shape(shape))
end
function Base.similar(::Type{A}, shape::Tuple{SOneTo, Vararg{SOneTo}}) where {A<:AbstractArray}
R = similar(A, length.(shape))
SizedArray{length.(shape)}(R)
end

const SizedMatrixLike = Union{SizedMatrix, Transpose{<:Any, <:SizedMatrix}, Adjoint{<:Any, <:SizedMatrix}}

_data(S::SizedArray) = S.data
Expand Down

0 comments on commit 4e7648f

Please sign in to comment.