Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add lightgbm.booster support #270

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add _lgb_n_targets test
  • Loading branch information
qh582 authored May 31, 2018
commit 3c213e93f63a56491ff25a2ac066019196546ef1
18 changes: 17 additions & 1 deletion tests/test_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lightgbm import LGBMClassifier, LGBMRegressor

from eli5 import explain_weights, explain_prediction
from eli5.lightgbm import _check_booster_args
from eli5.lightgbm import _check_booster_args, _lgb_n_targets
from .test_sklearn_explain_weights import (
test_explain_tree_classifier as _check_rf_classifier,
test_explain_random_forest_and_tree_feature_filter as _check_rf_feature_filter,
Expand Down Expand Up @@ -258,3 +258,19 @@ def test_explain_prediction_booster_binary(
flt_pos_features = get_all_features(flt_res.targets[0].feature_weights.pos)
assert 'graphics' in flt_pos_features
assert 'computer' not in flt_pos_features

def test_lgb_n_targets():
clf = LGBMClassifier(min_data=1)
clf.fit(np.array([[0], [1]]), np.array([0, 1]))
assert _lgb_n_targets(clf) == 1

clf = LGBMClassifier(min_data=1)
clf.fit(np.array([[0], [1], [2]]), np.array([0, 1, 2]))
assert _lgb_n_targets(clf) == 3

reg = LGBMRegressor(min_data=1)
reg.fit(np.array([[0], [1], [2]]), np.array([0, 1, 2]))
assert _lgb_n_targets(reg) == 1

with pytest.raises(TypeError):
_lgb_n_targets(object())