Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT Common parameter validation #22722

Merged
merged 42 commits into from
May 16, 2022

Conversation

jeremiedbb
Copy link
Member

@jeremiedbb jeremiedbb commented Mar 7, 2022

This PR proposes a unified design for parameter validation across estimators, classes and functions.

The goal is to have a consistent way to raise an informative error message when a parameter does not have a valid type/value. Here's an example:

>>> KMeans(init="wrong").fit(X)
ValueError: The 'init' parameter of KMeans must be a str among {'k-means++', 'random'}, a callable or an array-like. Got 'wrong' instead.

It's also meant to centralize all these checks in one place, i.e. being the first instruction of fit or of a function. Currently they can be spread throughout fit making it hard to follow and slow to fail. I also find that having all this boilerplate inside fit makes the actual interesting code of the algorithm hard to find and mixed up with non-relevant code.
In addition, these checks are currently often done for a small subset of the parameters and often not tested. And when tested, it's often spread inside several tests.

This PR only deals with non co-dependent types and values between parameters. For instance if a value of a parameter is valid only if some value of another parameter is set.

I propose to add to BaseEstimator a method _validate_params that performs validation for all parameters of estimators and a decorator validate_params for public functions. Validation is made against a dict param_name: constraint where constraint is a list of valid types/values.

# param validation of an estimator
class SomeEstimator(BaseEstimator):
    _parameter_constraints = {
        "n_clusters": [Interval(Integral, 1, None, closed="left")],
        "init": [StrOptions(["k-means++", "random"]), callable, "array-like")],
        "tol": [Interval(Real, 0, None, closed="left")],
        "algorithm": [StrOptions(["lloyd", "elkan", "auto", "full"], deprecated={"auto", "full"})],
        "max_no_improvement": [None,  Interval(Integral, 0, None, closed="left")]
    }

    def fit(X, y):
        self._validate_params()
# param validation of a function
@validate_params(
    {
        "n_clusters": [Interval(Integral, 1, None, closed="left")],
        "init": [StrOptions(["k-means++", "random"]), callable, "array-like")],
        "tol": [Interval(Real, 0, None, closed="left")],
        "algorithm": [StrOptions(["lloyd", "elkan", "auto", "full"], deprecated={"auto", "full"})],
        "max_no_improvement": [None,  Interval(Integral, 0, None, closed="left")]
    }
)
def some_func(n_clusters, init, tol, algorithm, max_no_improvement):
    ...

I also propose to add a new common test that makes sure this is done for all estimators (almost all of them being skipped right now).

closes #14721

Copy link
Member Author

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some comments for possible extensions of this work.

sklearn/cluster/_kmeans.py Outdated Show resolved Hide resolved
"n_clusters": [(numbers.Integral, Interval(1, None, closed="left"))],
"init": [
(str, {"k-means++", "random"}),
(callable,),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future work: we can imagine defining the subset of callables with a specific signature to pass here as valid values

sklearn/cluster/_kmeans.py Outdated Show resolved Hide resolved
@@ -1166,7 +1190,8 @@ def fit(self, X, y=None, sample_weight=None):
accept_large_sparse=False,
)

self._check_params(X)
self._check_params_vs_input(X)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a second round of checks after data validation that deals with valid values that depend on the data or on other parameters.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening the PR on this topic!

The way the specification is defined is a dictionary of list of tuples where each tuple is: (valid_type, constraint). I like thinking of everything as a constraint.

As for the developer API, I see two parts:

  1. Defining the constraints
  2. Actually performing the validation.

In this PR, item 1 is a dictionary, and item 2 is a function call to validate_param. There is also another API for validate_params that combines item 1 and item 2 that is used in function calls. My preference is to have one API instead of two.

As we are already defining a Interval object, I think it's okay to go straight to defining a Validator object:

validator = Validator(
    n_clusters=[Interval(Integral, 1, None, closed="left")],
    init=[Options(["k-means++", "random"]), callable, "array-like")],
    tol=[Interval(Real, 0, None, closed="left")],
    algorithm=[Options(["lloyd", "elkan", "auto", "full"], deprecated={"auto", "full"})],
    max_no_improvement=[None,  Interval(Integral, 0, None, closed="left")]
)
validator.validate(n_clusters=2, ...)

