Skip to content

Commit

Permalink
ENH add option to cross_validate to return estimators fitted on each …
Browse files Browse the repository at this point in the history
  • Loading branch information
bellet authored and jnothman committed Feb 27, 2018
1 parent 080e22d commit d9c2122
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 12 deletions.
11 changes: 8 additions & 3 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ The ``cross_validate`` function differs from ``cross_val_score`` in two ways -

- It allows specifying multiple metrics for evaluation.

- It returns a dict containing training scores, fit-times and score-times in
- It returns a dict containing fit-times, score-times
(and optionally training scores as well as fitted estimators) in
addition to the test score.

For single metric evaluation, where the scoring parameter is a string,
Expand All @@ -196,6 +197,9 @@ following keys -
for all the scorers. If train scores are not needed, this should be set to
``False`` explicitly.

You may also retain the estimator fitted on each training set by setting
``return_estimator=True``.

The multiple metrics can be specified either as a list, tuple or set of
predefined scorer names::

Expand Down Expand Up @@ -226,9 +230,10 @@ Or as a dict mapping scorer name to a predefined or custom scoring function::
Here is an example of ``cross_validate`` using a single metric::

>>> scores = cross_validate(clf, iris.data, iris.target,
... scoring='precision_macro')
... scoring='precision_macro',
... return_estimator=True)
>>> sorted(scores.keys())
['fit_time', 'score_time', 'test_score', 'train_score']
['estimator', 'fit_time', 'score_time', 'test_score', 'train_score']


Obtaining predictions by cross-validation
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ Model evaluation and meta-estimators
group-based CV strategies. :issue:`9085` by :user:`Laurent Direr <ldirer>`
and `Andreas Müller`_.

- Add `return_estimator` parameter in :func:`model_selection.cross_validate` to
return estimators fitted on each split. :issue:`9686` by :user:`Aurélien Bellet
<bellet>`.

Metrics

- :func:`metrics.roc_auc_score` now supports binary ``y_true`` other than
Expand Down
34 changes: 28 additions & 6 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@

