From 483b2433bf1efdd1feb6af196d1883c67c14eb6d Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 4 Jun 2024 21:18:36 +0530 Subject: [PATCH] Improve error message in inplace transpose --- stdlib/LinearAlgebra/src/transpose.jl | 15 ++++++++++----- stdlib/LinearAlgebra/test/adjtrans.jl | 10 ++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/stdlib/LinearAlgebra/src/transpose.jl b/stdlib/LinearAlgebra/src/transpose.jl index 8aa04f7d34b48..a36919b2e557a 100644 --- a/stdlib/LinearAlgebra/src/transpose.jl +++ b/stdlib/LinearAlgebra/src/transpose.jl @@ -74,27 +74,32 @@ julia> A ``` """ adjoint!(B::AbstractMatrix, A::AbstractMatrix) = transpose_f!(adjoint, B, A) + +@noinline function check_transpose_axes(axesA, axesB) + axesB == reverse(axesA) || throw(DimensionMismatch("axes of the destination are incompatible with that of the source")) +end + function transpose!(B::AbstractVector, A::AbstractMatrix) - axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes((axes(B,1), axes(B,2)), axes(A)) copyto!(B, A) end function transpose!(B::AbstractMatrix, A::AbstractVector) - axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes(axes(B), (axes(A,1), axes(A,2))) copyto!(B, A) end function adjoint!(B::AbstractVector, A::AbstractMatrix) - axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes((axes(B,1), axes(B,2)), axes(A)) ccopy!(B, A) end function adjoint!(B::AbstractMatrix, A::AbstractVector) - axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes(axes(B), (axes(A,1), axes(A,2))) ccopy!(B, A) end const transposebaselength=64 function transpose_f!(f, B::AbstractMatrix, A::AbstractMatrix) inds = axes(A) - axes(B,1) == inds[2] && axes(B,2) == inds[1] || throw(DimensionMismatch(string(f))) + check_transpose_axes(axes(B), inds) m, n = length(inds[1]), length(inds[2]) if m*n<=4*transposebaselength diff --git a/stdlib/LinearAlgebra/test/adjtrans.jl b/stdlib/LinearAlgebra/test/adjtrans.jl index 2c533af37f912..1a66c7430723e 100644 --- a/stdlib/LinearAlgebra/test/adjtrans.jl +++ b/stdlib/LinearAlgebra/test/adjtrans.jl @@ -703,4 +703,14 @@ end @test B == At end +@testset "error message in transpose" begin + v = zeros(2) + A = zeros(1,1) + B = zeros(2,3) + for (t1, t2) in Any[(A, v), (v, A), (A, B)] + @test_throws "axes of the destination are incompatible with that of the source" transpose!(t1, t2) + @test_throws "axes of the destination are incompatible with that of the source" adjoint!(t1, t2) + end +end + end # module TestAdjointTranspose