-
Notifications
You must be signed in to change notification settings - Fork 546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
statmat: added multi linear and lasso regression #1998
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've taken a brief look through this. There is more API here that I think is necessary. Also there are error returns that do not conform to the approach that we use in Gonum packages; we use error returns for error conditions that the user could not know before calling a function, but panics for cases where the calling parameters do not conform to the documented invariants for the call. Please match this.
I'll take a deeper look in the next week or so.
Thanks for explaining the differences! Will get the errors addressed and converted. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The API is too complex here. I'd estimate that we could make this somewhere between 1/3 and 2/3 of the code that's here by removing the extraneous code.
stat/statmat.go
Outdated
// Validate runs basic validation on OLS options | ||
func (o *OLSOptions) Validate() *OLSOptions { | ||
if o == nil { | ||
o = NewDefaultOLSOptions() | ||
} | ||
|
||
return o | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method doesn't do what is on the tin. Validate implies that it checks that it is correct, this ensures that it is correct. I cannot think of another example where we do something like this in Gonum packages.
stat/statmat.go
Outdated
// NewDefaultOLSOptions returns a default set of OLS Regression options | ||
func NewDefaultOLSOptions() *OLSOptions { | ||
return &OLSOptions{ | ||
FitIntercept: true, | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like more API than we need. Please take a look at how stat.LinearRegression
does this. The situation here is a little more complex, but not so much that we need all this. I think a pure function that returns the details that we need to perform predictions and to calculate scores from predictions. The model that we use for the solvers would be appropriate maybe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got rid of a lot of the boiler plate. Let me know if this is more of what you were thinking. Took a look at the LinearRegression and the PrincipalComponent methods.
stat/statmat.go
Outdated
|
||
ym, _ := y.Dims() | ||
if ym != m { | ||
panic(ErrTargetLenMismatch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use mat.ErrShape
for this.
stat/statmat.go
Outdated
if x == nil { | ||
panic(ErrNoTrainingMatrix) | ||
} | ||
if y == nil { | ||
panic(ErrNoTargetMatrix) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These can just panic with a nil pointer deref.
stat/statmat.go
Outdated
// SoftThreshold returns 0.0 if the value is less than or equal to the gamma input | ||
func SoftThreshold(x, gamma float64) float64 { | ||
res := math.Max(0, math.Abs(x)-gamma) | ||
if math.Signbit(x) { | ||
return -res | ||
} | ||
return res | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this needs to be exported.
func softThreshold(x, gamma float64) float64 {
switch {
case x < -gamma:
return x + gamma
case gamma < x:
return x - gamma
default:
return 0
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. will make it private
Sounds good! Will take a crack at simplifying it and model something close to the LinearRegression method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks much more manageable.
Initial review only.
stat/statmat_test.go
Outdated
flatten( | ||
[][]float64{ | ||
{0, 0}, | ||
{3, 5}, | ||
{9, 20}, | ||
{12, 6}, | ||
}, | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flatten( | |
[][]float64{ | |
{0, 0}, | |
{3, 5}, | |
{9, 20}, | |
{12, 6}, | |
}, | |
), | |
[]float64{ | |
0, 0, | |
3, 5, | |
9, 20, | |
12, 6, | |
}, |
(similar throughout and delete flatten
)
stat/statmat_test.go
Outdated
intercept float64 | ||
coef []float64 | ||
}{ | ||
"invalid lambda": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"invalid lambda": { | |
"invalid_lambda": { |
(similar throughout; it simplifies finding cases)
stat/statmat_test.go
Outdated
|
||
func TestLassoRegression(t *testing.T) { | ||
// y = 2 + 3*x0 + 4*x1 | ||
testData := map[string]struct { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't do this with a map. Use a struct with a name field. Also put the test cases in a global var, lassoRegressionTests
, above the TestLassoRegression
func decl.
stat/statmat_test.go
Outdated
[]float64{ | ||
0.8, 0.3, 0.1, | ||
0.3, 0.7, -0.1, | ||
0.1, -0.1, 7}), | ||
0.1, -0.1, 7, | ||
}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's leave these formatting changes alone. Please revert this and the changes above in this file.
stat/statmat_test.go
Outdated
func TestOLSRegression(t *testing.T) { | ||
// y = 2 + 3*x0 + 4*x1 | ||
testData := map[string]struct { | ||
x *mat.Dense | ||
y *mat.Dense | ||
model OLSModel | ||
tol float64 | ||
intercept float64 | ||
coef []float64 | ||
}{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar here.
stat/statmat_test.go
Outdated
@@ -321,6 +553,7 @@ func benchmarkCovarianceMatrix(b *testing.B, m mat.Matrix) { | |||
CovarianceMatrix(&res, m, nil) | |||
} | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert these formatting changes.
@@ -478,3 +718,65 @@ func BenchmarkCorrToCov(b *testing.B) { | |||
corrToCov(cc, sigma) | |||
} | |||
} | |||
|
|||
func BenchmarkLassoRegression(b *testing.B) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're adding benchmarks, they will need to have a timer reset after the set-up, but also the set up doesn't need to do all the work that it is; it should work directly into x
, and the matrix constructions should happen outside the benchmark loop.
} | ||
} | ||
|
||
func BenchmarkOLSRegression(b *testing.B) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
Hello! First time contributor here. I've been working on some multi-linear regression work lately and wanted to see if some of that could also be integrated here. This only partially solves some of the features requested in #1865 (MISO) but can extend it to more if we want to in this PR. I tried my best to follow similar conventions that I saw in the existing modules, but let me know what I can change to make it more similar.
The 2 new regression structs are OLSRegression and LassoRegression.
Each implement the following:
Inspired by the Python sklearn interface