Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add newton-cholesky solver to LogisticRegression #24767

Merged
merged 14 commits into from
Nov 3, 2022

Conversation

lorentzenchr
Copy link
Member

Reference Issues/PRs

Completes #16634.
Follow-up of #24637.

What does this implement/fix? Explain your changes.

This adds the solver "newton-cholesky" to the classes LogisticRegression and LogisticRegressionCV.

Any other comments?

For multiclass problems, it uses a one-vs-rest strategy.

if solver in ("liblinear", "newton-cholesky") and multi_class == "multinomial":
pytest.skip("'multinomial' is not supported by liblinear and newton-cholesky")
if solver == "newton-cholesky" and max_iter > 1:
pytest.skip("solver newton-cholesky might converge very fast")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent!

@lorentzenchr lorentzenchr added this to the 1.2 milestone Oct 27, 2022
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be curious to see a benchmark with n_samples >> n_features on a pipeline with one-hot encoded categorical variables with 1/3 of rare yet predictive categories.

In particular compared to both "lbfgs", "newton-cg", "sag" and "saga".

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Oct 27, 2022

@ogrisel Do you know a good dataset? For binary classification, I'll run one benchmark with the kicks dataset.

@ogrisel
Copy link
Member

ogrisel commented Oct 27, 2022

You can always threshold the y of the regression benchmark we used previously.

@ogrisel
Copy link
Member

ogrisel commented Oct 27, 2022

Or you can try with Adult Census and some baseline ColumnTransformer with a OneHotEncoder, it's not very big but maybe it's enough to see an improvement.

@lorentzenchr
Copy link
Member Author

With the kicks dataset, every point is a run with tol=1e-x with x from -1 to -10.

image
image

import warnings
from pathlib import Path
from time import perf_counter
import numpy as np
from sklearn._loss import HalfBinomialLoss
from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_openml
from sklearn.impute import SimpleImputer
from sklearn.linear_model._glm.glm import _GeneralizedLinearRegressor
from sklearn.linear_model._linear_loss import LinearModelLoss
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.exceptions import ConvergenceWarning
from sklearn.model_selection import train_test_split
import pandas as pd
import joblib
import matplotlib.pyplot as plt



def prepare_data():
    df = fetch_openml(data_id=41162, as_frame=True, parser="auto").frame
    linear_model_preprocessor = ColumnTransformer(
        [
            (
                "passthrough_numeric",
                make_pipeline(SimpleImputer(), StandardScaler()),
                [
                    "MMRAcquisitionAuctionAveragePrice",
                    "MMRAcquisitionAuctionCleanPrice",
                    "MMRCurrentAuctionAveragePrice",
                    "MMRCurrentAuctionCleanPrice",
                    "MMRCurrentRetailAveragePrice",
                    "MMRCurrentRetailCleanPrice",
                    "MMRCurrentRetailAveragePrice",
                    "MMRCurrentRetailCleanPrice",
                    "VehBCost",
                    "VehOdo",
                    "VehYear",
                    "VehicleAge",
                    "WarrantyCost",
                ],
            ),
            (
                "onehot_categorical",
                OneHotEncoder(min_frequency=10),
                [
                    "Auction",
                    "Color",
                    "IsOnlineSale",
                    "Make",
                    "Model",
                    "Nationality",
                    "Size",
                    "SubModel",
                    "Transmission",
                    "Trim",
                    "WheelType",
                ],
            ),
        ],
        remainder="drop",
    )
    y = np.asarray(df["IsBadBuy"] == "1", dtype=float)
    X = linear_model_preprocessor.fit_transform(df)
    return X, y


X, y = prepare_data()
# X = X.toarray()
X_train, X_test, y_train, y_test = train_test_split(
    X, y, train_size=0.9, random_state=0
)
print(f"{X_train.shape = }")

results = []
loss_sw = np.full_like(y_train, fill_value=(1. / y_train.shape[0]))
alpha = 1e-4
for tol in np.logspace(-1, -10, 10):
    for solver in ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"]:
        tic = perf_counter()
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=ConvergenceWarning)
            reg = LogisticRegression(
                solver=solver, C=1 / alpha / X.shape[0], tol=tol, max_iter=10_000
            ).fit(X_train, y_train)
        toc = perf_counter()
        train_time = toc - tic
        train_loss = LinearModelLoss(
            base_loss=HalfBinomialLoss(), fit_intercept=reg.fit_intercept
        ).loss(
            coef=np.r_[np.squeeze(reg.coef_), reg.intercept_],
            X=X_train,
            y=y_train,
            l2_reg_strength=alpha,
            sample_weight=loss_sw,
        )
        result = {
            "solver": solver,
            "tol": tol,
            "train_loss": train_loss,
            "train_time": train_time,
            "train_score": reg.score(X_train, y_train),
            "test_score": reg.score(X_test, y_test),
            "n_iter": np.squeeze(reg.n_iter_),
            "converged": np.squeeze(reg.n_iter_) < np.squeeze(reg.max_iter),
        }
        print(result)
        results.append(result)

