Skip to content

Commit

Permalink
Merge pull request #8726 from JuliaLang/teh/AmulB
Browse files Browse the repository at this point in the history
Relax dimensionality requirements in `A*_mul_B*!` and `transpose!`
  • Loading branch information
timholy committed Oct 18, 2014
2 parents f4cc46f + ec48d95 commit 636f481
Showing 5 changed files with 58 additions and 56 deletions.
4 changes: 2 additions & 2 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
@@ -201,7 +201,7 @@ end

copy(a::AbstractArray) = copy!(similar(a), a)

function copy!{R,S}(B::AbstractMatrix{R}, ir_dest::Range{Int}, jr_dest::Range{Int}, A::AbstractMatrix{S}, ir_src::Range{Int}, jr_src::Range{Int})
function copy!{R,S}(B::AbstractVecOrMat{R}, ir_dest::Range{Int}, jr_dest::Range{Int}, A::AbstractVecOrMat{S}, ir_src::Range{Int}, jr_src::Range{Int})
if length(ir_dest) != length(ir_src) || length(jr_dest) != length(jr_src)
error("source and destination must have same size")
end
@@ -219,7 +219,7 @@ function copy!{R,S}(B::AbstractMatrix{R}, ir_dest::Range{Int}, jr_dest::Range{In
return B
end

function copy_transpose!{R,S}(B::AbstractMatrix{R}, ir_dest::Range{Int}, jr_dest::Range{Int}, A::AbstractVecOrMat{S}, ir_src::Range{Int}, jr_src::Range{Int})
function copy_transpose!{R,S}(B::AbstractVecOrMat{R}, ir_dest::Range{Int}, jr_dest::Range{Int}, A::AbstractVecOrMat{S}, ir_src::Range{Int}, jr_src::Range{Int})
if length(ir_dest) != length(jr_src) || length(jr_dest) != length(ir_src)
error("source and destination must have same size")
end
12 changes: 6 additions & 6 deletions base/array.jl
Original file line number Diff line number Diff line change
@@ -1229,9 +1229,9 @@ end

## Transpose ##
const transposebaselength=64
function transpose!(B::StridedMatrix,A::StridedMatrix)
function transpose!(B::StridedVecOrMat,A::StridedMatrix)
m, n = size(A)
size(B) == (n,m) || throw(DimensionMismatch("transpose"))
size(B,1) == n && size(B,2) == m || throw(DimensionMismatch("transpose"))

if m*n<=4*transposebaselength
@inbounds begin
@@ -1246,7 +1246,7 @@ function transpose!(B::StridedMatrix,A::StridedMatrix)
end
return B
end
function transposeblock!(B::StridedMatrix,A::StridedMatrix,m::Int,n::Int,offseti::Int,offsetj::Int)
function transposeblock!(B::StridedVecOrMat,A::StridedMatrix,m::Int,n::Int,offseti::Int,offsetj::Int)
if m*n<=transposebaselength
@inbounds begin
for j = offsetj+(1:n)
@@ -1266,9 +1266,9 @@ function transposeblock!(B::StridedMatrix,A::StridedMatrix,m::Int,n::Int,offseti
end
return B
end
function ctranspose!(B::StridedMatrix,A::StridedMatrix)
function ctranspose!(B::StridedVecOrMat,A::StridedMatrix)
m, n = size(A)
size(B) == (n,m) || throw(DimensionMismatch("transpose"))
size(B,1) == n && size(B,2) == m || throw(DimensionMismatch("transpose"))

if m*n<=4*transposebaselength
@inbounds begin
@@ -1283,7 +1283,7 @@ function ctranspose!(B::StridedMatrix,A::StridedMatrix)
end
return B
end
function ctransposeblock!(B::StridedMatrix,A::StridedMatrix,m::Int,n::Int,offseti::Int,offsetj::Int)
function ctransposeblock!(B::StridedVecOrMat,A::StridedMatrix,m::Int,n::Int,offseti::Int,offsetj::Int)
if m*n<=transposebaselength
@inbounds begin
for j = offsetj+(1:n)
6 changes: 3 additions & 3 deletions base/linalg/blas.jl
Original file line number Diff line number Diff line change
@@ -272,8 +272,8 @@ for (fname, elty) in ((:dgemv_,:Float64),
# CHARACTER TRANS
#* .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function gemv!(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty})
m,n = size(A)
function gemv!(trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty}, X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty})
m,n = size(A,1),size(A,2)
length(X) == (trans == 'N' ? n : m) && length(Y) == (trans == 'N' ? m : n) || throw(DimensionMismatch(""))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
@@ -556,7 +556,7 @@ for (gemm, elty) in
# CHARACTER TRANSA,TRANSB
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function gemm!(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty}, B::StridedMatrix{$elty}, beta::($elty), C::StridedVecOrMat{$elty})
function gemm!(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty}, B::StridedVecOrMat{$elty}, beta::($elty), C::StridedVecOrMat{$elty})
# if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1)
# error("gemm!: BLAS module requires contiguous matrix columns")
# end # should this be checked on every call?
86 changes: 41 additions & 45 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
@@ -73,18 +73,18 @@ function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
end
(*)(A::AbstractVector, B::AbstractMatrix) = reshape(A,length(A),1)*B

