diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 28d7b2fe56eb7..0144a575ccfdf 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -435,8 +435,76 @@ const BiTri = Union{Bidiagonal,Tridiagonal} @inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) @inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul()) -rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul()) +# B .= A * B +function lmul!(A::Bidiagonal, B::AbstractVecOrMat) + _muldiag_size_check(A, B) + (; dv, ev) = A + if A.uplo == 'U' + for k in axes(B,2) + for i in axes(ev,1) + B[i,k] = dv[i] * B[i,k] + ev[i] * B[i+1,k] + end + B[end,k] = dv[end] * B[end,k] + end + else + for k in axes(B,2) + for i in reverse(axes(dv,1)[2:end]) + B[i,k] = dv[i] * B[i,k] + ev[i-1] * B[i-1,k] + end + B[1,k] = dv[1] * B[1,k] + end + end + return B +end +# B .= D * B +function lmul!(D::Diagonal, B::Bidiagonal) + _muldiag_size_check(D, B) + (; dv, ev) = B + isL = B.uplo == 'L' + dv[1] = D.diag[1] * dv[1] + for i in axes(ev,1) + ev[i] = D.diag[i + isL] * ev[i] + dv[i+1] = D.diag[i+1] * dv[i+1] + end + return B +end +# B .= B * A +function rmul!(B::AbstractMatrix, A::Bidiagonal) + _muldiag_size_check(A, B) + (; dv, ev) = A + if A.uplo == 'U' + for k in reverse(axes(dv,1)[2:end]) + for i in axes(B,1) + B[i,k] = B[i,k] * dv[k] + B[i,k-1] * ev[k-1] + end + end + for i in axes(B,1) + B[i,1] *= dv[1] + end + else + for k in axes(ev,1) + for i in axes(B,1) + B[i,k] = B[i,k] * dv[k] + B[i,k+1] * ev[k] + end + end + for i in axes(B,1) + B[i,end] *= dv[end] + end + end + return B +end +# B .= B * D +function rmul!(B::Bidiagonal, D::Diagonal) + _muldiag_size_check(B, D) + (; dv, ev) = B + isU = B.uplo == 'U' + dv[1] *= D.diag[1] + for i in axes(ev,1) + ev[i] *= D.diag[i + isU] + dv[i+1] *= D.diag[i+1] + end + return B +end function check_A_mul_B!_sizes(C, A, B) mA, nA = size(A) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 6ee4f1279b4fd..cd5937e3dcb4f 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -310,8 +310,49 @@ function (*)(D::Diagonal, V::AbstractVector) return D.diag .* V end -rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D) -lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B) +function rmul!(A::AbstractMatrix, D::Diagonal) + _muldiag_size_check(A, D) + for I in CartesianIndices(A) + row, col = Tuple(I) + @inbounds A[row, col] *= D.diag[col] + end + return A +end +# T .= T * D +function rmul!(T::Tridiagonal, D::Diagonal) + _muldiag_size_check(T, D) + (; dl, d, du) = T + d[1] *= D.diag[1] + for i in axes(dl,1) + dl[i] *= D.diag[i] + du[i] *= D.diag[i+1] + d[i+1] *= D.diag[i+1] + end + return T +end + +function lmul!(D::Diagonal, B::AbstractVecOrMat) + _muldiag_size_check(D, B) + for I in CartesianIndices(B) + row = I[1] + @inbounds B[I] = D.diag[row] * B[I] + end + return B +end + +# in-place multiplication with a diagonal +# T .= D * T +function lmul!(D::Diagonal, T::Tridiagonal) + _muldiag_size_check(D, T) + (; dl, d, du) = T + d[1] = D.diag[1] * d[1] + for i in axes(dl,1) + dl[i] = D.diag[i+1] * dl[i] + du[i] = D.diag[i] * du[i] + d[i+1] = D.diag[i+1] * d[i+1] + end + return T +end function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} require_one_based_indexing(out, B) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 3dbb9adf5f562..614b1662f3627 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -884,6 +884,41 @@ end @test mul!(C1, B, sv, 1, 2) == mul!(C2, B, v, 1 ,2) end +@testset "rmul!/lmul! with banded matrices" begin + dv, ev = rand(4), rand(3) + for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L)) + @testset "$(nameof(typeof(B)))" for B in ( + Bidiagonal(dv, ev, :U), + Bidiagonal(dv, ev, :L), + Diagonal(dv) + ) + @test_throws ArgumentError rmul!(B, A) + @test_throws ArgumentError lmul!(A, B) + end + D = Diagonal(dv) + @test rmul!(copy(A), D) ≈ A * D + @test lmul!(D, copy(A)) ≈ D * A + end + @testset "non-commutative" begin + S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2)) + S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3)) + S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2)) + for uplo in (:L, :U) + B = Bidiagonal(fill(S32, 4), fill(S32, 3), uplo) + D = Diagonal(fill(S22, size(B,2))) + @test rmul!(copy(B), D) ≈ B * D + D = Diagonal(fill(S33, size(B,1))) + @test lmul!(D, copy(B)) ≈ D * B + end + + B = Bidiagonal(fill(S33, 4), fill(S33, 3), :U) + D = Diagonal(fill(S32, 4)) + @test lmul!(B, Array(D)) ≈ B * D + B = Bidiagonal(fill(S22, 4), fill(S22, 3), :U) + @test rmul!(Array(D), B) ≈ D * B + end +end + @testset "off-band indexing error" begin B = Bidiagonal(Vector{BigInt}(undef, 4), Vector{BigInt}(undef,3), :L) @test_throws "cannot set entry" B[1,2] = 4 diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 3087e87f63415..8397e97dcdf41 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1288,4 +1288,16 @@ end @test yadj == x' end +@testset "rmul!/lmul! with banded matrices" begin + @testset "$(nameof(typeof(B)))" for B in ( + Bidiagonal(rand(4), rand(3), :L), + Tridiagonal(rand(3), rand(4), rand(3)) + ) + BA = Array(B) + D = Diagonal(rand(size(B,1))) + DA = Array(D) + @test rmul!(copy(B), D) ≈ B * D ≈ BA * DA + @test lmul!(D, copy(B)) ≈ D * B ≈ DA * BA + end +end end # module TestDiagonal diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 00414fcdc8b56..4f84988cabd50 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -833,4 +833,22 @@ end @test axes(B) === (ax, ax) end +@testset "rmul!/lmul! with banded matrices" begin + dl, d, du = rand(3), rand(4), rand(3) + A = Tridiagonal(dl, d, du) + D = Diagonal(d) + @test rmul!(copy(A), D) ≈ A * D + @test lmul!(D, copy(A)) ≈ D * A + + @testset "non-commutative" begin + S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2)) + S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3)) + S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2)) + T = Tridiagonal(fill(S32,3), fill(S32, 4), fill(S32, 3)) + D = Diagonal(fill(S22, size(T,2))) + @test rmul!(copy(T), D) ≈ T * D + D = Diagonal(fill(S33, size(T,1))) + @test lmul!(D, copy(T)) ≈ D * T + end +end end # module TestTridiagonal diff --git a/test/testhelpers/SizedArrays.jl b/test/testhelpers/SizedArrays.jl index 43bc27e630479..bc02fb5cbbd20 100644 --- a/test/testhelpers/SizedArrays.jl +++ b/test/testhelpers/SizedArrays.jl @@ -23,6 +23,9 @@ Base.first(::SOneTo) = 1 Base.last(r::SOneTo) = length(r) Base.show(io::IO, r::SOneTo) = print(io, "SOneTo(", length(r), ")") +Broadcast.axistype(a::Base.OneTo, s::SOneTo) = s +Broadcast.axistype(s::SOneTo, a::Base.OneTo) = s + struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N} data::A function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N} @@ -43,10 +46,25 @@ Base.size(a::SizedArray) = size(typeof(a)) Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ Base.axes(a::SizedArray) = map(SOneTo, size(a)) Base.getindex(A::SizedArray, i...) = getindex(A.data, i...) +Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...) Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T))) +Base.parent(S::SizedArray) = S.data +(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data) ==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data +homogenize_shape(t::Tuple) = (_homogenize_shape(first(t)), homogenize_shape(Base.tail(t))...) +homogenize_shape(::Tuple{}) = () +_homogenize_shape(x::Integer) = x +_homogenize_shape(x::AbstractUnitRange) = length(x) +const Dims = Union{Integer, Base.OneTo, SOneTo} +function Base.similar(::Type{A}, shape::Tuple{Dims, Vararg{Dims}}) where {A<:AbstractArray} + similar(A, homogenize_shape(shape)) +end +function Base.similar(::Type{A}, shape::Tuple{SOneTo, Vararg{SOneTo}}) where {A<:AbstractArray} + R = similar(A, length.(shape)) + SizedArray{length.(shape)}(R) +end + const SizedMatrixLike = Union{SizedMatrix, Transpose{<:Any, <:SizedMatrix}, Adjoint{<:Any, <:SizedMatrix}} _data(S::SizedArray) = S.data