results = pd.DataFrame.from_records(results)

results["suboptimality"] = results["train_loss"] - results["train_loss"].min() + 1e-15

fig, ax = plt.subplots(figsize=(8, 6))
for label, group in results.groupby("solver"):
    ax.plot(group.n_iter, group.suboptimality, marker="o", label=label)
ax.set_xscale("log")
ax.set_yscale("log")
ax.legend()
ax.set_xlabel("n_iter")
ax.set_ylabel("suboptimality")
ax.set_title("Suboptimality by iterations, sparse X")


fig, ax = plt.subplots(figsize=(8, 6))
for label, group in results.groupby("solver"):
    ax.plot(group.train_time, group.suboptimality, marker="o", label=label)
ax.set_xscale("log")
ax.set_yscale("log")
ax.legend()
ax.set_xlabel("train_time")
ax.set_ylabel("suboptimality")
ax.set_title("Suboptimality by time, sparse X")

@ogrisel
Copy link
Member

ogrisel commented Oct 27, 2022

I suspect that using SplineTransformer on numerical features and/or PolynomialCountSketch on all the features can make the problem harder to optimize. The latter was reported in: #15583 (comment)

@ogrisel
Copy link
Member

ogrisel commented Oct 27, 2022

In your benchmark you used OneHotEncoder(min_frequency=10) but I suspect that some solvers will suffer more from a lower value for min_frequency (although I agree that the derived rare one-hot features will not help the model generalize better).

@lorentzenchr
Copy link
Member Author

Do you want to see more benchmarks, or is it enough?

@ogrisel
Copy link
Member

ogrisel commented Oct 27, 2022

Do you want to see more benchmarks, or is it enough?

I would like to have a better understanding under which practical conditions the lbfgs solver dramatically fails. I suppose that it happens when the rank of the Hessian is significantly larger than the memory size of LBFGS, yet with an ill-conditioned spectrum and this can probably happen with those rare categorical variables but maybe also with other kinds of common feature engineering strategies.

@lorentzenchr
Copy link
Member Author

One has to distinguished 2 points:

@ogrisel
Copy link
Member

ogrisel commented Oct 30, 2022

Out of curiosity, I just re-ran a quick bench to test the behavior of the new solver on a problem derived from the PolynomialCountSketch example from the 0.24 release highlights where we observed the convergence warning with LBFGS.

In [30]: from sklearn.datasets import fetch_covtype
    ...: from sklearn.pipeline import make_pipeline
    ...: from sklearn.model_selection import train_test_split
    ...: from sklearn.preprocessing import MinMaxScaler
    ...: from sklearn.kernel_approximation import PolynomialCountSketch
    ...: from sklearn.linear_model import LogisticRegression
    ...: 
    ...: X, y = fetch_covtype(return_X_y=True)
    ...: y = y == 2
    ...: pipe = make_pipeline(
    ...:     MinMaxScaler(),
    ...:     PolynomialCountSketch(degree=2, n_components=1000, random_state=0),
    ...:     LogisticRegression(max_iter=1000, solver="lbfgs"),
    ...: )
    ...: X_train, X_test, y_train, y_test = train_test_split(
    ...:     X, y, train_size=50000, test_size=1000, random_state=42
    ...: )
    ...: %time pipe.fit(X_train, y_train).score(X_test, y_test)
CPU times: user 2min 17s, sys: 11 s, total: 2min 28s
Wall time: 23 s
Out[30]: 0.808

