Skip to content

Commit

Permalink
FIX Issue scikit-learn#8173 - pass n_neighbors in MI computation (sci…
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored and raghavrv committed Jan 19, 2017
1 parent 4826883 commit aaebee1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sklearn/feature_selection/mutual_info_.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _estimate_mi(X, y, discrete_features='auto', discrete_target=False,
y = scale(y, with_mean=False)
y += 1e-10 * np.maximum(1, np.mean(np.abs(y))) * rng.randn(n_samples)

mi = [_compute_mi(x, y, discrete_feature, discrete_target) for
mi = [_compute_mi(x, y, discrete_feature, discrete_target, n_neighbors) for
x, discrete_feature in moves.zip(_iterate_columns(X), discrete_mask)]

return np.array(mi)
Expand Down
16 changes: 14 additions & 2 deletions sklearn/feature_selection/tests/test_mutual_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from scipy.sparse import csr_matrix

from sklearn.utils.testing import (assert_array_equal, assert_almost_equal,
assert_false, assert_raises, assert_equal)
assert_false, assert_raises, assert_equal,
assert_allclose, assert_greater)
from sklearn.feature_selection.mutual_info_ import (
mutual_info_regression, mutual_info_classif, _compute_mi)

Expand Down Expand Up @@ -158,8 +159,19 @@ def test_mutual_info_classif_mixed():
y = ((0.5 * X[:, 0] + X[:, 2]) > 0.5).astype(int)
X[:, 2] = X[:, 2] > 0.5

mi = mutual_info_classif(X, y, discrete_features=[2], random_state=0)
mi = mutual_info_classif(X, y, discrete_features=[2], n_neighbors=3,
random_state=0)
assert_array_equal(np.argsort(-mi), [2, 0, 1])
for n_neighbors in [5, 7, 9]:
mi_nn = mutual_info_classif(X, y, discrete_features=[2],
n_neighbors=n_neighbors, random_state=0)
# Check that the continuous values have an higher MI with greater
# n_neighbors
assert_greater(mi_nn[0], mi[0])
assert_greater(mi_nn[1], mi[1])
# The n_neighbors should not have any effect on the discrete value
# The MI should be the same
assert_equal(mi_nn[2], mi[2])


def test_mutual_info_options():
Expand Down

0 comments on commit aaebee1

Please sign in to comment.