def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None,
n_jobs=1, verbose=0, fit_params=None,
pre_dispatch='2*n_jobs', return_train_score="warn"):
pre_dispatch='2*n_jobs', return_train_score="warn",
return_estimator=False):
"""Evaluate metric(s) by cross-validation and also record fit/score times.
Read more in the :ref:`User Guide <multimetric_cross_validation>`.
Expand Down Expand Up @@ -129,6 +130,9 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None,
expensive and is not strictly required to select the parameters that
yield the best generalization performance.
return_estimator : boolean, default False
Whether to return the estimators fitted on each split.
Returns
-------
scores : dict of float arrays of shape=(n_splits,)
Expand All @@ -150,6 +154,10 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None,
The time for scoring the estimator on the test set for each
cv split. (Note time for scoring on the train set is not
included even if ``return_train_score`` is set to ``True``
``estimator``
The estimator objects for each cv split.
This is available only if ``return_estimator`` parameter
is set to ``True``.
Examples
--------
Expand Down Expand Up @@ -203,21 +211,26 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None,
delayed(_fit_and_score)(
clone(estimator), X, y, scorers, train, test, verbose, None,
fit_params, return_train_score=return_train_score,
return_times=True)
return_times=True, return_estimator=return_estimator)
for train, test in cv.split(X, y, groups))

zipped_scores = list(zip(*scores))
if return_train_score:
train_scores, test_scores, fit_times, score_times = zip(*scores)
train_scores = zipped_scores.pop(0)
train_scores = _aggregate_score_dicts(train_scores)
else:
test_scores, fit_times, score_times = zip(*scores)
if return_estimator:
fitted_estimators = zipped_scores.pop()
test_scores, fit_times, score_times = zipped_scores
test_scores = _aggregate_score_dicts(test_scores)

# TODO: replace by a dict in 0.21
ret = DeprecationDict() if return_train_score == 'warn' else {}
ret['fit_time'] = np.array(fit_times)
ret['score_time'] = np.array(score_times)

if return_estimator:
ret['estimator'] = fitted_estimators

for name in scorers:
ret['test_%s' % name] = np.array(test_scores[name])
if return_train_score:
Expand Down Expand Up @@ -347,7 +360,8 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
parameters, fit_params, return_train_score=False,
return_parameters=False, return_n_test_samples=False,
return_times=False, error_score='raise'):
return_times=False, return_estimator=False,
error_score='raise'):
"""Fit estimator and compute scores for a given dataset split.
Parameters
Expand Down Expand Up @@ -405,6 +419,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
return_times : boolean, optional, default: False
Whether to return the fit/score times.
return_estimator : boolean, optional, default: False
Whether to return the fitted estimator.
Returns
-------
train_scores : dict of scorer name -> float, optional
Expand All @@ -425,6 +442,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
parameters : dict or None, optional
The parameters that have been evaluated.
estimator : estimator object
The fitted estimator
"""
if verbose > 1:
if parameters is None:
Expand Down Expand Up @@ -513,6 +533,8 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
ret.extend([fit_time, score_time])
if return_parameters:
ret.append(parameters)
if return_estimator:
ret.append(estimator)
return ret


Expand Down
17 changes: 14 additions & 3 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,20 +368,23 @@ def test_cross_validate():
test_mse_scores = []
train_r2_scores = []
test_r2_scores = []
fitted_estimators = []
for train, test in cv.split(X, y):
est = clone(reg).fit(X[train], y[train])
train_mse_scores.append(mse_scorer(est, X[train], y[train]))
train_r2_scores.append(r2_scorer(est, X[train], y[train]))
test_mse_scores.append(mse_scorer(est, X[test], y[test]))
test_r2_scores.append(r2_scorer(est, X[test], y[test]))
fitted_estimators.append(est)

train_mse_scores = np.array(train_mse_scores)
test_mse_scores = np.array(test_mse_scores)
train_r2_scores = np.array(train_r2_scores)
test_r2_scores = np.array(test_r2_scores)
fitted_estimators = np.array(fitted_estimators)

scores = (train_mse_scores, test_mse_scores, train_r2_scores,
test_r2_scores)
test_r2_scores, fitted_estimators)

yield check_cross_validate_single_metric, est, X, y, scores
yield check_cross_validate_multi_metric, est, X, y, scores
Expand Down Expand Up @@ -411,7 +414,7 @@ def test_cross_validate_return_train_score_warn():

def check_cross_validate_single_metric(clf, X, y, scores):
(train_mse_scores, test_mse_scores, train_r2_scores,
test_r2_scores) = scores
test_r2_scores, fitted_estimators) = scores
# Test single metric evaluation when scoring is string or singleton list
for (return_train_score, dict_len) in ((True, 4), (False, 3)):
# Single metric passed as a string
Expand Down Expand Up @@ -443,11 +446,19 @@ def check_cross_validate_single_metric(clf, X, y, scores):
assert_equal(len(r2_scores_dict), dict_len)
assert_array_almost_equal(r2_scores_dict['test_r2'], test_r2_scores)

# Test return_estimator option
mse_scores_dict = cross_validate(clf, X, y, cv=5,
scoring='neg_mean_squared_error',
return_estimator=True)
for k, est in enumerate(mse_scores_dict['estimator']):
assert_almost_equal(est.coef_, fitted_estimators[k].coef_)
assert_almost_equal(est.intercept_, fitted_estimators[k].intercept_)


def check_cross_validate_multi_metric(clf, X, y, scores):
# Test multimetric evaluation when scoring is a list / dict
(train_mse_scores, test_mse_scores, train_r2_scores,
test_r2_scores) = scores
test_r2_scores, fitted_estimators) = scores
all_scoring = (('r2', 'neg_mean_squared_error'),
{'r2': make_scorer(r2_score),
'neg_mean_squared_error': 'neg_mean_squared_error'})
Expand Down

0 comments on commit d9c2122

Please sign in to comment.