Skip to content

Commit

Permalink
mat: add support for tridiagonal matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-ch committed Apr 9, 2021
1 parent 02e2805 commit 49182b1
Show file tree
Hide file tree
Showing 8 changed files with 1,052 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mat/band.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ func NewBandDense(r, c, kl, ku int, data []float64) *BandDense {
if r == 0 || c == 0 {
panic(ErrZeroLength)
}
panic("mat: negative dimension")
panic(ErrNegativeDimension)
}
if kl+1 > r || ku+1 > c {
panic("mat: band out of range")
panic(ErrBandwidth)
}
bc := kl + ku + 1
if data != nil && len(data) != min(r, c+kl)*bc {
Expand Down
8 changes: 8 additions & 0 deletions mat/basictypes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ func (m *basicMatrix) At(r, c int) float64 { return (*Dense)(m).At(r, c) }
func (m *basicMatrix) Dims() (r, c int) { return (*Dense)(m).Dims() }
func (m *basicMatrix) T() Matrix { return Transpose{m} }

type rawMatrix struct {
*basicMatrix
}

func (a *rawMatrix) RawMatrix() blas64.General {
return a.basicMatrix.mat
}

type basicVector VecDense

var _ Vector = &basicVector{}
Expand Down
49 changes: 49 additions & 0 deletions mat/index_bound_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,52 @@ func (d *DiagDense) setDiag(i int, v float64) {
}
d.mat.Data[i*d.mat.Inc] = v
}

// At returns the element at row i, column j.
func (a *Tridiag) At(i, j int) float64 {
return a.at(i, j)
}

func (a *Tridiag) at(i, j int) float64 {
if uint(i) >= uint(d.mat.N) {
panic(ErrRowAccess)
}
if uint(j) >= uint(d.mat.N) {
panic(ErrColAccess)
}
switch i - j {
case -1:
return a.mat.DU[i]
case 0:
return a.mat.D[i]
case 1:
return a.mat.DL[j]
default:
return 0
}
}

// SetBand sets the element at row i, column j to the value v.
// It panics if the location is outside the appropriate region of the matrix.
func (a *Tridiag) SetBand(i, j int, v float64) {
a.set(i, j, v)
}

func (a *Tridiag) set(i, j int, v float64) {
if uint(i) >= uint(a.mat.N) {
panic(ErrRowAccess)
}
if uint(j) >= uint(a.mat.N) {
panic(ErrColAccess)
}
switch i - j {
case -1:
a.mat.DU[i] = v
case 0:
a.mat.D[i] = v
case 1:
a.mat.DL[j] = v
default:
panic(ErrBandSet)
}
}
49 changes: 49 additions & 0 deletions mat/index_no_bound_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,52 @@ func (d *DiagDense) SetDiag(i int, v float64) {
func (d *DiagDense) setDiag(i int, v float64) {
d.mat.Data[i*d.mat.Inc] = v
}

// At returns the element at row i, column j.
func (a *Tridiag) At(i, j int) float64 {
if uint(i) >= uint(a.mat.N) {
panic(ErrRowAccess)
}
if uint(j) >= uint(a.mat.N) {
panic(ErrColAccess)
}
return a.at(i, j)
}

func (a *Tridiag) at(i, j int) float64 {
switch i - j {
case -1:
return a.mat.DU[i]
case 0:
return a.mat.D[i]
case 1:
return a.mat.DL[j]
default:
return 0
}
}

// SetBand sets the element at row i, column j to the value v.
// It panics if the location is outside the appropriate region of the matrix.
func (a *Tridiag) SetBand(i, j int, v float64) {
if uint(i) >= uint(a.mat.N) {
panic(ErrRowAccess)
}
if uint(j) >= uint(a.mat.N) {
panic(ErrColAccess)
}
a.set(i, j, v)
}

func (a *Tridiag) set(i, j int, v float64) {
switch i - j {
case -1:
a.mat.DU[i] = v
case 0:
a.mat.D[i] = v
case 1:
a.mat.DL[j] = v
default:
panic(ErrBandSet)
}
}
6 changes: 5 additions & 1 deletion mat/matrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func untranspose(a Matrix) (Matrix, bool) {
func untransposeExtract(a Matrix) (Matrix, bool) {
ut, trans := untranspose(a)
switch m := ut.(type) {
case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense, *VecDense:
case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense, *VecDense, *Tridiag:
return m, trans
// TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
case RawSymBander:
Expand Down Expand Up @@ -291,6 +291,10 @@ func untransposeExtract(a Matrix) (Matrix, bool) {
var v VecDense
v.SetRawVector(m.RawVector())
return &v, trans
case RawTridiagonaler:
var d Tridiag
d.SetRawTridiagonal(m.RawTridiagonal())
return &d, trans
default:
return ut, trans
}
Expand Down
12 changes: 12 additions & 0 deletions mat/matrix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,12 @@ func TestDoer(t *testing.T) {
NewSymBandDense(3, 1, ones(3*(1+1))),
NewSymBandDense(6, 1, ones(6*(1+1))),
NewSymBandDense(6, 2, ones(6*(2+1))),
NewTridiag(1, nil, ones(1), nil),
NewTridiag(2, ones(1), ones(2), ones(1)),
NewTridiag(3, ones(2), ones(3), ones(2)),
NewTridiag(4, ones(3), ones(4), ones(3)),
NewTridiag(7, ones(6), ones(7), ones(6)),
NewTridiag(10, ones(9), ones(10), ones(9)),
} {
r, c := m.Dims()

Expand Down Expand Up @@ -714,6 +720,12 @@ func TestMulVecToer(t *testing.T) {
NewSymBandDense(10, 0, random(10)),
NewSymBandDense(10, 1, random(20)),
NewSymBandDense(10, 4, random(50)),
NewTridiag(1, nil, random(1), nil),
NewTridiag(2, random(1), random(2), random(1)),
NewTridiag(3, random(2), random(3), random(2)),
NewTridiag(4, random(3), random(4), random(3)),
NewTridiag(7, random(6), random(7), random(6)),
NewTridiag(10, random(9), random(10), random(9)),
} {
// Dense copy of A used for computing the expected result.
var aDense Dense
Expand Down
Loading

0 comments on commit 49182b1

Please sign in to comment.