Skip to content

Commit

Permalink
mat: make QR satisfy Matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-ch committed Oct 6, 2023
1 parent aef3c5f commit 45b7421
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 19 deletions.
1 change: 1 addition & 0 deletions lapack/lapack.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Float64 interface {
Dlansy(norm MatrixNorm, uplo blas.Uplo, n int, a []float64, lda int, work []float64) float64
Dlapmr(forward bool, m, n int, x []float64, ldx int, k []int)
Dlapmt(forward bool, m, n int, x []float64, ldx int, k []int)
Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
Dpbcon(uplo blas.Uplo, n, kd int, ab []float64, ldab int, anorm float64, work []float64, iwork []int) float64
Expand Down
31 changes: 27 additions & 4 deletions lapack/lapack64/lapack64.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,28 @@ func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64
lapack64.Dormlq(side, trans, c.Rows, c.Cols, a.Rows, a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork)
}

// Orgqr generates an m×n matrix Q with orthonormal columns defined by the
// product of elementary reflectors
//
// Q = H_0 * H_1 * ... * H_{k-1}
//
// as computed by Geqrf.
//
// k is determined by the length of tau.
//
// The length of work must be at least n and it also must be that 0 <= k <= n
// and 0 <= n <= m.
//
// work is temporary storage, and lwork specifies the usable memory length. At
// minimum, lwork >= n, and the amount of blocking is limited by the usable
// length. If lwork == -1, instead of computing Orgqr the optimal work length
// is stored into work[0].
//
// Orgqr will panic if the conditions on input values are not met.
func Orgqr(a blas64.General, tau []float64, work []float64, lwork int) {
lapack64.Dorgqr(a.Rows, a.Cols, len(tau), a.Data, a.Stride, tau, work, lwork)
}

// Ormqr multiplies an m×n matrix C by an orthogonal matrix Q as
//
// C = Q * C if side == blas.Left and trans == blas.NoTrans,
Expand All @@ -705,12 +727,13 @@ func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64
//
// Q = H_0 * H_1 * ... * H_{k-1}.
//
// k is determined by the length of tau.
//
// If side == blas.Left, A is an m×k matrix and 0 <= k <= m.
// If side == blas.Right, A is an n×k matrix and 0 <= k <= n.
// The ith column of A contains the vector which defines the elementary
// reflector H_i and tau[i] contains its scalar factor. tau must have length k
// and Ormqr will panic otherwise. Geqrf returns A and tau in the required
// form.
// reflector H_i and tau[i] contains its scalar factor. Geqrf returns A and tau
// in the required form.
//
// work must have length at least max(1,lwork), and lwork must be at least n if
// side == blas.Left and at least m if side == blas.Right, otherwise Ormqr will
Expand All @@ -725,7 +748,7 @@ func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64
// If lwork is -1, instead of performing Ormqr, the optimal workspace size will
// be stored into work[0].
func Ormqr(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) {
lapack64.Dormqr(side, trans, c.Rows, c.Cols, a.Cols, a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork)
lapack64.Dormqr(side, trans, c.Rows, c.Cols, len(tau), a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork)
}

// Pocon estimates the reciprocal of the condition number of a positive-definite
Expand Down
66 changes: 51 additions & 15 deletions mat/qr.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,42 @@ const badQR = "mat: invalid QR factorization"
// QR is a type for creating and using the QR factorization of a matrix.
type QR struct {
qr *Dense
q *Dense
tau []float64
cond float64
}

// Dims returns the dimensions of the matrix.
func (qr *QR) Dims() (r, c int) {
if qr.qr == nil {
return 0, 0
}
return qr.qr.Dims()
}

// At returns the element at row i, column j.
func (qr *QR) At(i, j int) float64 {
m, n := qr.Dims()
if uint(i) >= uint(m) {
panic(ErrRowAccess)
}
if uint(j) >= uint(n) {
panic(ErrColAccess)
}

var val float64
for k := 0; k <= j; k++ {
val += qr.q.at(i, k) * qr.qr.at(k, j)
}
return val
}

// T performs an implicit transpose by returning the receiver inside a
// Transpose.
func (qr *QR) T() Matrix {
return Transpose{qr}
}

func (qr *QR) updateCond(norm lapack.MatrixNorm) {
// Since A = Q*R, and Q is orthogonal, we get for the condition number κ
// κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Qᵀ|
Expand Down Expand Up @@ -55,18 +87,34 @@ func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) {
if m < n {
panic(ErrShape)
}
k := min(m, n)
if qr.qr == nil {
qr.qr = &Dense{}
}
qr.qr.CloneFrom(a)
work := []float64{0}
qr.tau = make([]float64, k)
qr.tau = make([]float64, n)
lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1)
work = getFloat64s(int(work[0]), false)
lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work))
putFloat64s(work)
qr.updateCond(norm)
qr.updateQ()
}

func (qr *QR) updateQ() {
m, _ := qr.Dims()
if qr.q == nil {
qr.q = NewDense(m, m, nil)
} else {
qr.q.reuseAsNonZeroed(m, m)
}
// Construct Q from the elementary reflectors.
qr.q.Copy(qr.qr)
work := []float64{0}
lapack64.Orgqr(qr.q.mat, qr.tau, work, -1)
work = getFloat64s(int(work[0]), false)
lapack64.Orgqr(qr.q.mat, qr.tau, work, len(work))
putFloat64s(work)
}

// isValid returns whether the receiver contains a factorization.
Expand Down Expand Up @@ -143,20 +191,8 @@ func (qr *QR) QTo(dst *Dense) {
if r != r2 || r != c2 {
panic(ErrShape)
}
dst.Zero()
}

// Set Q = I.
for i := 0; i < r*r; i += r + 1 {
dst.mat.Data[i] = 1
}

// Construct Q from the elementary reflectors.
work := []float64{0}
lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, -1)
work = getFloat64s(int(work[0]), false)
lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, len(work))
putFloat64s(work)
dst.Copy(qr.q)
}

// SolveTo finds a minimum-norm solution to a system of linear equations defined
Expand Down
7 changes: 7 additions & 0 deletions mat/qr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ func TestQR(t *testing.T) {
t.Errorf("Q is not orthonormal: m = %v, n = %v", m, n)
}

if !EqualApprox(a, &qr, 1e-14) {
t.Errorf("m=%d,n=%d: A and QR are not equal", m, n)
}
if !EqualApprox(a.T(), qr.T(), 1e-14) {
t.Errorf("m=%d,n=%d: Aᵀ and (QR)ᵀ are not equal", m, n)
}

qr.RTo(&r)

var got Dense
Expand Down

0 comments on commit 45b7421

Please sign in to comment.