The above can be used directly in functions.

For estimators:

class MyEstimator:
    _validator = Validator(...)

    def fit(self, X, y):
        self._validator.validate(self.get_params())

I think the dictionary of lists of tuples has semantics that makes it harder to parse and a validator object makes the semantics clear.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the dictionary of constraints idea!

sklearn/cluster/_kmeans.py Outdated Show resolved Hide resolved
sklearn/utils/_param_validation.py Outdated Show resolved Hide resolved
sklearn/utils/_param_validation.py Outdated Show resolved Hide resolved
Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for tackling this, @jeremiedbb.

I think this very valuable for maintenance on the long term.

Here is a first review.

sklearn/utils/_param_validation.py Outdated Show resolved Hide resolved
sklearn/utils/_param_validation.py Outdated Show resolved Hide resolved
sklearn/utils/_param_validation.py Outdated Show resolved Hide resolved
sklearn/utils/_param_validation.py Outdated Show resolved Hide resolved
@jjerphan jjerphan changed the title Common parameter validation MAINT Common parameter validation Mar 11, 2022
Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks quite interesting, and I'm quite happy with it. But I really would like to see what @jnothman thinks about it. I don't think this adds too much complexity and it's not a required API for developers.

sklearn/cluster/_kmeans.py Outdated Show resolved Hide resolved
Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really nice. I'm the second approver here, but since it's quite major, I'd like another set of eyes giving a thumb up before merging.

@adrinjalali
Copy link
Member

Seems like there hasn't been any objections since I left my last comment a month ago. Will update the branch, and merge after if CI passes.

@adrinjalali
Copy link
Member

@jeremiedbb CI fails here.

@lorentzenchr
Copy link
Member

+1 for this change. It is a step in the right direction! Thank @jeremiedbb for your effort!

@jjerphan
Copy link
Member

jjerphan commented May 16, 2022

29c12fe resolves the failing tests. Feel free to cherry-pick (I've tried to do a PR on top of the branch of this PR but I can't).

Thank you once again, @jeremiedbb. I am looking forward to the merge.

@jeremiedbb
Copy link
Member Author

jeremiedbb commented May 16, 2022

thanks @jjerphan. I was also looking at this :)
It actually needed a little bit more fixes. Should be ok now

@jjerphan
Copy link
Member

I would wait for another 3rd approval before merging this one. What do you think, @adrinjalali?

@adrinjalali
Copy link
Member

We also have @lorentzenchr 's +1 here. I think we can merge. I think there's been enough time to object if there were concerns.

@adrinjalali adrinjalali merged commit 2b09fa0 into scikit-learn:main May 16, 2022
@lorentzenchr
Copy link
Member

How about opening a follow-up issue to track progress on the modules (making PARAM_VALIDATION_ESTIMATORS_TO_IGNORE smaller)?
Maybe also some PR/issue for the documentation?

lesteve pushed a commit to lesteve/scikit-learn that referenced this pull request May 19, 2022
* common parameter validation

* black

* cln

* wip

* wip

* rework

* renaming and cleaning

* lint

* re lint

* cln

* add tests

* lint

* make random_state constraint

* lint

* closed positional

* increase coverage + validate constraints

* exp typing

* trigger ci ?

* lint

* cln

* rev type hints

* cln

* interval closed kwarg only

* address comments

* address comments + more tests + cln + improve err msg

* lint

* cln

* cln

* address comments

* address comments

* lint

* adapt or skip new estimators

* lint

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
ivannz added a commit to ivannz/scikit-learn that referenced this pull request Aug 29, 2022
…cSVM scikit-learn#24001)

finish v1.2 deprecation of params kwargs in `.fit` of SVDD (similar to ocSVM scikit-learn#20843)
removed SVDD param-validation exception from test_common.py since scikit-learn#23462 is go (scikit-learn#22722)
ivannz added a commit to ivannz/scikit-learn that referenced this pull request Sep 5, 2022
…cSVM scikit-learn#24001)

finish v1.2 deprecation of params kwargs in `.fit` of SVDD (similar to ocSVM scikit-learn#20843)
TST ensure SVDD passes param-validation test_common.py due to scikit-learn#23462 (scikit-learn#22722)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use decorators for simple input validations
9 participants