From 685f5272c1e200f2772805a14042e01d261baf68 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 2 May 2024 14:03:12 +0530 Subject: [PATCH] LinearAlgebra: Type-stability in broadcasting numbers over Bidiagonal (#54067) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This makes the following type-stable: ```julia julia> B = Bidiagonal(rand(3), rand(2), :U); julia> @inferred (B -> B .* 2)(B) 3×3 Bidiagonal{Float64, Vector{Float64}}: 0.3929 1.93165 ⋅ ⋅ 1.61301 1.00202 ⋅ ⋅ 1.96483 ``` Similarly, for other operations involving a single `Bidiagonal` and numbers. This is not type-stable on master, as the number of `Bidiagonal` matrices in a broadcast operation is not tracked (even though this is used in promoting the `uplo`). Since the `uplo` can't be constant-propagated, we count this by introducing an additional flag in the promotion mechanism, which is entirely determined by the types of the terms in the broadcast operation. --------- Co-authored-by: N5N3 <2642243996@qq.com> --- .../LinearAlgebra/src/structuredbroadcast.jl | 4 +++- .../LinearAlgebra/test/structuredbroadcast.jl | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl index f2c35c8edcce4..248ec3105de30 100644 --- a/stdlib/LinearAlgebra/src/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl @@ -78,7 +78,7 @@ find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nest function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, ::Type{ElType}, n) where {ElType} uplo = n > 0 ? find_uplo(bc) : 'U' n1 = max(n - 1, 0) - if uplo == 'T' + if count_structedmatrix(Bidiagonal, bc) > 1 && uplo == 'T' return Tridiagonal(Array{ElType}(undef, n1), Array{ElType}(undef, n), Array{ElType}(undef, n1)) end return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n1), uplo) @@ -135,6 +135,8 @@ iszerodefined(::Type{<:Number}) = true iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T) iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T) +count_structedmatrix(T, bc::Broadcasted) = sum(Base.Fix2(isa, T), Broadcast.cat_nested(bc); init = 0) + fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0)) # Like sparse matrices, we assume that the zero-preservation property of a broadcasted # expression is stable. We can test the zero-preservability by applying the function diff --git a/stdlib/LinearAlgebra/test/structuredbroadcast.jl b/stdlib/LinearAlgebra/test/structuredbroadcast.jl index 3767fc10055f2..e4e78ad94102c 100644 --- a/stdlib/LinearAlgebra/test/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/test/structuredbroadcast.jl @@ -96,6 +96,24 @@ using Test, LinearAlgebra @test broadcast!(*, Z, X, Y) == broadcast(*, fX, fY) end end + + @testset "type-stability in Bidiagonal" begin + B2 = @inferred (B -> .- B)(B) + @test B2 isa Bidiagonal + @test B2 == -1 * B + B2 = @inferred (B -> B .* 2)(B) + @test B2 isa Bidiagonal + @test B2 == B + B + B2 = @inferred (B -> 2 .* B)(B) + @test B2 isa Bidiagonal + @test B2 == B + B + B2 = @inferred (B -> B ./ 1)(B) + @test B2 isa Bidiagonal + @test B2 == B + B2 = @inferred (B -> 1 .\ B)(B) + @test B2 isa Bidiagonal + @test B2 == B + end end @testset "broadcast! where the destination is a structured matrix" begin