Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add lightgbm.booster support #270

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add test_explain_prediction_booster_binary
  • Loading branch information
qh582 authored May 31, 2018
commit f0533992b50dcb248bd3ff19a9fba13f08ba22d0
30 changes: 30 additions & 0 deletions tests/test_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,33 @@ def test_explain_prediction_booster_multitarget(newsgroups_train):
assert len(top_target_res.targets) == 2
assert sorted(t.proba for t in top_target_res.targets) == sorted(
t.proba for t in res.targets)[-2:]

def test_explain_prediction_booster_binary(
newsgroups_train_binary_big):
docs, ys, target_names = newsgroups_train_binary_big
vec = CountVectorizer(stop_words='english', dtype=np.float64)
xs = vec.fit_transform(docs)
explain_kwargs = {}
clf = lightgbm.train(
params={'objective': 'binary', 'verbose_eval': -1, 'max_depth': 2,'n_estimators':100,
'min_child_samples':1, 'min_child_weight':1},
train_set=lightgbm.Dataset(xs.toarray(), label=ys))

get_res = lambda **kwargs: explain_prediction(
clf, 'computer graphics in space: a sign of atheism',
vec=vec, target_names=target_names, **kwargs)
res = get_res()
for expl in format_as_all(res, clf, show_feature_values=True):
assert 'graphics' in expl
check_targets_scores(res)
weights = res.targets[0].feature_weights
pos_features = get_all_features(weights.pos)
neg_features = get_all_features(weights.neg)
assert 'graphics' in pos_features
assert 'computer' in pos_features
assert 'atheism' in neg_features

flt_res = get_res(feature_re='gra')
flt_pos_features = get_all_features(flt_res.targets[0].feature_weights.pos)
assert 'graphics' in flt_pos_features
assert 'computer' not in flt_pos_features