Skip to content

Commit

Permalink
[MRG+1] Add scorer based on brier_score_loss (scikit-learn#9521)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinhanmin2014 authored and jnothman committed Aug 21, 2017
1 parent d9fdd8b commit ee2025f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
3 changes: 2 additions & 1 deletion doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Scoring Function
**Classification**
'accuracy' :func:`metrics.accuracy_score`
'average_precision' :func:`metrics.average_precision_score`
'brier_score_loss' :func:`metrics.brier_score_loss`
'f1' :func:`metrics.f1_score` for binary targets
'f1_micro' :func:`metrics.f1_score` micro-averaged
'f1_macro' :func:`metrics.f1_score` macro-averaged
Expand Down Expand Up @@ -102,7 +103,7 @@ Usage examples:
>>> model = svm.SVC()
>>> cross_val_score(model, X, y, scoring='wrong_choice')
Traceback (most recent call last):
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'average_precision', 'completeness_score', 'explained_variance', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'fowlkes_mallows_score', 'homogeneity_score', 'mutual_info_score', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'normalized_mutual_info_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc', 'v_measure_score']
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'average_precision', 'brier_score_loss', 'completeness_score', 'explained_variance', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'fowlkes_mallows_score', 'homogeneity_score', 'mutual_info_score', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'normalized_mutual_info_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc', 'v_measure_score']

.. note::

Expand Down
17 changes: 11 additions & 6 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ Classifiers and regressors
via ``n_iter_no_change``, ``validation_fraction`` and ``tol``. :issue:`7071`
by `Raghav RV`_

Enhancements
............

Model evaluation and meta-estimators

- A scorer based on :func:`metrics.brier_score_loss` is also available.
:issue:`9521` by :user:`Hanmin Qin <qinhanmin2014>`.

Bug fixes
.........

Expand Down Expand Up @@ -185,9 +193,6 @@ Model selection and evaluation
:class:`model_selection.RepeatedStratifiedKFold`.
:issue:`8120` by `Neeraj Gangwar`_.

- Added a scorer based on :class:`metrics.explained_variance_score`.
:issue:`9259` by `Hanmin Qin <https://github.com/qinhanmin2014>`_.

Miscellaneous

- Validation that input data contains no NaN or inf can now be suppressed
Expand Down Expand Up @@ -287,9 +292,6 @@ Decomposition, manifold learning and clustering
``singular_values_``, like in :class:`decomposition.IncrementalPCA`.
:issue:`7685` by :user:`Tommy Löfstedt <tomlof>`

- Fixed the implementation of noise_variance_ in :class:`decomposition.PCA`.
:issue:`9108` by `Hanmin Qin <https://github.com/qinhanmin2014>`_.

- :class:`decomposition.NMF` now faster when ``beta_loss=0``.
:issue:`9277` by :user:`hongkahjun`.

Expand Down Expand Up @@ -380,6 +382,9 @@ Model evaluation and meta-estimators
- More clustering metrics are now available through :func:`metrics.get_scorer`
and ``scoring`` parameters. :issue:`8117` by `Raghav RV`_.

- A scorer based on :func:`metrics.explained_variance_score` is also available.
:issue:`9259` by :user:`Hanmin Qin <qinhanmin2014>`.

Metrics

- :func:`metrics.matthews_corrcoef` now support multiclass classification.
Expand Down
9 changes: 8 additions & 1 deletion sklearn/metrics/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
mean_squared_error, mean_squared_log_error, accuracy_score,
f1_score, roc_auc_score, average_precision_score,
precision_score, recall_score, log_loss,
explained_variance_score)
explained_variance_score, brier_score_loss)

from .cluster import adjusted_rand_score
from .cluster import homogeneity_score
Expand Down Expand Up @@ -135,7 +135,10 @@ def __call__(self, clf, X, y, sample_weight=None):
"""
super(_ProbaScorer, self).__call__(clf, X, y,
sample_weight=sample_weight)
y_type = type_of_target(y)
y_pred = clf.predict_proba(X)
if y_type == "binary":
y_pred = y_pred[:, 1]
if sample_weight is not None:
return self._sign * self._score_func(y, y_pred,
sample_weight=sample_weight,
Expand Down Expand Up @@ -514,6 +517,9 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
log_loss_scorer = make_scorer(log_loss, greater_is_better=False,
needs_proba=True)
log_loss_scorer._deprecation_msg = deprecation_msg
brier_score_loss_scorer = make_scorer(brier_score_loss,
greater_is_better=False,
needs_proba=True)


# Clustering scores
Expand All @@ -540,6 +546,7 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
average_precision=average_precision_scorer,
log_loss=log_loss_scorer,
neg_log_loss=neg_log_loss_scorer,
brier_score_loss=brier_score_loss_scorer,
# Cluster metrics that use supervised evaluation
adjusted_rand_score=adjusted_rand_scorer,
homogeneity_score=homogeneity_scorer,
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
'roc_auc', 'average_precision', 'precision',
'precision_weighted', 'precision_macro', 'precision_micro',
'recall', 'recall_weighted', 'recall_macro', 'recall_micro',
'neg_log_loss', 'log_loss']
'neg_log_loss', 'log_loss', 'brier_score_loss']

# All supervised cluster scorers (They behave like classification metric)
CLUSTER_SCORERS = ["adjusted_rand_score",
Expand Down

0 comments on commit ee2025f

Please sign in to comment.