Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearAlgebra: Type-stability in broadcasting numbers over Bidiagonal #54067

Merged
merged 3 commits into from
May 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
count_structedmatrix instead of find_uplo flag
  • Loading branch information
jishnub committed May 1, 2024
commit 3aaca7cd6e7609428461d7fd27911d03a0201e64
23 changes: 11 additions & 12 deletions stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,19 @@ structured_broadcast_alloc(bc, ::Type{Diagonal}, ::Type{ElType}, n) where {ElTyp
# Bidiagonal is tricky as we need to know if it's upper or lower. The promotion
# system will return Tridiagonal when there's more than one Bidiagonal, but when
# there's only one, we need to make figure out upper or lower
# the Val flag checks if only one Bidiagonal is encountered in the broadcast expression,
# in which case we may preserve the type of the array
merge_uplos(::Tuple{Nothing, Val{A}}, ::Tuple{Nothing, Val{B}}) where {A,B} = (nothing, Val(A & B))
merge_uplos(a::Tuple{Any,Val{A}}, ::Tuple{Nothing, Val{B}}) where {A,B} = (first(a), Val(A & B))
merge_uplos(::Tuple{Nothing, Val{A}}, b::Tuple{Any,Val{B}}) where {A,B} = (first(b), Val(A & B))
merge_uplos(a::Tuple{Any,Val}, b::Tuple{Any,Val}) = (first(a) == first(b) ? first(a) : 'T', Val(false))
merge_uplos(::Nothing, ::Nothing) = nothing
merge_uplos(a, ::Nothing) = a
merge_uplos(::Nothing, b) = b
merge_uplos(a, b) = a == b ? a : 'T'

find_uplo(a::Bidiagonal) = (a.uplo, Val(true))
find_uplo(a) = (nothing, Val(true))
find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nested(bc), init=(nothing, Val(true)))
find_uplo(a::Bidiagonal) = a.uplo
find_uplo(a) = nothing
find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nested(bc), init=nothing)

function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, ::Type{ElType}, n) where {ElType}
uplo, val = find_uplo(bc)
uplo = n > 0 ? uplo : 'U'
uplo = n > 0 ? find_uplo(bc) : 'U'
n1 = max(n - 1, 0)
if val isa Val{false} && uplo == 'T'
if count_structedmatrix(Bidiagonal, bc) > 1 && uplo == 'T'
return Tridiagonal(Array{ElType}(undef, n1), Array{ElType}(undef, n), Array{ElType}(undef, n1))
end
return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n1), uplo)
Expand Down Expand Up @@ -138,6 +135,8 @@ iszerodefined(::Type{<:Number}) = true
iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T)
iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T)

count_structedmatrix(T, bc::Broadcasted) = sum(Base.Fix2(isa, T), Broadcast.cat_nested(bc); init = 0)

fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0))
# Like sparse matrices, we assume that the zero-preservation property of a broadcasted
# expression is stable. We can test the zero-preservability by applying the function
Expand Down