Skip to content

Commit

Permalink
MAINT Parameter validation for linear_model.orthogonal_mp (#25817)
Browse files Browse the repository at this point in the history
  • Loading branch information
choudharynishu authored Mar 14, 2023
1 parent 59a48db commit 263b428
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
19 changes: 14 additions & 5 deletions sklearn/linear_model/_omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..utils import as_float_array, check_array
from ..utils.parallel import delayed, Parallel
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils._param_validation import validate_params
from ..model_selection import check_cv

premature = (
Expand Down Expand Up @@ -281,6 +282,18 @@ def _gram_omp(
return gamma, indices[:n_active], n_active


@validate_params(
{
"X": ["array-like"],
"y": [np.ndarray],
"n_nonzero_coefs": [Interval(Integral, 1, None, closed="left"), None],
"tol": [Interval(Real, 0, None, closed="left"), None],
"precompute": ["boolean", StrOptions({"auto"})],
"copy_X": ["boolean"],
"return_path": ["boolean"],
"return_n_iter": ["boolean"],
}
)
def orthogonal_mp(
X,
y,
Expand Down Expand Up @@ -308,7 +321,7 @@ def orthogonal_mp(
Parameters
----------
X : ndarray of shape (n_samples, n_features)
X : array-like of shape (n_samples, n_features)
Input data. Columns are assumed to have unit norm.
y : ndarray of shape (n_samples,) or (n_samples, n_targets)
Expand Down Expand Up @@ -380,10 +393,6 @@ def orthogonal_mp(
# default for n_nonzero_coefs is 0.1 * n_features
# but at least one.
n_nonzero_coefs = max(int(0.1 * X.shape[1]), 1)
if tol is not None and tol < 0:
raise ValueError("Epsilon cannot be negative")
if tol is None and n_nonzero_coefs <= 0:
raise ValueError("The number of atoms must be positive")
if tol is None and n_nonzero_coefs > X.shape[1]:
raise ValueError(
"The number of atoms cannot be more than the number of features"
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/tests/test_omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_unreachable_accuracy():
@pytest.mark.parametrize("positional_params", [(X, y), (G, Xy)])
@pytest.mark.parametrize(
"keyword_params",
[{"tol": -1}, {"n_nonzero_coefs": -1}, {"n_nonzero_coefs": n_features + 1}],
[{"n_nonzero_coefs": n_features + 1}],
)
def test_bad_input(positional_params, keyword_params):
with pytest.raises(ValueError):
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _check_function_param_validation(
"sklearn.feature_selection.f_regression",
"sklearn.feature_selection.mutual_info_classif",
"sklearn.feature_selection.r_regression",
"sklearn.linear_model.orthogonal_mp",
"sklearn.metrics.accuracy_score",
"sklearn.metrics.auc",
"sklearn.metrics.average_precision_score",
Expand Down

0 comments on commit 263b428

Please sign in to comment.