-
-
Notifications
You must be signed in to change notification settings - Fork 25.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH add newton-cholesky solver to LogisticRegression #24767
ENH add newton-cholesky solver to LogisticRegression #24767
Conversation
if solver in ("liblinear", "newton-cholesky") and multi_class == "multinomial": | ||
pytest.skip("'multinomial' is not supported by liblinear and newton-cholesky") | ||
if solver == "newton-cholesky" and max_iter > 1: | ||
pytest.skip("solver newton-cholesky might converge very fast") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be curious to see a benchmark with n_samples >> n_features
on a pipeline with one-hot encoded categorical variables with 1/3 of rare yet predictive categories.
In particular compared to both "lbfgs"
, "newton-cg"
, "sag"
and "saga"
.
@ogrisel Do you know a good dataset? For binary classification, I'll run one benchmark with the kicks dataset. |
You can always threshold the y of the regression benchmark we used previously. |
Or you can try with Adult Census and some baseline |
With the kicks dataset, every point is a run with import warnings
from pathlib import Path
from time import perf_counter
import numpy as np
from sklearn._loss import HalfBinomialLoss
from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_openml
from sklearn.impute import SimpleImputer
from sklearn.linear_model._glm.glm import _GeneralizedLinearRegressor
from sklearn.linear_model._linear_loss import LinearModelLoss
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.exceptions import ConvergenceWarning
from sklearn.model_selection import train_test_split
import pandas as pd
import joblib
import matplotlib.pyplot as plt
def prepare_data():
df = fetch_openml(data_id=41162, as_frame=True, parser="auto").frame
linear_model_preprocessor = ColumnTransformer(
[
(
"passthrough_numeric",
make_pipeline(SimpleImputer(), StandardScaler()),
[
"MMRAcquisitionAuctionAveragePrice",
"MMRAcquisitionAuctionCleanPrice",
"MMRCurrentAuctionAveragePrice",
"MMRCurrentAuctionCleanPrice",
"MMRCurrentRetailAveragePrice",
"MMRCurrentRetailCleanPrice",
"MMRCurrentRetailAveragePrice",
"MMRCurrentRetailCleanPrice",
"VehBCost",
"VehOdo",
"VehYear",
"VehicleAge",
"WarrantyCost",
],
),
(
"onehot_categorical",
OneHotEncoder(min_frequency=10),
[
"Auction",
"Color",
"IsOnlineSale",
"Make",
"Model",
"Nationality",
"Size",
"SubModel",
"Transmission",
"Trim",
"WheelType",
],
),
],
remainder="drop",
)
y = np.asarray(df["IsBadBuy"] == "1", dtype=float)
X = linear_model_preprocessor.fit_transform(df)
return X, y
X, y = prepare_data()
# X = X.toarray()
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=0.9, random_state=0
)
print(f"{X_train.shape = }")
results = []
loss_sw = np.full_like(y_train, fill_value=(1. / y_train.shape[0]))
alpha = 1e-4
for tol in np.logspace(-1, -10, 10):
for solver in ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"]:
tic = perf_counter()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConvergenceWarning)
reg = LogisticRegression(
solver=solver, C=1 / alpha / X.shape[0], tol=tol, max_iter=10_000
).fit(X_train, y_train)
toc = perf_counter()
train_time = toc - tic
train_loss = LinearModelLoss(
base_loss=HalfBinomialLoss(), fit_intercept=reg.fit_intercept
).loss(
coef=np.r_[np.squeeze(reg.coef_), reg.intercept_],
X=X_train,
y=y_train,
l2_reg_strength=alpha,
sample_weight=loss_sw,
)
result = {
"solver": solver,
"tol": tol,
"train_loss": train_loss,
"train_time": train_time,
"train_score": reg.score(X_train, y_train),
"test_score": reg.score(X_test, y_test),
"n_iter": np.squeeze(reg.n_iter_),
"converged": np.squeeze(reg.n_iter_) < np.squeeze(reg.max_iter),
}
print(result)
results.append(result)
results = pd.DataFrame.from_records(results)
results["suboptimality"] = results["train_loss"] - results["train_loss"].min() + 1e-15
fig, ax = plt.subplots(figsize=(8, 6))
for label, group in results.groupby("solver"):
ax.plot(group.n_iter, group.suboptimality, marker="o", label=label)
ax.set_xscale("log")
ax.set_yscale("log")
ax.legend()
ax.set_xlabel("n_iter")
ax.set_ylabel("suboptimality")
ax.set_title("Suboptimality by iterations, sparse X")
fig, ax = plt.subplots(figsize=(8, 6))
for label, group in results.groupby("solver"):
ax.plot(group.train_time, group.suboptimality, marker="o", label=label)
ax.set_xscale("log")
ax.set_yscale("log")
ax.legend()
ax.set_xlabel("train_time")
ax.set_ylabel("suboptimality")
ax.set_title("Suboptimality by time, sparse X") |
I suspect that using |
In your benchmark you used |
Do you want to see more benchmarks, or is it enough? |
I would like to have a better understanding under which practical conditions the lbfgs solver dramatically fails. I suppose that it happens when the rank of the Hessian is significantly larger than the memory size of LBFGS, yet with an ill-conditioned spectrum and this can probably happen with those rare categorical variables but maybe also with other kinds of common feature engineering strategies. |
One has to distinguished 2 points:
|
Out of curiosity, I just re-ran a quick bench to test the behavior of the new solver on a problem derived from the In [30]: from sklearn.datasets import fetch_covtype
...: from sklearn.pipeline import make_pipeline
...: from sklearn.model_selection import train_test_split
...: from sklearn.preprocessing import MinMaxScaler
...: from sklearn.kernel_approximation import PolynomialCountSketch
...: from sklearn.linear_model import LogisticRegression
...:
...: X, y = fetch_covtype(return_X_y=True)
...: y = y == 2
...: pipe = make_pipeline(
...: MinMaxScaler(),
...: PolynomialCountSketch(degree=2, n_components=1000, random_state=0),
...: LogisticRegression(max_iter=1000, solver="lbfgs"),
...: )
...: X_train, X_test, y_train, y_test = train_test_split(
...: X, y, train_size=50000, test_size=1000, random_state=42
...: )
...: %time pipe.fit(X_train, y_train).score(X_test, y_test)
CPU times: user 2min 17s, sys: 11 s, total: 2min 28s
Wall time: 23 s
Out[30]: 0.808
In [31]: from sklearn.datasets import fetch_covtype
...: from sklearn.pipeline import make_pipeline
...: from sklearn.model_selection import train_test_split
...: from sklearn.preprocessing import MinMaxScaler
...: from sklearn.kernel_approximation import PolynomialCountSketch
...: from sklearn.linear_model import LogisticRegression
...:
...: X, y = fetch_covtype(return_X_y=True)
...: y = y == 2
...: pipe = make_pipeline(
...: MinMaxScaler(),
...: PolynomialCountSketch(degree=2, n_components=1000, random_state=0),
...: LogisticRegression(max_iter=10, solver="newton-cholesky"),
...: )
...: X_train, X_test, y_train, y_test = train_test_split(
...: X, y, train_size=50000, test_size=1000, random_state=42
...: )
...: %time pipe.fit(X_train, y_train).score(X_test, y_test)
CPU times: user 42.8 s, sys: 3.68 s, total: 46.5 s
Wall time: 8.5 s
Out[31]: 0.808 So the new solver converges both faster and below the default |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you, @lorentzenchr.
Here are some minor suggestions.
if solver in ("sag", "saga"): | ||
clf.set_params(tol=1e-5, max_iter=10000, random_state=0) | ||
clf.fit(X, y) | ||
assert_array_almost_equal(clf.coef_, clf_lbfgs.coef_, decimal=4) | ||
|
||
|
||
def test_logistic_regression_sample_weights(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is long and assert two things:
- the solvers' results' consistency (before l.747)
- results' consistency when dual=True for l1 penalty and l2 penalty independently (after l.747)
Would it make sense to split this test into two tests for 1. and 2. respectively? Also can the logic of 1. (and 2.) be simplified using test parametrisation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 x yes. I would suggest to do it in another PR in order to not bloat this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a quick comment, do not forget to update the solver
section from the User Guide: https://scikit-learn.org/dev/modules/linear_model.html#solvers
I don't think this is automatically updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor doc and styling changes.
if solver in ("sag", "saga"): | ||
clf.set_params(tol=1e-5, max_iter=10000, random_state=0) | ||
clf.fit(X, y) | ||
assert_array_almost_equal(clf.coef_, clf_lbfgs.coef_, decimal=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to use assert_allclose
with a tolerance instead of assert_array_almost_equal
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. We could open an issue for a sprint to replace assert_array_almost_equal
by assert_allclose
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, I wanted to keep the change small.
As you want. Sent from my iPhoneOn 3 Nov 2022, at 12:42, Christian Lorentzen ***@***.***> wrote:
@lorentzenchr commented on this pull request.
In sklearn/linear_model/tests/test_logistic.py:
@@ -121,16 +123,9 @@ def test_predict_3_classes():
check_predictions(LogisticRegression(C=10), X_sp, Y2)
…-def test_predict_iris():
- # Test logistic regression with the iris dataset
- n_samples, n_features = iris.data.shape
-
- target = iris.target_names[iris.target]
-
- # Test that both multinomial and OvR solvers handle
- # multiclass data correctly and give good accuracy
- # score (>0.95) for the training data.
- for clf in [
***@***.***(
What is better: 1) Filter the ConvergenceWarning or 2) increase max_iter in lbfgs?
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you commented.Message ID: ***@***.***>
|
@glemaitre Thanks for the 3rd pair of eyes, which just spot more than 2 pairs 😄 |
Done. If that part is ok and CI green, it's ready again, I guess. |
The error is unrelated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Merging Thanks @lorentzenchr
@lorentzenchr you might be interested in this work to generalize this solver to non-smooth penalized models: https://gbareilles.fr/software/#additive_nonsmooth_minimization |
Reference Issues/PRs
Completes #16634.
Follow-up of #24637.
What does this implement/fix? Explain your changes.
This adds the solver
"newton-cholesky"
to the classesLogisticRegression
andLogisticRegressionCV
.Any other comments?
For multiclass problems, it uses a one-vs-rest strategy.