Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add lightgbm.booster support #270

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 65 additions & 23 deletions eli5/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
all values sum to 1.
"""


@explain_weights.register(lightgbm.Booster)
@explain_weights.register(lightgbm.LGBMClassifier)
@explain_weights.register(lightgbm.LGBMRegressor)
def explain_weights_lightgbm(lgb,
Expand All @@ -32,7 +32,7 @@ def explain_weights_lightgbm(lgb,
):
"""
Return an explanation of an LightGBM estimator (via scikit-learn wrapper
LGBMClassifier or LGBMRegressor) as feature importances.
LGBMClassifier or LGBMRegressor, or via lightgbm.Booster) as feature importances.

See :func:`eli5.explain_weights` for description of
``top``, ``feature_names``,
Expand All @@ -51,8 +51,9 @@ def explain_weights_lightgbm(lgb,
across all trees
- 'weight' - the same as 'split', for compatibility with xgboost
"""
coef = _get_lgb_feature_importances(lgb, importance_type)
lgb_feature_names = lgb.booster_.feature_name()
booster, is_regression = _check_booster_args(lgb)
coef = _get_lgb_feature_importances(booster, importance_type)
lgb_feature_names = booster.feature_name()
return get_feature_importance_explanation(lgb, vec, coef,
feature_names=feature_names,
estimator_feature_names=lgb_feature_names,
Expand All @@ -64,7 +65,7 @@ def explain_weights_lightgbm(lgb,
is_regression=isinstance(lgb, lightgbm.LGBMRegressor),
)


@explain_prediction.register(lightgbm.Booster)
@explain_prediction.register(lightgbm.LGBMClassifier)
@explain_prediction.register(lightgbm.LGBMRegressor)
def explain_prediction_lightgbm(
Expand All @@ -80,7 +81,7 @@ def explain_prediction_lightgbm(
vectorized=False,
):
""" Return an explanation of LightGBM prediction (via scikit-learn wrapper
LGBMClassifier or LGBMRegressor) as feature weights.
LGBMClassifier or LGBMRegressor, or via lightgbm.Booster) as feature weights.

See :func:`eli5.explain_prediction` for description of
``top``, ``top_targets``, ``target_names``, ``targets``,
Expand Down Expand Up @@ -108,20 +109,45 @@ def explain_prediction_lightgbm(
Weights of all features sum to the output score of the estimator.
"""

vec, feature_names = handle_vec(lgb, doc, vec, vectorized, feature_names)
booster, is_regression = _check_booster_args(lgb)
lgb_feature_names = booster.feature_name()
vec, feature_names = handle_vec(lgb, doc, vec, vectorized, feature_names,
num_features=len(lgb_feature_names))
if feature_names.bias_name is None:
# LightGBM estimators do not have an intercept, but here we interpret
# them as having an intercept
feature_names.bias_name = '<BIAS>'
X = get_X(doc, vec, vectorized=vectorized)

if isinstance(lgb, lightgbm.Booster):
prediction = lgb.predict(X)
n_targets = prediction.shape[-1]
if is_regression is None:
# When n_targets is 1, this can be classification too,
# but it's safer to assume regression.
# If n_targets > 1, it must be classification.
is_regression = n_targets == 1
if is_regression:
proba = None
else:
if n_targets == 1:
p, = prediction
proba = np.array([1 - p, p])
else:
proba, = prediction
else:
proba = predict_proba(lgb, X)
n_targets = _lgb_n_targets(lgb)

proba = predict_proba(lgb, X)
weight_dicts = _get_prediction_feature_weights(lgb, X, _lgb_n_targets(lgb))
x = get_X0(add_intercept(X))
if is_regression:
names = ['y']
elif isinstance(lgb, lightgbm.Booster):
names = np.arange(max(2, n_targets))
else:
names = lgb.classes_

is_regression = isinstance(lgb, lightgbm.LGBMRegressor)
is_multiclass = _lgb_n_targets(lgb) > 2
names = lgb.classes_ if not is_regression else ['y']
weight_dicts = _get_prediction_feature_weights(booster, X, n_targets)
x = get_X0(add_intercept(X))

def get_score_weights(_label_id):
_weights = _target_feature_weights(
Expand All @@ -144,23 +170,39 @@ def get_score_weights(_label_id):
target_names=target_names,
targets=targets,
top_targets=top_targets,
is_regression=is_regression,
is_multiclass=is_multiclass,
is_regression=isinstance(lgb, lightgbm.LGBMRegressor),
is_multiclass=n_targets > 1,
proba=proba,
get_score_weights=get_score_weights,
)


def _check_booster_args(lgb, is_regression=None):
# type: (Any, bool) -> Tuple[Booster, bool]
Copy link
Contributor

@lopuhin lopuhin May 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope something like # type: (Any, bool) -> Tuple[lightgbm.Booster, bool] and adding Tuple and Any to typing imports at the top should fix the build (https://travis-ci.org/TeamHG-Memex/eli5/jobs/385021206#L764-L766). To check locally, you can try installing mypy==0.550 and running mypy --check-untyped-defs eli5

if isinstance(lgb, lightgbm.Booster):
booster = lgb
else:
booster = lgb.booster()
_is_regression = isinstance(lgb, lightgbm.LGBMRegressor)
if is_regression is not None and is_regression != _is_regression:
raise ValueError(
'Inconsistent is_regression={} passed. '
'You don\'t have to pass it when using scikit-learn API'
.format(is_regression))
is_regression = _is_regression
return booster, is_regression

def _lgb_n_targets(lgb):
if isinstance(lgb, lightgbm.LGBMClassifier):
return lgb.n_classes_
else:
return 1 if lgb.n_classes_ == 2 else lgb.n_classes_
elif isinstance(lgb, lightgbm.LGBMRegressor):
return 1
else:
raise TypeError


def _get_lgb_feature_importances(lgb, importance_type):
def _get_lgb_feature_importances(booster, importance_type):
aliases = {'weight': 'split'}
coef = lgb.booster_.feature_importance(
coef = booster.feature_importance(
importance_type=aliases.get(importance_type, importance_type)
)
norm = coef.sum()
Expand Down Expand Up @@ -237,17 +279,17 @@ def walk(tree, parent_id=-1):
return leaf_index, split_index


def _get_prediction_feature_weights(lgb, X, n_targets):
def _get_prediction_feature_weights(booster, X, n_targets):
"""
Return a list of {feat_id: value} dicts with feature weights,
following ideas from http://blog.datadive.net/interpreting-random-forests/
"""
if n_targets == 2:
n_targets = 1
dump = lgb.booster_.dump_model()
dump = booster.dump_model()
tree_info = dump['tree_info']
_compute_node_values(tree_info)
pred_leafs = lgb.booster_.predict(X, pred_leaf=True).reshape(-1, n_targets)
pred_leafs = booster.predict(X, pred_leaf=True).reshape(-1, n_targets)
tree_info = np.array(tree_info).reshape(-1, n_targets)
assert pred_leafs.shape == tree_info.shape

Expand Down