A_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = gemv!(y, 'N', A, x)
A_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'N', A, x)
for elty in (Float32,Float64)
@eval begin
function A_mul_B!(y::StridedVector{Complex{$elty}}, A::StridedMatrix{Complex{$elty}}, x::StridedVector{$elty})
function A_mul_B!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty})
Afl = reinterpret($elty,A,(2size(A,1),size(A,2)))
yfl = reinterpret($elty,y)
gemv!(yfl,'N',Afl,x)
return y
end
end
end
A_mul_B!(y::StridedVector, A::StridedMatrix, x::StridedVector) = generic_matvecmul!(y, 'N', A, x)
A_mul_B!(y::StridedVector, A::StridedVecOrMat, x::StridedVector) = generic_matvecmul!(y, 'N', A, x)

function At_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_type(arithtype(T),arithtype(S))
@@ -94,8 +94,8 @@ function At_mul_B{T,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_type(arithtype(T),arithtype(S))
At_mul_B!(similar(x,TS,size(A,2)), A, x)
end
At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x)
At_mul_B!(y::StridedVector, A::StridedMatrix, x::StridedVector) = generic_matvecmul!(y, 'T', A, x)
At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x)
At_mul_B!(y::StridedVector, A::StridedVecOrMat, x::StridedVector) = generic_matvecmul!(y, 'T', A, x)

function Ac_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_type(arithtype(T),arithtype(S))
@@ -106,83 +106,83 @@ function Ac_mul_B{T,S}(A::StridedMatrix{T}, x::StridedVector{S})
Ac_mul_B!(similar(x,TS,size(A,2)), A, x)
end

Ac_mul_B!{T<:BlasReal}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = At_mul_B!(y, A, x)
Ac_mul_B!{T<:BlasComplex}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = gemv!(y, 'C', A, x)
Ac_mul_B!(y::StridedVector, A::StridedMatrix, x::StridedVector) = generic_matvecmul!(y, 'C', A, x)
Ac_mul_B!{T<:BlasReal}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = At_mul_B!(y, A, x)
Ac_mul_B!{T<:BlasComplex}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'C', A, x)
Ac_mul_B!(y::StridedVector, A::StridedVecOrMat, x::StridedVector) = generic_matvecmul!(y, 'C', A, x)

# Matrix-matrix multiplication

function (*){T,S}(A::AbstractMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T), arithtype(S))
A_mul_B!(similar(B, TS, (size(A,1), size(B,2))), A, B)
end
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
for elty in (Float32,Float64)
@eval begin
function A_mul_B!(C::StridedMatrix{Complex{$elty}}, A::StridedMatrix{Complex{$elty}}, B::StridedMatrix{$elty})
function A_mul_B!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
Afl = reinterpret($elty, A, (2size(A,1), size(A,2)))
Cfl = reinterpret($elty, C, (2size(C,1), size(C,2)))
gemm_wrapper!(Cfl, 'N', 'N', Afl, B)
return C
end
end
end
A_mul_B!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'N', 'N', A, B)
A_mul_B!(C::StridedMatrix, A::StridedVecOrMat, B::StridedVecOrMat) = generic_matmatmul!(C, 'N', 'N', A, B)

function At_mul_B{T,S}(A::AbstractMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T), arithtype(S))
At_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = is(A,B) ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)
At_mul_B!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'T', 'N', A, B)
At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)
At_mul_B!(C::StridedMatrix, A::StridedVecOrMat, B::StridedVecOrMat) = generic_matmatmul!(C, 'T', 'N', A, B)