In [31]: from sklearn.datasets import fetch_covtype
    ...: from sklearn.pipeline import make_pipeline
    ...: from sklearn.model_selection import train_test_split
    ...: from sklearn.preprocessing import MinMaxScaler
    ...: from sklearn.kernel_approximation import PolynomialCountSketch
    ...: from sklearn.linear_model import LogisticRegression
    ...: 
    ...: X, y = fetch_covtype(return_X_y=True)
    ...: y = y == 2
    ...: pipe = make_pipeline(
    ...:     MinMaxScaler(),
    ...:     PolynomialCountSketch(degree=2, n_components=1000, random_state=0),
    ...:     LogisticRegression(max_iter=10, solver="newton-cholesky"),
    ...: )
    ...: X_train, X_test, y_train, y_test = train_test_split(
    ...:     X, y, train_size=50000, test_size=1000, random_state=42
    ...: )
    ...: %time pipe.fit(X_train, y_train).score(X_test, y_test)
CPU times: user 42.8 s, sys: 3.68 s, total: 46.5 s
Wall time: 8.5 s
Out[31]: 0.808

So the new solver converges both faster and below the default n_iter=100 (even lower than n_iter=10 in this case) without any convergence warning.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

sklearn/linear_model/tests/test_logistic.py Outdated Show resolved Hide resolved
@jjerphan jjerphan self-requested a review October 31, 2022 14:56
Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you, @lorentzenchr.

Here are some minor suggestions.

sklearn/linear_model/_logistic.py Outdated Show resolved Hide resolved
sklearn/linear_model/tests/test_logistic.py Outdated Show resolved Hide resolved
if solver in ("sag", "saga"):
clf.set_params(tol=1e-5, max_iter=10000, random_state=0)
clf.fit(X, y)
assert_array_almost_equal(clf.coef_, clf_lbfgs.coef_, decimal=4)


def test_logistic_regression_sample_weights():
Copy link
Member

@jjerphan jjerphan Nov 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is long and assert two things:

  1. the solvers' results' consistency (before l.747)
  2. results' consistency when dual=True for l1 penalty and l2 penalty independently (after l.747)

Would it make sense to split this test into two tests for 1. and 2. respectively? Also can the logic of 1. (and 2.) be simplified using test parametrisation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 x yes. I would suggest to do it in another PR in order to not bloat this one.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a quick comment, do not forget to update the solver section from the User Guide: https://scikit-learn.org/dev/modules/linear_model.html#solvers

I don't think this is automatically updated.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor doc and styling changes.

sklearn/linear_model/_glm/glm.py Outdated Show resolved Hide resolved
sklearn/linear_model/_glm/glm.py Outdated Show resolved Hide resolved
sklearn/linear_model/_glm/glm.py Outdated Show resolved Hide resolved
sklearn/linear_model/_glm/glm.py Outdated Show resolved Hide resolved
sklearn/linear_model/_logistic.py Outdated Show resolved Hide resolved
sklearn/linear_model/_logistic.py Outdated Show resolved Hide resolved
sklearn/linear_model/tests/test_logistic.py Outdated Show resolved Hide resolved
sklearn/linear_model/tests/test_logistic.py Show resolved Hide resolved
if solver in ("sag", "saga"):
clf.set_params(tol=1e-5, max_iter=10000, random_state=0)
clf.fit(X, y)
assert_array_almost_equal(clf.coef_, clf_lbfgs.coef_, decimal=4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to use assert_allclose with a tolerance instead of assert_array_almost_equal.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We could open an issue for a sprint to replace assert_array_almost_equal by assert_allclose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I wanted to keep the change small.

sklearn/linear_model/tests/test_logistic.py Outdated Show resolved Hide resolved
@glemaitre
Copy link
Member

glemaitre commented Nov 3, 2022 via email

@lorentzenchr
Copy link
Member Author

@glemaitre Thanks for the 3rd pair of eyes, which just spot more than 2 pairs 😄

@lorentzenchr
Copy link
Member Author

As a quick comment, do not forget to update the solver section from the User Guide: https://scikit-learn.org/dev/modules/linear_model.html#solvers

Done. If that part is ok and CI green, it's ready again, I guess.
Note to myself: the UG on linear model solvers could need some updates.

@glemaitre
Copy link
Member

The error is unrelated.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Merging Thanks @lorentzenchr

@glemaitre glemaitre merged commit bb080aa into scikit-learn:main Nov 3, 2022
@lorentzenchr lorentzenchr deleted the logistic_newton_cholesky branch November 3, 2022 13:40
andportnoy pushed a commit to andportnoy/scikit-learn that referenced this pull request Nov 5, 2022
@ogrisel
Copy link
Member

ogrisel commented Dec 6, 2022

@lorentzenchr you might be interested in this work to generalize this solver to non-smooth penalized models: https://gbareilles.fr/software/#additive_nonsmooth_minimization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants