Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearAlgebra: improve type-inference in Symmetric/Hermitian matmul #54303

Merged
merged 15 commits into from
May 7, 2024
Merged
Prev Previous commit
Next Next commit
Remove some unnecessary type conversions
  • Loading branch information
jishnub committed May 2, 2024
commit 9e66d1a03dbe7743b8ecf1c7aa65229fd9c2e85c
12 changes: 6 additions & 6 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
if _in(tA_uc, ('S', 'H'))
# re-wrap again and use plain ('N') matvec mul algorithm,
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
return _generic_matvecmul!(y, oftype(tA, 'N'), wrap(A, tA), x, MulAddMul(α, β))
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
else
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
Expand Down Expand Up @@ -516,7 +516,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
elseif _in(tA_uc, ('S', 'H'))
# re-wrap again and use plain ('N') matvec mul algorithm,
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
return _generic_matvecmul!(y, oftype(tA, 'N'), wrap(A, tA), x, MulAddMul(α, β))
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
else
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
Expand All @@ -530,10 +530,10 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
if tA_uc == 'T'
(nA, mA) = size(A,1), size(A,2)
tAt = oftype(tA, 'N')
tAt = 'N'
else
(mA, nA) = size(A,1), size(A,2)
tAt = oftype(tA, 'T')
tAt = 'T'
end
if nC != mA
throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)"))
Expand Down Expand Up @@ -571,10 +571,10 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
if tA_uc == 'C'
(nA, mA) = size(A,1), size(A,2)
tAt = oftype(tA, 'N')
tAt = 'N'
else
(mA, nA) = size(A,1), size(A,2)
tAt = oftype(tA, 'C')
tAt = 'C'
end
if nC != mA
throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)"))
Expand Down