-
Notifications
You must be signed in to change notification settings - Fork 546
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
353 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
// Copyright ©2021 The Gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package r3 | ||
|
||
import ( | ||
"unsafe" | ||
|
||
"gonum.org/v1/gonum/blas/blas64" | ||
"gonum.org/v1/gonum/mat" | ||
) | ||
|
||
const ( | ||
badDim = "bad matrix dimensions" | ||
badIdx = "bad matrix index" | ||
) | ||
|
||
// Mat represents a 3×3 matrix. Useful for rotation matrices and such. | ||
type Mat struct { | ||
data *[3][3]float64 | ||
} | ||
|
||
var _ mat.Matrix = (*Mat)(nil) | ||
|
||
// NewMat returns a new 3×3 matrix Mat type and populates its elements | ||
// with values passed as argument in row-major form. If val argument | ||
// is nil then NewMat returns a matrix filled with zeros. | ||
func NewMat(val []float64) *Mat { | ||
if len(val) != 9 { | ||
if val == nil { | ||
return &Mat{data: new([3][3]float64)} | ||
} | ||
panic(badDim) | ||
} | ||
m := Mat{} | ||
m.setBackingSlice(val) | ||
return &m | ||
} | ||
|
||
// Dims returns the number of rows and columns of this matrix. | ||
// This method will always return 3×3 for a Mat. | ||
func (m *Mat) Dims() (r, c int) { return 3, 3 } | ||
|
||
// At returns the value of a matrix element at row i, column j. | ||
// At expects indices in the range [0,2]. | ||
// It will panic if i or j are out of bounds for the matrix. | ||
func (m *Mat) At(i, j int) float64 { | ||
return m.data[i][j] | ||
} | ||
|
||
// Set sets the element at row i, column j to the value v. | ||
func (m *Mat) Set(i, j int, v float64) { | ||
m.data[i][j] = v | ||
} | ||
|
||
// T returns the transpose of Mat. Changes in the receiver will be reflected in the returned matrix. | ||
func (m *Mat) T() mat.Matrix { return mat.Transpose{Matrix: m} } | ||
|
||
// RawMatrix returns the blas representation of the matrix with the backing data of this matrix. | ||
// Changes to the returned matrix will be reflected in the receiver. | ||
func (m *Mat) RawMatrix() blas64.General { | ||
return blas64.General{Rows: 3, Cols: 3, Data: m.backingSlice(), Stride: 3} | ||
} | ||
|
||
// Eye returns the 3×3 Identity matrix | ||
func Eye() *Mat { | ||
return &Mat{data: &[3][3]float64{ | ||
{1, 0, 0}, | ||
{0, 1, 0}, | ||
{0, 0, 1}, | ||
}} | ||
} | ||
|
||
// Scale multiplies the elements of a by f, placing the result in the receiver. | ||
// | ||
// See the mat.Scaler interface for more information. | ||
func (m *Mat) Scale(f float64, a mat.Matrix) { | ||
r, c := a.Dims() | ||
if r != 3 || c != 3 { | ||
panic(badDim) | ||
} | ||
for i := 0; i < 3; i++ { | ||
for j := 0; j < 3; j++ { | ||
m.Set(i, j, f*a.At(i, j)) | ||
} | ||
} | ||
} | ||
|
||
// Performs matrix multiplication on v: | ||
// result = M * v | ||
func (m *Mat) MulVec(v Vec) Vec { | ||
return Vec{ | ||
X: v.X*m.At(0, 0) + v.Y*m.At(0, 1) + v.Z*m.At(0, 2), | ||
Y: v.X*m.At(1, 0) + v.Y*m.At(1, 1) + v.Z*m.At(1, 2), | ||
Z: v.X*m.At(2, 0) + v.Y*m.At(2, 1) + v.Z*m.At(2, 2), | ||
} | ||
} | ||
|
||
// Performs transposed matrix multiplication on v: | ||
// result = Mᵀ * v | ||
func (m *Mat) MulVecTrans(v Vec) Vec { | ||
return Vec{ | ||
X: v.X*m.At(0, 0) + v.Y*m.At(1, 0) + v.Z*m.At(2, 0), | ||
Y: v.X*m.At(0, 1) + v.Y*m.At(1, 1) + v.Z*m.At(2, 1), | ||
Z: v.X*m.At(0, 2) + v.Y*m.At(1, 2) + v.Z*m.At(2, 2), | ||
} | ||
} | ||
|
||
// Skew returns the 3×3 skew symmetric matrix (right hand system) of v. | ||
// ⎡ 0 -z y⎤ | ||
// Skew({x,y,z}) = ⎢ z 0 -x⎥ | ||
// ⎣-y x 0⎦ | ||
func Skew(v Vec) (M *Mat) { | ||
return &Mat{data: &[3][3]float64{ | ||
{0, -v.Z, v.Y}, | ||
{v.Z, 0, -v.X}, | ||
{-v.Y, v.X, 0}, | ||
}} | ||
} | ||
|
||
// Mul takes the matrix product of a and b, placing the result in the receiver. | ||
// If the number of columns in a does not equal 3, Mul will panic. | ||
func (m *Mat) Mul(a, b mat.Matrix) { | ||
ra, ca := a.Dims() | ||
rb, cb := b.Dims() | ||
switch { | ||
case ra != 3: | ||
panic(badDim) | ||
case cb != 3: | ||
panic(badDim) | ||
case ca != rb: | ||
panic(badDim) | ||
} | ||
if ca != 3 { | ||
// General matrix multiplication for the case where the inner dimension is not 3. | ||
t := mat.NewDense(3, 3, m.backingSlice()) | ||
t.Mul(a, b) | ||
return | ||
} | ||
|
||
a00 := a.At(0, 0) | ||
b00 := b.At(0, 0) | ||
a01 := a.At(0, 1) | ||
b01 := b.At(0, 1) | ||
a02 := a.At(0, 2) | ||
b02 := b.At(0, 2) | ||
a10 := a.At(1, 0) | ||
b10 := b.At(1, 0) | ||
a11 := a.At(1, 1) | ||
b11 := b.At(1, 1) | ||
a12 := a.At(1, 2) | ||
b12 := b.At(1, 2) | ||
a20 := a.At(2, 0) | ||
b20 := b.At(2, 0) | ||
a21 := a.At(2, 1) | ||
b21 := b.At(2, 1) | ||
a22 := a.At(2, 2) | ||
b22 := b.At(2, 2) | ||
m.data[0][0] = a00*b00 + a01*b10 + a02*b20 | ||
m.data[0][1] = a00*b01 + a01*b11 + a02*b21 | ||
m.data[0][2] = a00*b02 + a01*b12 + a02*b22 | ||
m.data[1][0] = a10*b00 + a11*b10 + a12*b20 | ||
m.data[1][1] = a10*b01 + a11*b11 + a12*b21 | ||
m.data[1][2] = a10*b02 + a11*b12 + a12*b22 | ||
m.data[2][0] = a20*b00 + a21*b10 + a22*b20 | ||
m.data[2][1] = a20*b01 + a21*b11 + a22*b21 | ||
m.data[2][2] = a20*b02 + a21*b12 + a22*b22 | ||
} | ||
|
||
// CloneFrom makes a copy of a into the receiver m. | ||
// Mat expects a 3×3 input matrix. | ||
func (m *Mat) CloneFrom(a mat.Matrix) { | ||
r, c := a.Dims() | ||
if r != 3 || c != 3 { | ||
panic(badDim) | ||
} | ||
for i := 0; i < 3; i++ { | ||
for j := 0; j < 3; j++ { | ||
m.Set(i, j, a.At(i, j)) | ||
} | ||
} | ||
} | ||
|
||
// Sub subtracts the matrix b from a, placing the result in the receiver. | ||
// Sub will panic if the two matrices do not have the same shape. | ||
func (m *Mat) Sub(a, b mat.Matrix) { | ||
if r, c := a.Dims(); r != 3 || c != 3 { | ||
panic(badDim) | ||
} | ||
if r, c := b.Dims(); r != 3 || c != 3 { | ||
panic(badDim) | ||
} | ||
|
||
m.data[0][0] = a.At(0, 0) - b.At(0, 0) | ||
m.data[0][1] = a.At(0, 1) - b.At(0, 1) | ||
m.data[0][2] = a.At(0, 2) - b.At(0, 2) | ||
m.data[1][0] = a.At(1, 0) - b.At(1, 0) | ||
m.data[1][1] = a.At(1, 1) - b.At(1, 1) | ||
m.data[1][2] = a.At(1, 2) - b.At(1, 2) | ||
m.data[2][0] = a.At(2, 0) - b.At(2, 0) | ||
m.data[2][1] = a.At(2, 1) - b.At(2, 1) | ||
m.data[2][2] = a.At(2, 2) - b.At(2, 2) | ||
} | ||
|
||
// Add adds a and b element-wise, placing the result in the receiver. Add will panic if the two matrices do not have the same shape. | ||
func (m *Mat) Add(a, b mat.Matrix) { | ||
if r, c := a.Dims(); r != 3 || c != 3 { | ||
panic(badDim) | ||
} | ||
if r, c := b.Dims(); r != 3 || c != 3 { | ||
panic(badDim) | ||
} | ||
|
||
m.data[0][0] = a.At(0, 0) + b.At(0, 0) | ||
m.data[0][1] = a.At(0, 1) + b.At(0, 1) | ||
m.data[0][2] = a.At(0, 2) + b.At(0, 2) | ||
m.data[1][0] = a.At(1, 0) + b.At(1, 0) | ||
m.data[1][1] = a.At(1, 1) + b.At(1, 1) | ||
m.data[1][2] = a.At(1, 2) + b.At(1, 2) | ||
m.data[2][0] = a.At(2, 0) + b.At(2, 0) | ||
m.data[2][1] = a.At(2, 1) + b.At(2, 1) | ||
m.data[2][2] = a.At(2, 2) + b.At(2, 2) | ||
} | ||
|
||
// VecRow returns the elements in the ith row of the receiver. | ||
func (m *Mat) VecRow(i int) Vec { | ||
if i > 2 { | ||
panic(badIdx) | ||
} | ||
return Vec{X: m.At(i, 0), Y: m.At(i, 1), Z: m.At(i, 2)} | ||
} | ||
|
||
// VecCol returns the elements in the jth column of the receiver. | ||
func (m *Mat) VecCol(j int) Vec { | ||
if j > 2 { | ||
panic(badIdx) | ||
} | ||
return Vec{X: m.At(0, j), Y: m.At(1, j), Z: m.At(2, j)} | ||
} | ||
|
||
// setBackingSlice requires unsafe. | ||
func (m *Mat) setBackingSlice(vals []float64) { | ||
m.data = (*[3][3]float64)(unsafe.Pointer(&vals[0])) | ||
} | ||
|
||
// backingSlice requires unsafe. | ||
func (m *Mat) backingSlice() []float64 { | ||
return (*[9]float64)(unsafe.Pointer(m.data))[:] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
// Copyright ©2021 The Gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package r3 | ||
|
||
import ( | ||
"math" | ||
"testing" | ||
|
||
"golang.org/x/exp/rand" | ||
|
||
"gonum.org/v1/gonum/mat" | ||
) | ||
|
||
func TestMatScale(t *testing.T) { | ||
const tol = 1e-12 | ||
rnd := rand.New(rand.NewSource(1)) | ||
for tc := 0; tc < 20; tc++ { | ||
v := rnd.Float64() | ||
a := randomMat(rnd) | ||
gotmat := NewMat(nil) | ||
gotmat.Scale(v, a) | ||
for iv := range a.data { | ||
i := iv / 3 | ||
j := iv % 3 | ||
expect := v * a.At(i, j) | ||
got := gotmat.At(i, j) | ||
if math.Abs(got-expect) > tol { | ||
t.Errorf( | ||
"case %d: got=%v, want=%v", | ||
tc, got, expect) | ||
} | ||
} | ||
} | ||
} | ||
|
||
func TestMatCloneFrom(t *testing.T) { | ||
rnd := rand.New(rand.NewSource(1)) | ||
for tc := 0; tc < 20; tc++ { | ||
a := randomMat(rnd) | ||
gotmat := NewMat(nil) | ||
gotmat.CloneFrom(a) | ||
if !mat.Equal(a, gotmat) { | ||
t.Error("Clonefrom fail") | ||
} | ||
} | ||
} | ||
|
||
func TestSkew(t *testing.T) { | ||
rnd := rand.New(rand.NewSource(1)) | ||
for tc := 0; tc < 20; tc++ { | ||
v1 := randomVec(rnd) | ||
v2 := randomVec(rnd) | ||
sk := Skew(v1) | ||
got := sk.MulVec(v2) | ||
expect := Cross(v1, v2) | ||
if got != expect { | ||
t.Error("r3.Cross(v1,v2) not match with r3.Skew(v1)*v2") | ||
} | ||
} | ||
} | ||
|
||
func TestTranspose(t *testing.T) { | ||
rnd := rand.New(rand.NewSource(1)) | ||
for tc := 0; tc < 20; tc++ { | ||
d := mat.NewDense(3, 3, nil) | ||
m := randomMat(rnd) | ||
d.CloneFrom(m) | ||
mt := m.T() | ||
dt := d.T() | ||
if !mat.Equal(mt, dt) { | ||
t.Error("Dense.T() not equal to r3.Mat.T()") | ||
} | ||
vd := mat.NewVecDense(3, nil) | ||
v := randomVec(rnd) | ||
vd.SetVec(0, v.X) | ||
vd.SetVec(1, v.Y) | ||
vd.SetVec(2, v.Z) | ||
got := m.MulVecTrans(v) | ||
vd.MulVec(dt, vd) | ||
if vd.AtVec(0) != got.X || vd.AtVec(1) != got.Y || vd.AtVec(2) != got.Z { | ||
t.Error("VecDense.MulVec(dense.T()) not equal to r3.Mat.MulVec(r3.Vec)") | ||
} | ||
} | ||
} | ||
|
||
func randomMat(rnd *rand.Rand) *Mat { | ||
m := Mat{data: new([3][3]float64)} | ||
for iv := 0; iv < 9; iv++ { | ||
i := iv / 3 | ||
j := iv % 3 | ||
m.Set(i, j, (rnd.Float64()-0.5)*20) | ||
} | ||
return &m | ||
} | ||
|
||
func randomVec(rnd *rand.Rand) (v Vec) { | ||
v.X = (rnd.Float64() - 0.5) * 20 | ||
v.Y = (rnd.Float64() - 0.5) * 20 | ||
v.Z = (rnd.Float64() - 0.5) * 20 | ||
return v | ||
} |