Skip to content

Commit

Permalink
[MRG+1] Issue#7998 : Consistent parameters between QDA and LDA (sciki…
Browse files Browse the repository at this point in the history
…t-learn#8130)

* for scikit-learn#7998

* Fix some style error and add test

* Add local variable store_covariance

* better deprecation

* fix bug

* Style check

* fix covariance_

* style check

* Update

* modify test

* Formating

* update

* Update

* Add whats_new.rst

* Revert "Add whats_new.rst"

This reverts commit 4e5977d.

* whats_new

* Update for FutureWarning

* Remove warning from the setter

* add fit in test

* drop back

* Quick fix

* Small fix

* Fix

* update new

* Fix space

* Fix docstring

* fix style

* Fix

* fix assert
  • Loading branch information
mrbeann authored and dmohns committed Aug 7, 2017
1 parent a0b247b commit 2488540
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 18 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,13 @@ Miscellaneous
:mod:`utils` have been removed or deprecated accordingly.
:issue:`8854` and :issue:`8874` by :user:`Naoya Kanai <naoyak>`

- The ``store_covariances`` and ``covariances_`` parameters of
:class:`discriminant_analysis.QuadraticDiscriminantAnalysis`
has been renamed to ``store_covariance`` and ``covariance_`` to be
consistent with the corresponding parameter names of the
:class:`discriminant_analysis.LinearDiscriminantAnalysis`. They will be
removed in version 0.21. :issue:`7998` by :user:`Jiacheng <mrbeann>`

Removed in 0.19:

- ``utils.fixes.argpartition``
Expand Down
41 changes: 29 additions & 12 deletions sklearn/discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from __future__ import print_function
import warnings