function A_mul_Bt{T,S}(A::AbstractMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T), arithtype(S))
A_mul_Bt!(similar(B, TS, (size(A,1), size(B,1))), A, B)
end
A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = is(A,B) ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)
A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)
for elty in (Float32,Float64)
@eval begin
function A_mul_Bt!(C::StridedMatrix{Complex{$elty}}, A::StridedMatrix{Complex{$elty}}, B::StridedMatrix{$elty})
function A_mul_Bt!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
Afl = reinterpret($elty, A, (2size(A,1), size(A,2)))
Cfl = reinterpret($elty, C, (2size(C,1), size(C,2)))
gemm_wrapper!(Cfl, 'N', 'T', Afl, B)
return C
end
end
end
A_mul_Bt!(C::StridedVecOrMat, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'N', 'T', A, B)
A_mul_Bt!(C::StridedVecOrMat, A::StridedVecOrMat, B::StridedVecOrMat) = generic_matmatmul!(C, 'N', 'T', A, B)

function At_mul_Bt{T,S}(A::AbstractMatrix{T}, B::StridedMatrix{S})
function At_mul_Bt{T,S}(A::AbstractMatrix{T}, B::StridedVecOrMat{S})
TS = promote_type(arithtype(T), arithtype(S))
At_mul_Bt!(similar(B, TS, (size(A,2), size(B,1))), A, B)
end
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper!(C, 'T', 'T', A, B)
At_mul_Bt!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'T', 'T', A, B)
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'T', 'T', A, B)
At_mul_Bt!(C::StridedMatrix, A::StridedVecOrMat, B::StridedVecOrMat) = generic_matmatmul!(C, 'T', 'T', A, B)

Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B(A, B)
Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B!(C, A, B)
Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = At_mul_B!(C, A, B)
function Ac_mul_B{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T), arithtype(S))
Ac_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = is(A,B) ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)
Ac_mul_B!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'C', 'N', A, B)
Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)
Ac_mul_B!(C::StridedMatrix, A::StridedVecOrMat, B::StridedVecOrMat) = generic_matmatmul!(C, 'C', 'N', A, B)

A_mul_Bc{T<:BlasFloat,S<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt(A, B)
A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt!(C, A, B)
A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{S}) = A_mul_Bt!(C, A, B)
function A_mul_Bc{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T),arithtype(S))
A_mul_Bc!(similar(B,TS,(size(A,1),size(B,1))),A,B)
end
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = is(A,B) ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)
A_mul_Bc!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'N', 'C', A, B)
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)
A_mul_Bc!(C::StridedMatrix, A::StridedVecOrMat, B::StridedVecOrMat) = generic_matmatmul!(C, 'N', 'C', A, B)

Ac_mul_Bc{T,S}(A::AbstractMatrix{T}, B::StridedMatrix{S}) = Ac_mul_Bc!(similar(B, promote_type(arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper!(C, 'C', 'C', A, B)
Ac_mul_Bc!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'C', 'C', A, B)
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'C', 'C', A, B)
Ac_mul_Bc!(C::StridedMatrix, A::StridedVecOrMat, B::StridedVecOrMat) = generic_matmatmul!(C, 'C', 'C', A, B)
Ac_mul_Bt{T,S}(A::AbstractMatrix{T}, B::StridedMatrix{S}) = Ac_mul_Bt(similar(B, promote_type(arithtype(A), arithtype(B)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bt!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'C', 'T', A, B)
Ac_mul_Bt!(C::StridedMatrix, A::StridedVecOrMat, B::StridedVecOrMat) = generic_matmatmul!(C, 'C', 'T', A, B)

# Supporting functions for matrix multiplication

@@ -203,27 +203,23 @@ function copytri!(A::StridedMatrix, uplo::Char, conjugate::Bool=false)
A
end

function gemv!{T<:BlasFloat}(y::StridedVector{T}, tA::Char, A::StridedMatrix{T}, x::StridedVector{T})
function gemv!{T<:BlasFloat}(y::StridedVector{T}, tA::Char, A::StridedVecOrMat{T}, x::StridedVector{T})
stride(A, 1)==1 || return generic_matvecmul!(y, tA, A, x)
if tA != 'N'
(nA, mA) = size(A)
else
(mA, nA) = size(A)
end
mA, nA = lapack_size(tA, A)
nA==length(x) || throw(DimensionMismatch(""))
mA==length(y) || throw(DimensionMismatch(""))
mA == 0 && return y
nA == 0 && return fill!(y,0)
return BLAS.gemv!(tA, one(T), A, x, zero(T), y)
end

function syrk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedMatrix{T})
function syrk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedVecOrMat{T})
nC = chksquare(C)
if tA == 'T'
(nA, mA) = size(A)
(nA, mA) = size(A,1), size(A,2)
tAt = 'N'
else
(mA, nA) = size(A)
(mA, nA) = size(A,1), size(A,2)
tAt = 'T'
end
nC == mA || throw(DimensionMismatch("output matrix has size: $(nC), but should have size $(mA)"))
@@ -235,13 +231,13 @@ function syrk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedMa
copytri!(BLAS.syrk!('U', tA, one(T), A, zero(T), C), 'U')
end

function herk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedMatrix{T})
function herk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedVecOrMat{T})
nC = chksquare(C)
if tA == 'C'
(nA, mA) = size(A)
(nA, mA) = size(A,1), size(A,2)
tAt = 'N'
else
(mA, nA) = size(A)
(mA, nA) = size(A,1), size(A,2)
tAt = 'C'
end
nC == mA || throw(DimensionMismatch("output matrix has size: $(nC), but should have size $(mA)"))
@@ -259,7 +255,7 @@ end

function gemm_wrapper{T<:BlasFloat}(tA::Char, tB::Char,
A::StridedVecOrMat{T},
B::StridedMatrix{T})
B::StridedVecOrMat{T})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, T, mA, nB)
@@ -268,7 +264,7 @@ end

function gemm_wrapper!{T<:BlasFloat}(C::StridedVecOrMat{T}, tA::Char, tB::Char,
A::StridedVecOrMat{T},
B::StridedMatrix{T})
B::StridedVecOrMat{T})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)

@@ -290,7 +286,7 @@ end

lapack_size(t::Char, M::AbstractVecOrMat) = (size(M, t=='N' ? 1:2), size(M, t=='N' ? 2:1))

function copy!{R,S}(B::AbstractMatrix{R}, ir_dest::UnitRange{Int}, jr_dest::UnitRange{Int}, tM::Char, M::AbstractMatrix{S}, ir_src::UnitRange{Int}, jr_src::UnitRange{Int})
function copy!{R,S}(B::AbstractVecOrMat{R}, ir_dest::UnitRange{Int}, jr_dest::UnitRange{Int}, tM::Char, M::AbstractVecOrMat{S}, ir_src::UnitRange{Int}, jr_src::UnitRange{Int})
if tM == 'N'
copy!(B, ir_dest, jr_dest, M, ir_src, jr_src)
else
@@ -314,7 +310,7 @@ end
# NOTE: the generic version is also called as fallback for
# strides != 1 cases

function generic_matvecmul!{T,S,R}(C::AbstractVector{R}, tA, A::AbstractMatrix{T}, B::AbstractVector{S})
function generic_matvecmul!{T,S,R}(C::AbstractVector{R}, tA, A::AbstractVecOrMat{T}, B::AbstractVector{S})
mB = length(B)
mA, nA = lapack_size(tA, A)
mB==nA || throw(DimensionMismatch("*"))
@@ -366,7 +362,7 @@ const Abuf = Array(Uint8, tilebufsize)
const Bbuf = Array(Uint8, tilebufsize)
const Cbuf = Array(Uint8, tilebufsize)

function generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
function generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
mB==nA || throw(DimensionMismatch("*"))
6 changes: 6 additions & 0 deletions test/linalg3.jl
Original file line number Diff line number Diff line change
@@ -92,6 +92,12 @@ C = Array(Int, size(A, 1), size(B, 2))
@test At_mul_B!(C, A, B) == A'*B
@test A_mul_Bt!(C, A, B) == A*B'
@test At_mul_Bt!(C, A, B) == A'*B'
v = [1,2,3]
C = Array(Int, 3, 3)
@test A_mul_Bt!(C, v, v) == v*v'
vf = float64(v)
C = Array(Float64, 3, 3)
@test A_mul_Bt!(C, v, v) == v*v'

# matrix algebra with subarrays of floats (stride != 1)
A = reshape(float64(1:20),5,4)

0 comments on commit 636f481

Please sign in to comment.