diff --git a/lapack/lapack.go b/lapack/lapack.go index a08324d9b..b61df69c0 100644 --- a/lapack/lapack.go +++ b/lapack/lapack.go @@ -27,6 +27,7 @@ type Float64 interface { Dlansy(norm MatrixNorm, uplo blas.Uplo, n int, a []float64, lda int, work []float64) float64 Dlapmr(forward bool, m, n int, x []float64, ldx int, k []int) Dlapmt(forward bool, m, n int, x []float64, ldx int, k []int) + Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) Dpbcon(uplo blas.Uplo, n, kd int, ab []float64, ldab int, anorm float64, work []float64, iwork []int) float64 diff --git a/lapack/lapack64/lapack64.go b/lapack/lapack64/lapack64.go index 2e623974c..a3b91ff73 100644 --- a/lapack/lapack64/lapack64.go +++ b/lapack/lapack64/lapack64.go @@ -694,6 +694,28 @@ func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64 lapack64.Dormlq(side, trans, c.Rows, c.Cols, a.Rows, a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork) } +// Orgqr generates an m×n matrix Q with orthonormal columns defined by the +// product of elementary reflectors +// +// Q = H_0 * H_1 * ... * H_{k-1} +// +// as computed by Geqrf. +// +// k is determined by the length of tau. +// +// The length of work must be at least n and it also must be that 0 <= k <= n +// and 0 <= n <= m. +// +// work is temporary storage, and lwork specifies the usable memory length. At +// minimum, lwork >= n, and the amount of blocking is limited by the usable +// length. If lwork == -1, instead of computing Orgqr the optimal work length +// is stored into work[0]. +// +// Orgqr will panic if the conditions on input values are not met. +func Orgqr(a blas64.General, tau []float64, work []float64, lwork int) { + lapack64.Dorgqr(a.Rows, a.Cols, len(tau), a.Data, a.Stride, tau, work, lwork) +} + // Ormqr multiplies an m×n matrix C by an orthogonal matrix Q as // // C = Q * C if side == blas.Left and trans == blas.NoTrans, @@ -705,12 +727,13 @@ func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64 // // Q = H_0 * H_1 * ... * H_{k-1}. // +// k is determined by the length of tau. +// // If side == blas.Left, A is an m×k matrix and 0 <= k <= m. // If side == blas.Right, A is an n×k matrix and 0 <= k <= n. // The ith column of A contains the vector which defines the elementary -// reflector H_i and tau[i] contains its scalar factor. tau must have length k -// and Ormqr will panic otherwise. Geqrf returns A and tau in the required -// form. +// reflector H_i and tau[i] contains its scalar factor. Geqrf returns A and tau +// in the required form. // // work must have length at least max(1,lwork), and lwork must be at least n if // side == blas.Left and at least m if side == blas.Right, otherwise Ormqr will @@ -725,7 +748,7 @@ func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64 // If lwork is -1, instead of performing Ormqr, the optimal workspace size will // be stored into work[0]. func Ormqr(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) { - lapack64.Dormqr(side, trans, c.Rows, c.Cols, a.Cols, a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork) + lapack64.Dormqr(side, trans, c.Rows, c.Cols, len(tau), a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork) } // Pocon estimates the reciprocal of the condition number of a positive-definite diff --git a/mat/qr.go b/mat/qr.go index f54bcc869..af99dbcaa 100644 --- a/mat/qr.go +++ b/mat/qr.go @@ -18,10 +18,42 @@ const badQR = "mat: invalid QR factorization" // QR is a type for creating and using the QR factorization of a matrix. type QR struct { qr *Dense + q *Dense tau []float64 cond float64 } +// Dims returns the dimensions of the matrix. +func (qr *QR) Dims() (r, c int) { + if qr.qr == nil { + return 0, 0 + } + return qr.qr.Dims() +} + +// At returns the element at row i, column j. +func (qr *QR) At(i, j int) float64 { + m, n := qr.Dims() + if uint(i) >= uint(m) { + panic(ErrRowAccess) + } + if uint(j) >= uint(n) { + panic(ErrColAccess) + } + + var val float64 + for k := 0; k <= j; k++ { + val += qr.q.at(i, k) * qr.qr.at(k, j) + } + return val +} + +// T performs an implicit transpose by returning the receiver inside a +// Transpose. +func (qr *QR) T() Matrix { + return Transpose{qr} +} + func (qr *QR) updateCond(norm lapack.MatrixNorm) { // Since A = Q*R, and Q is orthogonal, we get for the condition number κ // κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Qᵀ| @@ -55,18 +87,34 @@ func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) { if m < n { panic(ErrShape) } - k := min(m, n) if qr.qr == nil { qr.qr = &Dense{} } qr.qr.CloneFrom(a) work := []float64{0} - qr.tau = make([]float64, k) + qr.tau = make([]float64, n) lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1) work = getFloat64s(int(work[0]), false) lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work)) putFloat64s(work) qr.updateCond(norm) + qr.updateQ() +} + +func (qr *QR) updateQ() { + m, _ := qr.Dims() + if qr.q == nil { + qr.q = NewDense(m, m, nil) + } else { + qr.q.reuseAsNonZeroed(m, m) + } + // Construct Q from the elementary reflectors. + qr.q.Copy(qr.qr) + work := []float64{0} + lapack64.Orgqr(qr.q.mat, qr.tau, work, -1) + work = getFloat64s(int(work[0]), false) + lapack64.Orgqr(qr.q.mat, qr.tau, work, len(work)) + putFloat64s(work) } // isValid returns whether the receiver contains a factorization. @@ -143,20 +191,8 @@ func (qr *QR) QTo(dst *Dense) { if r != r2 || r != c2 { panic(ErrShape) } - dst.Zero() } - - // Set Q = I. - for i := 0; i < r*r; i += r + 1 { - dst.mat.Data[i] = 1 - } - - // Construct Q from the elementary reflectors. - work := []float64{0} - lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, -1) - work = getFloat64s(int(work[0]), false) - lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, len(work)) - putFloat64s(work) + dst.Copy(qr.q) } // SolveTo finds a minimum-norm solution to a system of linear equations defined diff --git a/mat/qr_test.go b/mat/qr_test.go index f3dfdc839..b71bee56b 100644 --- a/mat/qr_test.go +++ b/mat/qr_test.go @@ -42,6 +42,13 @@ func TestQR(t *testing.T) { t.Errorf("Q is not orthonormal: m = %v, n = %v", m, n) } + if !EqualApprox(a, &qr, 1e-14) { + t.Errorf("m=%d,n=%d: A and QR are not equal", m, n) + } + if !EqualApprox(a.T(), qr.T(), 1e-14) { + t.Errorf("m=%d,n=%d: Aᵀ and (QR)ᵀ are not equal", m, n) + } + qr.RTo(&r) var got Dense