import numpy as np
from .utils import deprecated
from scipy import linalg
from .externals.six import string_types
from .externals.six.moves import xrange
Expand Down Expand Up @@ -170,7 +170,8 @@ class LinearDiscriminantAnalysis(BaseEstimator, LinearClassifierMixin,
Number of components (< n_classes - 1) for dimensionality reduction.
store_covariance : bool, optional
Additionally compute class covariance matrix (default False).
Additionally compute class covariance matrix (default False), used
only in 'svd' solver.
.. versionadded:: 0.17
Expand Down Expand Up @@ -245,6 +246,7 @@ class LinearDiscriminantAnalysis(BaseEstimator, LinearClassifierMixin,
>>> print(clf.predict([[-0.8, -1]]))
[1]
"""

def __init__(self, solver='svd', shrinkage=None, priors=None,
n_components=None, store_covariance=False, tol=1e-4):
self.solver = solver
Expand Down Expand Up @@ -554,7 +556,7 @@ class QuadraticDiscriminantAnalysis(BaseEstimator, ClassifierMixin):
Regularizes the covariance estimate as
``(1-reg_param)*Sigma + reg_param*np.eye(n_features)``
store_covariances : boolean
store_covariance : boolean
If True the covariance matrices are computed and stored in the
`self.covariances_` attribute.
Expand All @@ -567,7 +569,7 @@ class QuadraticDiscriminantAnalysis(BaseEstimator, ClassifierMixin):
Attributes
----------
covariances_ : list of array-like, shape = [n_features, n_features]
covariance_ : list of array-like, shape = [n_features, n_features]
Covariance matrices of each class.
means_ : array-like, shape = [n_classes, n_features]
Expand Down Expand Up @@ -597,7 +599,8 @@ class QuadraticDiscriminantAnalysis(BaseEstimator, ClassifierMixin):
>>> clf.fit(X, y)
... # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
QuadraticDiscriminantAnalysis(priors=None, reg_param=0.0,
store_covariances=False, tol=0.0001)
store_covariance=False,
store_covariances=None, tol=0.0001)
>>> print(clf.predict([[-0.8, -1]]))
[1]
Expand All @@ -607,21 +610,30 @@ class QuadraticDiscriminantAnalysis(BaseEstimator, ClassifierMixin):
Discriminant Analysis
"""

def __init__(self, priors=None, reg_param=0., store_covariances=False,
tol=1.0e-4):
def __init__(self, priors=None, reg_param=0., store_covariance=False,
tol=1.0e-4, store_covariances=None):
self.priors = np.asarray(priors) if priors is not None else None
self.reg_param = reg_param
self.store_covariances = store_covariances
self.store_covariance = store_covariance
self.tol = tol

@property
@deprecated("Attribute covariances_ was deprecated in version"
" 0.19 and will be removed in 0.21. Use "
"covariance_ instead")
def covariances_(self):
return self.covariance_

def fit(self, X, y):
"""Fit the model according to the given training data and parameters.
.. versionchanged:: 0.19
*store_covariance* has been moved to main constructor.
``store_covariances`` has been moved to main constructor as
``store_covariance``
.. versionchanged:: 0.19
*tol* has been moved to main constructor.
``tol`` has been moved to main constructor.
Parameters
----------
Expand All @@ -645,7 +657,12 @@ def fit(self, X, y):
self.priors_ = self.priors

cov = None
store_covariance = self.store_covariance or self.store_covariances
if self.store_covariances:
warnings.warn("'store_covariances' was renamed to store_covariance"
" in version 0.19 and will be removed in 0.21.",
DeprecationWarning)
if store_covariance:
cov = []
means = []
scalings = []
Expand All @@ -665,13 +682,13 @@ def fit(self, X, y):
warnings.warn("Variables are collinear")
S2 = (S ** 2) / (len(Xg) - 1)
S2 = ((1 - self.reg_param) * S2) + self.reg_param
if self.store_covariances:
if self.store_covariance or store_covariance:
# cov = V * (S^2 / (n-1)) * V.T
cov.append(np.dot(S2 * Vt.T, Vt))
scalings.append(S2)
rotations.append(Vt.T)
if self.store_covariances:
self.covariances_ = cov
if self.store_covariance or store_covariance:
self.covariance_ = cov
self.means_ = np.asarray(means)
self.scalings_ = scalings
self.rotations_ = rotations
Expand Down
60 changes: 54 additions & 6 deletions sklearn/tests/test_discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_false
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import assert_warns
from sklearn.utils.testing import assert_warns_message
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import ignore_warnings

Expand Down Expand Up @@ -223,6 +225,38 @@ def test_lda_scaling():
'using covariance: %s' % solver)


def test_lda_store_covariance():
# Test for slover 'lsqr' and 'eigen'
# 'store_covariance' has no effect on 'lsqr' and 'eigen' solvers
for solver in ('lsqr', 'eigen'):
clf = LinearDiscriminantAnalysis(solver=solver).fit(X6, y6)
assert_true(hasattr(clf, 'covariance_'))

# Test the actual attribute:
clf = LinearDiscriminantAnalysis(solver=solver,
store_covariance=True).fit(X6, y6)
assert_true(hasattr(clf, 'covariance_'))

assert_array_almost_equal(
clf.covariance_,
np.array([[0.422222, 0.088889], [0.088889, 0.533333]])
)

# Test for SVD slover, the default is to not set the covariances_ attribute
clf = LinearDiscriminantAnalysis(solver='svd').fit(X6, y6)
assert_false(hasattr(clf, 'covariance_'))

# Test the actual attribute:
clf = LinearDiscriminantAnalysis(solver=solver,
store_covariance=True).fit(X6, y6)
assert_true(hasattr(clf, 'covariance_'))

assert_array_almost_equal(
clf.covariance_,
np.array([[0.422222, 0.088889], [0.088889, 0.533333]])
)


def test_qda():
# QDA classification.
# This checks that QDA implements fit and predict and returns
Expand Down Expand Up @@ -262,26 +296,40 @@ def test_qda_priors():
assert_greater(n_pos2, n_pos)


def test_qda_store_covariances():
def test_qda_store_covariance():
# The default is to not set the covariances_ attribute
clf = QuadraticDiscriminantAnalysis().fit(X6, y6)
assert_true(not hasattr(clf, 'covariances_'))
assert_false(hasattr(clf, 'covariance_'))

# Test the actual attribute:
clf = QuadraticDiscriminantAnalysis(store_covariances=True).fit(X6, y6)
assert_true(hasattr(clf, 'covariances_'))
clf = QuadraticDiscriminantAnalysis(store_covariance=True).fit(X6, y6)
assert_true(hasattr(clf, 'covariance_'))

assert_array_almost_equal(
clf.covariances_[0],
clf.covariance_[0],
np.array([[0.7, 0.45], [0.45, 0.7]])
)

assert_array_almost_equal(
clf.covariances_[1],
clf.covariance_[1],
np.array([[0.33333333, -0.33333333], [-0.33333333, 0.66666667]])
)


def test_qda_deprecation():
# Test the deprecation
clf = QuadraticDiscriminantAnalysis(store_covariances=True)
assert_warns_message(DeprecationWarning, "'store_covariances' was renamed"
" to store_covariance in version 0.19 and will be "
"removed in 0.21.", clf.fit, X, y)

# check that covariance_ (and covariances_ with warning) is stored
assert_warns_message(DeprecationWarning, "Attribute covariances_ was "
"deprecated in version 0.19 and will be removed "
"in 0.21. Use covariance_ instead", getattr, clf,
'covariances_')


def test_qda_regularization():
# the default is reg_param=0. and will cause issues
# when there is a constant variable
Expand Down

0 comments on commit 2488540

Please sign in to comment.