Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearAlgebra: improve type-inference in Symmetric/Hermitian matmul #54303

Merged
merged 15 commits into from
May 7, 2024
Merged
Prev Previous commit
Next Next commit
LinearAlgbebra: constant propagate character in generic_matmatmul checks
  • Loading branch information
jishnub committed May 2, 2024
commit 54a4038a951a423d2176763615cf022e1a70d29b
21 changes: 18 additions & 3 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,21 @@ julia> lmul!(F.Q, B)
"""
lmul!(A, B)

# unroll the in(a, b) computation to enable constant propagation
# This is a 2-valued in implementation that doesn't account for missing values
_in(t::AbstractChar, ::Tuple{}) = false
function _in(t::AbstractChar, chars::Tuple{Vararg{AbstractChar}})
return t == first(chars) || _in(t, Base.tail(chars))
end
all_in(chars, (tA, tB)) = _in(tA, chars) && _in(tB, chars)

# THE one big BLAS dispatch
# aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
if all(in(('N', 'T', 'C')), (tA, tB))
# if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop
# The check is only on the wrapper type, so we may extract that from a WrapperChar
if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB)))
if tA == 'T' && tB == 'N' && A === B
return syrk_wrapper!(C, 'T', A, _add)
elseif tA == 'N' && tB == 'T' && A === B
Expand Down Expand Up @@ -395,7 +405,10 @@ end
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal}
if all(in(('N', 'T', 'C')), (tA, tB))
special_cases = ('N', 'T', 'C')
# if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop
# The check is only on the wrapper type, so we may extract that from a WrapperChar
if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB)))
gemm_wrapper!(C, tA, tB, A, B, _add)
else
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
Expand Down Expand Up @@ -587,7 +600,9 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, T, mA, nB)
if all(in(('N', 'T', 'C')), (tA, tB))
# if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop
# The check is only on the wrapper type, so we may extract that from a WrapperChar
if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB)))
gemm_wrapper!(C, tA, tB, A, B)
else
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
Expand Down