From 3b5f648fd093428a4b4178940f742054554587b1 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sat, 4 Jan 2025 01:59:17 -0600 Subject: [PATCH 1/6] [python-package] make sub-classing scikit-learn estimators easier --- python-package/lightgbm/sklearn.py | 47 ++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 108ef1e14498..9f7d91630eb1 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -488,6 +488,7 @@ class LGBMModel(_LGBMModelBase): def __init__( self, + *, boosting_type: str = "gbdt", num_leaves: int = 31, max_depth: int = -1, @@ -627,6 +628,7 @@ def __init__( For multi-class task, y_pred is a numpy 2-D array of shape = [n_samples, n_classes], and grad and hess should be returned in the same format. """ + print("LGBMModel.__init__()") if not SKLEARN_INSTALLED: raise LightGBMError( "scikit-learn is required for lightgbm.sklearn. " @@ -745,7 +747,36 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]: params : dict Parameter names mapped to their values. """ + # Based on: https://github.com/dmlc/xgboost/blob/bd92b1c9c0db3e75ec3dfa513e1435d518bb535d/python-package/xgboost/sklearn.py#L941 + # which was based on: https://stackoverflow.com/questions/59248211 + # + # `get_params()` flows like this: + # + # 0. Return parameters in subclass (self.__class__) first, by using inspect. + # 1. Return parameters in all parent classes (especially `LGBMModel`). + # 2. Return whatever is in `**kwargs`. + # 3. Merge them. + # + # This needs to accommodate being called recursively in the following + # inheritance graphs (and similar for classification and ranking): + # + # DaskLGBMRegressor -> LGBMRegressor -> LGBMModel -> BaseEstimator + # (custom subclass) -> LGBMRegressor -> LGBMModel -> BaseEstimator + # LGBMRegressor -> LGBMModel -> BaseEstimator + # (custom subclass) -> LGBMModel -> BaseEstimator + # LGBMModel -> BaseEstimator + # params = super().get_params(deep=deep) + cp = copy.copy(self) + print(f"--- {cp.__class__.__bases__}") + # If the immediate parent defines get_params(), use that. + if callable(getattr(cp.__class__.__bases__[0], "get_params", None)): + cp.__class__ = cp.__class__.__bases__[0] + # Otherwise, skip it and assume the next class will have it. + # This is here primarily for cases where the first class in MRO is a scikit-learn mixin. + else: + cp.__class__ = cp.__class__.__bases__[1] + params.update(cp.__class__.get_params(cp, deep)) params.update(self._other_params) return params @@ -1285,6 +1316,12 @@ def feature_names_in_(self) -> None: class LGBMRegressor(_LGBMRegressorBase, LGBMModel): """LightGBM regressor.""" + def __init__(self, **kwargs: Any): + print("LGBMRegressor.__init__()") + super().__init__(**kwargs) + + __init__.__doc__ = LGBMModel.__init__.__doc__ + def _more_tags(self) -> Dict[str, Any]: # handle the case where RegressorMixin possibly provides _more_tags() if callable(getattr(_LGBMRegressorBase, "_more_tags", None)): @@ -1344,6 +1381,11 @@ def fit( # type: ignore[override] class LGBMClassifier(_LGBMClassifierBase, LGBMModel): """LightGBM classifier.""" + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + __init__.__doc__ = LGBMModel.__init__.__doc__ + def _more_tags(self) -> Dict[str, Any]: # handle the case where ClassifierMixin possibly provides _more_tags() if callable(getattr(_LGBMClassifierBase, "_more_tags", None)): @@ -1554,6 +1596,11 @@ class LGBMRanker(LGBMModel): Please use this class mainly for training and applying ranking models in common sklearnish way. """ + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + __init__.__doc__ = LGBMModel.__init__.__doc__ + def fit( # type: ignore[override] self, X: _LGBM_ScikitMatrixLike, From 02c48c3e474bde0d6ac1ea42beca179ed7ca1305 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sat, 4 Jan 2025 02:08:53 -0600 Subject: [PATCH 2/6] tests passing --- python-package/lightgbm/dask.py | 129 ++--------------------------- python-package/lightgbm/sklearn.py | 5 ++ 2 files changed, 11 insertions(+), 123 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index dcdacba7366c..76285fde183a 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -1115,52 +1115,13 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): def __init__( self, - boosting_type: str = "gbdt", - num_leaves: int = 31, - max_depth: int = -1, - learning_rate: float = 0.1, - n_estimators: int = 100, - subsample_for_bin: int = 200000, - objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, - class_weight: Optional[Union[dict, str]] = None, - min_split_gain: float = 0.0, - min_child_weight: float = 1e-3, - min_child_samples: int = 20, - subsample: float = 1.0, - subsample_freq: int = 0, - colsample_bytree: float = 1.0, - reg_alpha: float = 0.0, - reg_lambda: float = 0.0, - random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None, - n_jobs: Optional[int] = None, - importance_type: str = "split", + *, client: Optional[Client] = None, **kwargs: Any, ): """Docstring is inherited from the lightgbm.LGBMClassifier.__init__.""" self.client = client - super().__init__( - boosting_type=boosting_type, - num_leaves=num_leaves, - max_depth=max_depth, - learning_rate=learning_rate, - n_estimators=n_estimators, - subsample_for_bin=subsample_for_bin, - objective=objective, - class_weight=class_weight, - min_split_gain=min_split_gain, - min_child_weight=min_child_weight, - min_child_samples=min_child_samples, - subsample=subsample, - subsample_freq=subsample_freq, - colsample_bytree=colsample_bytree, - reg_alpha=reg_alpha, - reg_lambda=reg_lambda, - random_state=random_state, - n_jobs=n_jobs, - importance_type=importance_type, - **kwargs, - ) + super().__init__(**kwargs) _base_doc = LGBMClassifier.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs") # type: ignore @@ -1318,52 +1279,13 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): def __init__( self, - boosting_type: str = "gbdt", - num_leaves: int = 31, - max_depth: int = -1, - learning_rate: float = 0.1, - n_estimators: int = 100, - subsample_for_bin: int = 200000, - objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, - class_weight: Optional[Union[dict, str]] = None, - min_split_gain: float = 0.0, - min_child_weight: float = 1e-3, - min_child_samples: int = 20, - subsample: float = 1.0, - subsample_freq: int = 0, - colsample_bytree: float = 1.0, - reg_alpha: float = 0.0, - reg_lambda: float = 0.0, - random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None, - n_jobs: Optional[int] = None, - importance_type: str = "split", + *, client: Optional[Client] = None, **kwargs: Any, ): """Docstring is inherited from the lightgbm.LGBMRegressor.__init__.""" self.client = client - super().__init__( - boosting_type=boosting_type, - num_leaves=num_leaves, - max_depth=max_depth, - learning_rate=learning_rate, - n_estimators=n_estimators, - subsample_for_bin=subsample_for_bin, - objective=objective, - class_weight=class_weight, - min_split_gain=min_split_gain, - min_child_weight=min_child_weight, - min_child_samples=min_child_samples, - subsample=subsample, - subsample_freq=subsample_freq, - colsample_bytree=colsample_bytree, - reg_alpha=reg_alpha, - reg_lambda=reg_lambda, - random_state=random_state, - n_jobs=n_jobs, - importance_type=importance_type, - **kwargs, - ) + super().__init__(**kwargs) _base_doc = LGBMRegressor.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs") # type: ignore @@ -1485,52 +1407,13 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): def __init__( self, - boosting_type: str = "gbdt", - num_leaves: int = 31, - max_depth: int = -1, - learning_rate: float = 0.1, - n_estimators: int = 100, - subsample_for_bin: int = 200000, - objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, - class_weight: Optional[Union[dict, str]] = None, - min_split_gain: float = 0.0, - min_child_weight: float = 1e-3, - min_child_samples: int = 20, - subsample: float = 1.0, - subsample_freq: int = 0, - colsample_bytree: float = 1.0, - reg_alpha: float = 0.0, - reg_lambda: float = 0.0, - random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None, - n_jobs: Optional[int] = None, - importance_type: str = "split", + *, client: Optional[Client] = None, **kwargs: Any, ): """Docstring is inherited from the lightgbm.LGBMRanker.__init__.""" self.client = client - super().__init__( - boosting_type=boosting_type, - num_leaves=num_leaves, - max_depth=max_depth, - learning_rate=learning_rate, - n_estimators=n_estimators, - subsample_for_bin=subsample_for_bin, - objective=objective, - class_weight=class_weight, - min_split_gain=min_split_gain, - min_child_weight=min_child_weight, - min_child_samples=min_child_samples, - subsample=subsample, - subsample_freq=subsample_freq, - colsample_bytree=colsample_bytree, - reg_alpha=reg_alpha, - reg_lambda=reg_lambda, - random_state=random_state, - n_jobs=n_jobs, - importance_type=importance_type, - **kwargs, - ) + super().__init__(**kwargs) _base_doc = LGBMRanker.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs") # type: ignore diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 9f7d91630eb1..a484d991b96a 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -1317,6 +1317,11 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel): """LightGBM regressor.""" def __init__(self, **kwargs: Any): + # TODO(jameslamb): add links to PR: + # - https://github.com/microsoft/LightGBM/issues/4426 + # - https://stackoverflow.com/questions/40025406/inherit-from-scikit-learns-lassocv-model/40027200#40027200 + # - https://stackoverflow.com/questions/79320289/why-cant-i-wrap-lgbm + # - https://github.com/dmlc/xgboost/blob/bd92b1c9c0db3e75ec3dfa513e1435d518bb535d/python-package/xgboost/sklearn.py#L941 print("LGBMRegressor.__init__()") super().__init__(**kwargs) From 7b720cb509a06fd97c3aacf44c23879991c9d690 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 7 Jan 2025 00:26:43 -0600 Subject: [PATCH 3/6] add docs --- docs/FAQ.rst | 39 +++++ python-package/lightgbm/dask.py | 1 - python-package/lightgbm/sklearn.py | 9 +- tests/python_package_test/test_sklearn.py | 189 +++++++++++++++++++++- 4 files changed, 230 insertions(+), 8 deletions(-) diff --git a/docs/FAQ.rst b/docs/FAQ.rst index 14c7f7dd7265..6f8b71378ddf 100644 --- a/docs/FAQ.rst +++ b/docs/FAQ.rst @@ -377,3 +377,42 @@ We strongly recommend installation from the ``conda-forge`` channel and not from For some specific examples, see `this comment `__. In addition, as of ``lightgbm==4.4.0``, the ``conda-forge`` package automatically supports CUDA-based GPU acceleration. + +5. How do I subclass ``scikit-learn`` estimators? +------------------------------------------------- + +For ``lightgbm <= 4.5.0``, copy all of the constructor arguments from the corresponding +``lightgbm`` class into the constructor of your custom estimator. + +For later versions, just ensure that the constructor of your custom estimator calls ``super().__init__()``. + +Consider the example below, which implements a regressor that allows creation of truncated predictions. +This pattern will work with ``lightgbm > 4.5.0``. + +.. code-block:: python + + import numpy as np + from lightgbm import LGBMRegressor + from sklearn.datasets import make_regression + + class TruncatedRegressor(LGBMRegressor): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def predict(self, X, max_score: float = np.inf): + preds = super().predict(X) + preds[np.where(preds > max_score)] = max_score + return preds + + X, y = make_regression(n_samples=1_000, n_features=4) + + reg_trunc = TruncatedRegressor().fit(X, y) + + preds = reg_trunc.predict(X) + print(f"mean: {preds.mean():.2f}, max: {preds.max():.2f}") + # mean: -6.81, max: 345.10 + + preds_trunc = reg_trunc.predict(X, max_score = preds.mean()) + print(f"mean: {preds_trunc.mean():.2f}, max: {preds_trunc.max():.2f}") + # mean: -56.50, max: -6.81 diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 76285fde183a..cd1648c8c1cf 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -44,7 +44,6 @@ LGBMModel, LGBMRanker, LGBMRegressor, - _LGBM_ScikitCustomObjectiveFunction, _LGBM_ScikitEvalMetricType, _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index a484d991b96a..87c8e680476b 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -628,7 +628,6 @@ def __init__( For multi-class task, y_pred is a numpy 2-D array of shape = [n_samples, n_classes], and grad and hess should be returned in the same format. """ - print("LGBMModel.__init__()") if not SKLEARN_INSTALLED: raise LightGBMError( "scikit-learn is required for lightgbm.sklearn. " @@ -752,9 +751,9 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]: # # `get_params()` flows like this: # - # 0. Return parameters in subclass (self.__class__) first, by using inspect. - # 1. Return parameters in all parent classes (especially `LGBMModel`). - # 2. Return whatever is in `**kwargs`. + # 0. Get parameters in subclass (self.__class__) first, by using inspect. + # 1. Get parameters in all parent classes (especially `LGBMModel`). + # 2. Get whatever was passed via `**kwargs`. # 3. Merge them. # # This needs to accommodate being called recursively in the following @@ -768,7 +767,6 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]: # params = super().get_params(deep=deep) cp = copy.copy(self) - print(f"--- {cp.__class__.__bases__}") # If the immediate parent defines get_params(), use that. if callable(getattr(cp.__class__.__bases__[0], "get_params", None)): cp.__class__ = cp.__class__.__bases__[0] @@ -1322,7 +1320,6 @@ def __init__(self, **kwargs: Any): # - https://stackoverflow.com/questions/40025406/inherit-from-scikit-learns-lassocv-model/40027200#40027200 # - https://stackoverflow.com/questions/79320289/why-cant-i-wrap-lgbm # - https://github.com/dmlc/xgboost/blob/bd92b1c9c0db3e75ec3dfa513e1435d518bb535d/python-package/xgboost/sklearn.py#L941 - print("LGBMRegressor.__init__()") super().__init__(**kwargs) __init__.__doc__ = LGBMModel.__init__.__doc__ diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 1cdd047f1857..88cd4b5a5ebc 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -22,6 +22,7 @@ import lightgbm as lgb from lightgbm.compat import ( + DASK_INSTALLED, DATATABLE_INSTALLED, PANDAS_INSTALLED, _sklearn_version, @@ -83,6 +84,30 @@ def __call__(self, env): env.model.attr_set_inside_callback = env.iteration * 10 +class ExtendedLGBMClassifier(lgb.LGBMClassifier): + """Class for testing that inheriting from LGBMClassifier works""" + + def __init__(self, *, some_other_param: str = "lgbm-classifier", **kwargs): + self.some_other_param = some_other_param + super().__init__(**kwargs) + + +class ExtendedLGBMRanker(lgb.LGBMRanker): + """Class for testing that inheriting from LGBMRanker works""" + + def __init__(self, *, some_other_param: str = "lgbm-ranker", **kwargs): + self.some_other_param = some_other_param + super().__init__(**kwargs) + + +class ExtendedLGBMRegressor(lgb.LGBMRegressor): + """Class for testing that inheriting from LGBMRegressor works""" + + def __init__(self, *, some_other_param: str = "lgbm-regressor", **kwargs): + self.some_other_param = some_other_param + super().__init__(**kwargs) + + def custom_asymmetric_obj(y_true, y_pred): residual = (y_true - y_pred).astype(np.float64) grad = np.where(residual < 0, -2 * 10.0 * residual, -2 * residual) @@ -475,6 +500,165 @@ def test_clone_and_property(): assert isinstance(clf.feature_importances_, np.ndarray) +def test_subclassing_get_params_works(): + expected_params = { + "boosting_type": "gbdt", + "class_weight": None, + "colsample_bytree": 1.0, + "importance_type": "split", + "learning_rate": 0.1, + "max_depth": -1, + "min_child_samples": 20, + "min_child_weight": 0.001, + "min_split_gain": 0.0, + "n_estimators": 100, + "n_jobs": None, + "num_leaves": 31, + "objective": None, + "random_state": None, + "reg_alpha": 0.0, + "reg_lambda": 0.0, + "subsample": 1.0, + "subsample_for_bin": 200000, + "subsample_freq": 0, + } + + # Overrides, used to test that passing through **kwargs works as expected. + # + # why these? + # + # - 'n_estimators" directly matches a keyword arg for the scikit-learn estimators + # - 'eta' is a parameter alias for 'learning_rate' + overrides = {"n_estimators": 13, "eta": 0.07} + + # lightgbm-official classes + for est in [lgb.LGBMModel, lgb.LGBMClassifier, lgb.LGBMRanker, lgb.LGBMRegressor]: + assert est().get_params() == expected_params + assert est(**overrides).get_params() == { + **expected_params, + "eta": 0.07, + "n_estimators": 13, + "learning_rate": 0.1, + } + + if DASK_INSTALLED: + for est in [lgb.DaskLGBMClassifier, lgb.DaskLGBMRanker, lgb.DaskLGBMRegressor]: + assert est().get_params() == { + **expected_params, + "client": None, + } + assert est(**overrides).get_params() == { + **expected_params, + "eta": 0.07, + "n_estimators": 13, + "learning_rate": 0.1, + "client": None, + } + + # custom sub-classes + assert ExtendedLGBMClassifier().get_params() == {**expected_params, "some_other_param": "lgbm-classifier"} + assert ExtendedLGBMClassifier(**overrides).get_params() == { + **expected_params, + "eta": 0.07, + "n_estimators": 13, + "learning_rate": 0.1, + "some_other_param": "lgbm-classifier", + } + assert ExtendedLGBMRanker().get_params() == { + **expected_params, + "some_other_param": "lgbm-ranker", + } + assert ExtendedLGBMRanker(**overrides).get_params() == { + **expected_params, + "eta": 0.07, + "n_estimators": 13, + "learning_rate": 0.1, + "some_other_param": "lgbm-ranker", + } + assert ExtendedLGBMRegressor().get_params() == { + **expected_params, + "some_other_param": "lgbm-regressor", + } + assert ExtendedLGBMRegressor(**overrides).get_params() == { + **expected_params, + "eta": 0.07, + "n_estimators": 13, + "learning_rate": 0.1, + "some_other_param": "lgbm-regressor", + } + + +@pytest.mark.parametrize("task", all_tasks) +def test_subclassing_works(task): + # param values to make training deterministic and + # just train a small, cheap model + params = { + "deterministic": True, + "force_row_wise": True, + "n_jobs": 1, + "n_estimators": 5, + "num_leaves": 11, + "random_state": 708, + } + + X, y, g = _create_data(task=task) + if task == "ranking": + est = lgb.LGBMRanker(**params).fit(X, y, group=g) + est_sub = ExtendedLGBMRanker(**params).fit(X, y, group=g) + elif task.endswith("classification"): + est = lgb.LGBMClassifier(**params).fit(X, y) + est_sub = ExtendedLGBMClassifier(**params).fit(X, y) + else: + est = lgb.LGBMRegressor(**params).fit(X, y) + est_sub = ExtendedLGBMRegressor(**params).fit(X, y) + + np.testing.assert_allclose(est.predict(X), est_sub.predict(X)) + + +@pytest.mark.parametrize( + "estimator_to_task", + [ + (lgb.LGBMClassifier, "binary-classification"), + (ExtendedLGBMClassifier, "binary-classification"), + (lgb.LGBMRanker, "ranking"), + (ExtendedLGBMRanker, "ranking"), + (lgb.LGBMRegressor, "regression"), + (ExtendedLGBMRegressor, "regression"), + ], +) +def test_parameter_aliases_are_handled_correctly(estimator_to_task): + estimator, task = estimator_to_task + # scikit-learn estimators should remember every parameter passed + # via keyword arguments in the estimator constructor, but then + # only pass the correct value down to LightGBM's C++ side + params = { + "eta": 0.08, + "num_iterations": 3, + "num_leaves": 5, + } + X, y, g = _create_data(task=task) + mod = estimator(**params) + if task == "ranking": + mod.fit(X, y, group=g) + else: + mod.fit(X, y) + + # scikit-learn get_params() + p = mod.get_params() + assert p["eta"] == 0.08 + assert p["learning_rate"] == 0.1 + + # lgb.Booster's 'params' attribute + p = mod.booster_.params + assert p["eta"] == 0.08 + assert p["learning_rate"] == 0.1 + + # Config in the 'LightGBM::Booster' on the C++ side + p = mod.booster_._get_loaded_param() + assert p["learning_rate"] == 0.1 + assert "eta" not in p + + def test_joblib(tmp_path): X, y = make_synthetic_regression() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) @@ -1463,7 +1647,10 @@ def _get_expected_failed_tests(estimator): return estimator._more_tags()["_xfail_checks"] -@parametrize_with_checks([lgb.LGBMClassifier(), lgb.LGBMRegressor()], expected_failed_checks=_get_expected_failed_tests) +@parametrize_with_checks( + [ExtendedLGBMClassifier(), ExtendedLGBMRegressor(), lgb.LGBMClassifier(), lgb.LGBMRegressor()], + expected_failed_checks=_get_expected_failed_tests, +) def test_sklearn_integration(estimator, check): estimator.set_params(min_child_samples=1, min_data_in_bin=1) check(estimator) From 51b5e6468d5b2b6d79a66230331cc0c6dc6a6a73 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 10 Jan 2025 00:40:52 -0600 Subject: [PATCH 4/6] Update tests/python_package_test/test_sklearn.py --- tests/python_package_test/test_sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 88cd4b5a5ebc..991b3e5f8cf8 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -527,7 +527,7 @@ def test_subclassing_get_params_works(): # # why these? # - # - 'n_estimators" directly matches a keyword arg for the scikit-learn estimators + # - 'n_estimators' directly matches a keyword arg for the scikit-learn estimators # - 'eta' is a parameter alias for 'learning_rate' overrides = {"n_estimators": 13, "eta": 0.07} From 81178fd7714bd793950752de1f418e3f33b58c73 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 10 Jan 2025 00:47:28 -0600 Subject: [PATCH 5/6] remove docs links --- python-package/lightgbm/sklearn.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 87c8e680476b..fc5e716692a3 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -1315,11 +1315,6 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel): """LightGBM regressor.""" def __init__(self, **kwargs: Any): - # TODO(jameslamb): add links to PR: - # - https://github.com/microsoft/LightGBM/issues/4426 - # - https://stackoverflow.com/questions/40025406/inherit-from-scikit-learns-lassocv-model/40027200#40027200 - # - https://stackoverflow.com/questions/79320289/why-cant-i-wrap-lgbm - # - https://github.com/dmlc/xgboost/blob/bd92b1c9c0db3e75ec3dfa513e1435d518bb535d/python-package/xgboost/sklearn.py#L941 super().__init__(**kwargs) __init__.__doc__ = LGBMModel.__init__.__doc__ From d80b0df657b8473228822fc5592698cfacd68b78 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 12 Jan 2025 23:24:22 -0600 Subject: [PATCH 6/6] fix Dask tests --- tests/python_package_test/test_dask.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index b5e17991f63d..dacef3305547 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1385,13 +1385,14 @@ def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except sklearn_spec = inspect.getfullargspec(classes[1]) assert dask_spec.varargs == sklearn_spec.varargs assert dask_spec.varkw == sklearn_spec.varkw - assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs - assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults - # "client" should be the only different, and the final argument - assert dask_spec.args[:-1] == sklearn_spec.args - assert dask_spec.defaults[:-1] == sklearn_spec.defaults - assert dask_spec.args[-1] == "client" + assert dask_spec.kwonlyargs == [*sklearn_spec.kwonlyargs, "client"] + assert dask_spec.kwonlydefaults == {"client": None} + assert sklearn_spec.kwonlydefaults is None + + # only positional argument should be 'self' + assert dask_spec.args == sklearn_spec.args + assert dask_spec.args == ["self"] assert dask_spec.defaults[-1] is None