Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

statmat: added multi linear and lasso regression #1998

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

aouyang1
Copy link

@aouyang1 aouyang1 commented Oct 28, 2024

Hello! First time contributor here. I've been working on some multi-linear regression work lately and wanted to see if some of that could also be integrated here. This only partially solves some of the features requested in #1865 (MISO) but can extend it to more if we want to in this PR. I tried my best to follow similar conventions that I saw in the existing modules, but let me know what I can change to make it more similar.

The 2 new regression structs are OLSRegression and LassoRegression.

  • OLSRegressions uses a QR decomposition to compute ordinary least squares with multiple features. I'm sure there's more efficient algorithms here.
  • LassoRegression uses coordinate descent to find the optimal weights.

Each implement the following:

Fit(x, y mat.Matrix) (float64, []float64)
Predict(x mat.Matrix) []float64
Score(x, y mat.Matrix) float64

Inspired by the Python sklearn interface

Copy link
Member

@kortschak kortschak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've taken a brief look through this. There is more API here that I think is necessary. Also there are error returns that do not conform to the approach that we use in Gonum packages; we use error returns for error conditions that the user could not know before calling a function, but panics for cases where the calling parameters do not conform to the documented invariants for the call. Please match this.

I'll take a deeper look in the next week or so.

@aouyang1
Copy link
Author

aouyang1 commented Nov 4, 2024

I've taken a brief look through this. There is more API here that I think is necessary. Also there are error returns that do not conform to the approach that we use in Gonum packages; we use error returns for error conditions that the user could not know before calling a function, but panics for cases where the calling parameters do not conform to the documented invariants for the call. Please match this.

I'll take a deeper look in the next week or so.

Thanks for explaining the differences! Will get the errors addressed and converted.

Copy link
Member

@kortschak kortschak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API is too complex here. I'd estimate that we could make this somewhere between 1/3 and 2/3 of the code that's here by removing the extraneous code.

stat/statmat.go Outdated
Comment on lines 166 to 173
// Validate runs basic validation on OLS options
func (o *OLSOptions) Validate() *OLSOptions {
if o == nil {
o = NewDefaultOLSOptions()
}

return o
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method doesn't do what is on the tin. Validate implies that it checks that it is correct, this ensures that it is correct. I cannot think of another example where we do something like this in Gonum packages.

stat/statmat.go Outdated
Comment on lines 175 to 180
// NewDefaultOLSOptions returns a default set of OLS Regression options
func NewDefaultOLSOptions() *OLSOptions {
return &OLSOptions{
FitIntercept: true,
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like more API than we need. Please take a look at how stat.LinearRegression does this. The situation here is a little more complex, but not so much that we need all this. I think a pure function that returns the details that we need to perform predictions and to calculate scores from predictions. The model that we use for the solvers would be appropriate maybe.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got rid of a lot of the boiler plate. Let me know if this is more of what you were thinking. Took a look at the LinearRegression and the PrincipalComponent methods.

stat/statmat.go Outdated

ym, _ := y.Dims()
if ym != m {
panic(ErrTargetLenMismatch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use mat.ErrShape for this.

stat/statmat.go Outdated
Comment on lines 203 to 208
if x == nil {
panic(ErrNoTrainingMatrix)
}
if y == nil {
panic(ErrNoTargetMatrix)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can just panic with a nil pointer deref.

stat/statmat.go Outdated
Comment on lines 587 to 594
// SoftThreshold returns 0.0 if the value is less than or equal to the gamma input
func SoftThreshold(x, gamma float64) float64 {
res := math.Max(0, math.Abs(x)-gamma)
if math.Signbit(x) {
return -res
}
return res
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this needs to be exported.

func softThreshold(x, gamma float64) float64 {
	switch {
	case x < -gamma:
		return x + gamma
	case gamma < x:
		return x - gamma
	default:
		return 0
	}
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. will make it private

@aouyang1
Copy link
Author

The API is too complex here. I'd estimate that we could make this somewhere between 1/3 and 2/3 of the code that's here by removing the extraneous code.

Sounds good! Will take a crack at simplifying it and model something close to the LinearRegression method.

Copy link
Member

@kortschak kortschak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much more manageable.

Initial review only.

Comment on lines 338 to 345
flatten(
[][]float64{
{0, 0},
{3, 5},
{9, 20},
{12, 6},
},
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
flatten(
[][]float64{
{0, 0},
{3, 5},
{9, 20},
{12, 6},
},
),
[]float64{
0, 0,
3, 5,
9, 20,
12, 6,
},

(similar throughout and delete flatten)

intercept float64
coef []float64
}{
"invalid lambda": {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"invalid lambda": {
"invalid_lambda": {

(similar throughout; it simplifies finding cases)


func TestLassoRegression(t *testing.T) {
// y = 2 + 3*x0 + 4*x1
testData := map[string]struct {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't do this with a map. Use a struct with a name field. Also put the test cases in a global var, lassoRegressionTests, above the TestLassoRegression func decl.

Comment on lines 289 to 293
[]float64{
0.8, 0.3, 0.1,
0.3, 0.7, -0.1,
0.1, -0.1, 7}),
0.1, -0.1, 7,
}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave these formatting changes alone. Please revert this and the changes above in this file.

Comment on lines 462 to 471
func TestOLSRegression(t *testing.T) {
// y = 2 + 3*x0 + 4*x1
testData := map[string]struct {
x *mat.Dense
y *mat.Dense
model OLSModel
tol float64
intercept float64
coef []float64
}{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here.

@@ -321,6 +553,7 @@ func benchmarkCovarianceMatrix(b *testing.B, m mat.Matrix) {
CovarianceMatrix(&res, m, nil)
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert these formatting changes.

@@ -478,3 +718,65 @@ func BenchmarkCorrToCov(b *testing.B) {
corrToCov(cc, sigma)
}
}

func BenchmarkLassoRegression(b *testing.B) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're adding benchmarks, they will need to have a timer reset after the set-up, but also the set up doesn't need to do all the work that it is; it should work directly into x, and the matrix constructions should happen outside the benchmark loop.

}
}

func BenchmarkOLSRegression(b *testing.B) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants