Skip to content

Commit

Permalink
Add tags to classifiers and regressors to identify them as such.
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller committed Mar 31, 2015
1 parent e2dfd23 commit d89c215
Show file tree
Hide file tree
Showing 16 changed files with 261 additions and 61 deletions.
13 changes: 13 additions & 0 deletions doc/developers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,19 @@ take arguments ``X, y``, even if y is not used. Similarly, for ``score`` to be
usable, the last step of the pipeline needs to have a ``score`` function that
accepts an optional ``y``.

Estimator types
---------------
Some common functionality depends on the kind of estimator passed.
For example, cross-validation in :class:`grid_search.GridSearchCV` and
:func:`cross_validation.cross_val_score` defaults to being stratified when used
on a classifier, but not otherwise. Similarly, scorers for average precision
that take a continuous prediction need to call ``decision_function`` for classifiers,
but ``predict`` for regressors. This distinction between classifiers and regressors
is implemented using the ``_estimator_type`` attribute, which takes a string value.
It should be ``"classifier"`` for classifiers and ``"regressor"`` for regressors,
to work as expected. Inheriting from ``ClassifierMixin`` or ``RegressorMixin`` will
set the attribute automatically.

Working notes
-------------

Expand Down
28 changes: 12 additions & 16 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,14 @@ def set_params(self, **params):
if len(split) > 1:
# nested objects case
name, sub_name = split
if not name in valid_params:
if name not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s' %
(name, self))
sub_object = valid_params[name]
sub_object.set_params(**{sub_name: value})
else:
# simple objects case
if not key in valid_params:
if key not in valid_params:
raise ValueError('Invalid parameter %s ' 'for estimator %s'
% (key, self.__class__.__name__))
setattr(self, key, value)
Expand All @@ -266,6 +266,7 @@ def __repr__(self):
###############################################################################
class ClassifierMixin(object):
"""Mixin class for all classifiers in scikit-learn."""
_estimator_type = "classifier"

def score(self, X, y, sample_weight=None):
"""Returns the mean accuracy on the given test data and labels.
Expand Down Expand Up @@ -298,6 +299,7 @@ def score(self, X, y, sample_weight=None):
###############################################################################
class RegressorMixin(object):
"""Mixin class for all regression estimators in scikit-learn."""
_estimator_type = "regressor"

def score(self, X, y, sample_weight=None):
"""Returns the coefficient of determination R^2 of the prediction.
Expand Down Expand Up @@ -331,6 +333,8 @@ def score(self, X, y, sample_weight=None):
###############################################################################
class ClusterMixin(object):
"""Mixin class for all cluster estimators in scikit-learn."""
_estimator_type = "clusterer"

def fit_predict(self, X, y=None):
"""Performs clustering on X and returns cluster labels.
Expand Down Expand Up @@ -443,20 +447,12 @@ class MetaEstimatorMixin(object):


###############################################################################
# XXX: Temporary solution to figure out if an estimator is a classifier

def _get_sub_estimator(estimator):
"""Returns the final estimator if there is any."""
if hasattr(estimator, 'estimator'):
# GridSearchCV and other CV-tuned estimators
return _get_sub_estimator(estimator.estimator)
if hasattr(estimator, 'steps'):
# Pipeline
return _get_sub_estimator(estimator.steps[-1][1])
return estimator


def is_classifier(estimator):
"""Returns True if the given estimator is (probably) a classifier."""
estimator = _get_sub_estimator(estimator)
return isinstance(estimator, ClassifierMixin)
return getattr(estimator, "_estimator_type", None) == "classifier"


def is_regressor(estimator):
"""Returns True if the given estimator is (probably) a regressor."""
return getattr(estimator, "_estimator_type", None) == "regressor"
92 changes: 82 additions & 10 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from ..base import ClassifierMixin
from ..base import RegressorMixin
from ..utils import check_random_state, check_array, check_X_y, column_or_1d
from ..utils import check_consistent_length
from ..utils import check_consistent_length, deprecated
from ..utils.extmath import logsumexp
from ..utils.fixes import expit, bincount
from ..utils.stats import _weighted_percentile
Expand Down Expand Up @@ -438,7 +438,7 @@ class ClassificationLossFunction(six.with_metaclass(ABCMeta, LossFunction)):
def _score_to_proba(self, score):
"""Template method to convert scores to probabilities.
If the loss does not support probabilites raises AttributeError.
the does not support probabilites raises AttributeError.
"""
raise TypeError('%s does not support predict_proba' % type(self).__name__)

Expand Down Expand Up @@ -1044,9 +1044,10 @@ def _fit_stages(self, X, y, y_pred, sample_weight, random_state,
self.train_score_[i] = loss_(y[sample_mask],
y_pred[sample_mask],
sample_weight[sample_mask])
self.oob_improvement_[i] = (old_oob_score -
loss_(y[~sample_mask], y_pred[~sample_mask],
sample_weight[~sample_mask]))
self.oob_improvement_[i] = (
old_oob_score - loss_(y[~sample_mask],
y_pred[~sample_mask],
sample_weight[~sample_mask]))
else:
# no need to fancy index w/ no subsampling
self.train_score_[i] = loss_(y, y_pred, sample_weight)
Expand Down Expand Up @@ -1082,6 +1083,7 @@ def _decision_function(self, X):
predict_stages(self.estimators_, X, self.learning_rate, score)
return score

@deprecated(" and will be removed in 0.19")
def decision_function(self, X):
"""Compute the decision function of ``X``.
Expand All @@ -1104,7 +1106,7 @@ def decision_function(self, X):
return score.ravel()
return score

def staged_decision_function(self, X):
def _staged_decision_function(self, X):
"""Compute decision function of ``X`` for each iteration.
This method allows monitoring (i.e. determine error on testing set)
Expand All @@ -1129,6 +1131,30 @@ def staged_decision_function(self, X):
predict_stage(self.estimators_, i, X, self.learning_rate, score)
yield score.copy()

@deprecated(" and will be removed in 0.19")
def staged_decision_function(self, X):
"""Compute decision function of ``X`` for each iteration.
This method allows monitoring (i.e. determine error on testing set)
after each stage.
Parameters
----------
X : array-like of shape = [n_samples, n_features]
The input samples.
Returns
-------
score : generator of array, shape = [n_samples, k]
The decision function of the input samples. The order of the
classes corresponds to that in the attribute `classes_`.
Regression and binary classification are special cases with
``k == 1``, otherwise ``k==n_classes``.
"""
for dec in self._staged_decision_function(X):
# no yield from in Python2.X
yield dec

@property
def feature_importances_(self):
"""Return the feature importances (the higher, the more important the
Expand Down Expand Up @@ -1315,6 +1341,51 @@ def _validate_y(self, y):
self.n_classes_ = len(self.classes_)
return y

def decision_function(self, X):
"""Compute the decision function of ``X``.
Parameters
----------
X : array-like of shape = [n_samples, n_features]
The input samples.
Returns
-------
score : array, shape = [n_samples, n_classes] or [n_samples]
The decision function of the input samples. The order of the
classes corresponds to that in the attribute `classes_`.
Regression and binary classification produce an array of shape
[n_samples].
"""
X = check_array(X, dtype=DTYPE, order="C")
score = self._decision_function(X)
if score.shape[1] == 1:
return score.ravel()
return score

def staged_decision_function(self, X):
"""Compute decision function of ``X`` for each iteration.
This method allows monitoring (i.e. determine error on testing set)
after each stage.
Parameters
----------
X : array-like of shape = [n_samples, n_features]
The input samples.
Returns
-------
score : generator of array, shape = [n_samples, k]
The decision function of the input samples. The order of the
classes corresponds to that in the attribute `classes_`.
Regression and binary classification are special cases with
``k == 1``, otherwise ``k==n_classes``.
"""
for dec in self._staged_decision_function(X):
# no yield from in Python2.X
yield dec

def predict(self, X):
"""Predict class for X.
Expand Down Expand Up @@ -1348,7 +1419,7 @@ def staged_predict(self, X):
y : generator of array of shape = [n_samples]
The predicted value of the input samples.
"""
for score in self.staged_decision_function(X):
for score in self._staged_decision_function(X):
decisions = self.loss_._score_to_decision(score)
yield self.classes_.take(decisions, axis=0)

Expand Down Expand Up @@ -1419,7 +1490,7 @@ def staged_predict_proba(self, X):
The predicted value of the input samples.
"""
try:
for score in self.staged_decision_function(X):
for score in self._staged_decision_function(X):
yield self.loss_._score_to_proba(score)
except NotFittedError:
raise
Expand Down Expand Up @@ -1594,7 +1665,8 @@ def predict(self, X):
y : array of shape = [n_samples]
The predicted values.
"""
return self.decision_function(X).ravel()
X = check_array(X, dtype=DTYPE, order="C")
return self._decision_function(X).ravel()

def staged_predict(self, X):
"""Predict regression target at each stage for X.
Expand All @@ -1612,5 +1684,5 @@ def staged_predict(self, X):
y : generator of array of shape = [n_samples]
The predicted value of the input samples.
"""
for y in self.staged_decision_function(X):
for y in self._staged_decision_function(X):
yield y.ravel()
21 changes: 11 additions & 10 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Testing for the gradient boosting module (sklearn.ensemble.gradient_boosting).
"""

import warnings
import numpy as np

from sklearn import datasets
Expand Down Expand Up @@ -171,8 +171,9 @@ def test_boston():
for loss in ("ls", "lad", "huber"):
for subsample in (1.0, 0.5):
last_y_pred = None
for i, sample_weight in enumerate((None, np.ones(len(boston.target)),
2 * np.ones(len(boston.target)))):
for i, sample_weight in enumerate(
(None, np.ones(len(boston.target)),
2 * np.ones(len(boston.target)))):
clf = GradientBoostingRegressor(n_estimators=100, loss=loss,
max_depth=4, subsample=subsample,
min_samples_split=1,
Expand Down Expand Up @@ -343,6 +344,7 @@ def test_check_max_features():
max_features=-0.1)
assert_raises(ValueError, clf.fit, X, y)


def test_max_feature_regression():
# Test to make sure random state is set properly.
X, y = datasets.make_hastie_10_2(n_samples=12000, random_state=1)
Expand Down Expand Up @@ -455,7 +457,8 @@ def test_staged_functions_defensive():
if staged_func is None:
# regressor has no staged_predict_proba
continue
staged_result = list(staged_func(X))
with warnings.catch_warnings(record=True):
staged_result = list(staged_func(X))
staged_result[1][:] = 0
assert_true(np.all(staged_result[0] != 0))

Expand Down Expand Up @@ -843,7 +846,7 @@ def test_complete_classification():
k = 4

est = GradientBoostingClassifier(n_estimators=20, max_depth=None,
random_state=1, max_leaf_nodes=k+1)
random_state=1, max_leaf_nodes=k + 1)
est.fit(X, y)

tree = est.estimators_[0, 0].tree_
Expand All @@ -858,7 +861,7 @@ def test_complete_regression():
k = 4

est = GradientBoostingRegressor(n_estimators=20, max_depth=None,
random_state=1, max_leaf_nodes=k+1)
random_state=1, max_leaf_nodes=k + 1)
est.fit(boston.data, boston.target)

tree = est.estimators_[-1, 0].tree_
Expand Down Expand Up @@ -971,8 +974,7 @@ def test_non_uniform_weights_toy_edge_case_reg():
X = [[1, 0],
[1, 0],
[1, 0],
[0, 1],
]
[0, 1]]
y = [0, 0, 1, 0]
# ignore the first 2 training samples by setting their weight to 0
sample_weight = [0, 0, 1, 1]
Expand Down Expand Up @@ -1002,8 +1004,7 @@ def test_non_uniform_weights_toy_edge_case_clf():
X = [[1, 0],
[1, 0],
[1, 0],
[0, 1],
]
[0, 1]]
y = [0, 0, 1, 0]
# ignore the first 2 training samples by setting their weight to 0
sample_weight = [0, 0, 1, 1]
Expand Down
4 changes: 4 additions & 0 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def __init__(self, estimator, scoring=None,
self.pre_dispatch = pre_dispatch
self.error_score = error_score

@property
def _estimator_type(self):
return self.estimator._estimator_type

def score(self, X, y=None):
"""Returns the score on the given data, if the estimator has been refit
Expand Down
8 changes: 6 additions & 2 deletions sklearn/linear_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..externals import six
from ..externals.joblib import Parallel, delayed
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
from ..utils import as_float_array, check_array, check_X_y
from ..utils import as_float_array, check_array, check_X_y, deprecated
from ..utils.extmath import safe_sparse_dot
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
from ..utils.fixes import sparse_lsqr
Expand Down Expand Up @@ -119,6 +119,7 @@ class LinearModel(six.with_metaclass(ABCMeta, BaseEstimator)):
def fit(self, X, y):
"""Fit model."""

@deprecated(" and will be removed in 0.19.")
def decision_function(self, X):
"""Decision function of the linear model.
Expand All @@ -132,6 +133,9 @@ def decision_function(self, X):
C : array, shape = (n_samples,)
Returns predicted values.
"""
return self._decision_function(X)

def _decision_function(self, X):
check_is_fitted(self, "coef_")

X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
Expand All @@ -151,7 +155,7 @@ def predict(self, X):
C : array, shape = (n_samples,)
Returns predicted values.
"""
return self.decision_function(X)
return self._decision_function(X)

_center_data = staticmethod(center_data)

Expand Down
Loading

0 comments on commit d89c215

Please sign in to comment.