forked from interpretml/interpret
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added specific spec tests. Added more extensions. Fix for preserve de…
…faults.
- Loading branch information
1 parent
021ecad
commit e318e7b
Showing
13 changed files
with
183 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) 2019 Microsoft Corporation | ||
# Distributed under the MIT software license | ||
|
||
import sys | ||
from interpret.ext.extension_utils import load_class_extensions | ||
from interpret.ext.extension import DATA_EXTENSION_KEY, _is_valid_data_explainer | ||
|
||
load_class_extensions( | ||
sys.modules[__name__], DATA_EXTENSION_KEY, _is_valid_data_explainer | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) 2019 Microsoft Corporation | ||
# Distributed under the MIT software license | ||
|
||
import sys | ||
from interpret.ext.extension_utils import load_class_extensions | ||
from interpret.ext.extension import PERF_EXTENSION_KEY, _is_valid_perf_explainer | ||
|
||
load_class_extensions( | ||
sys.modules[__name__], PERF_EXTENSION_KEY, _is_valid_perf_explainer | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (c) 2019 Microsoft Corporation | ||
# Distributed under the MIT software license | ||
|
||
from ..utils.shap import shap_explain_local | ||
from sklearn.base import is_classifier | ||
|
||
from ..api.base import ExplainerMixin | ||
from ..utils import unify_predict_fn, unify_data | ||
|
||
|
||
class ShapTree(ExplainerMixin): | ||
available_explanations = ["local"] | ||
explainer_type = "specific" | ||
|
||
def __init__( | ||
self, | ||
model, | ||
data, | ||
feature_names=None, | ||
feature_types=None, | ||
explain_kwargs=None, | ||
n_jobs=1, | ||
**kwargs | ||
): | ||
import shap | ||
|
||
self.model = model | ||
if is_classifier(self): | ||
predict_fn = lambda x: self.model.predict_proba(x)[:, 1] | ||
else: | ||
predict_fn = self.model.predict | ||
self.predict_fn = unify_predict_fn(predict_fn, self.data) | ||
|
||
self.data, _, self.feature_names, self.feature_types = unify_data( | ||
data, None, feature_names, feature_types | ||
) | ||
self.n_jobs = n_jobs | ||
|
||
self.explain_kwargs = explain_kwargs | ||
self.kwargs = kwargs | ||
|
||
self.shap = shap.TreeExplainer(model, data, **self.kwargs) | ||
|
||
def explain_local(self, X, y=None, name=None): | ||
""" Provides local explanations for provided instances. | ||
Args: | ||
X: Numpy array for X to explain. | ||
y: Numpy vector for y to explain. | ||
name: User-defined explanation name. | ||
Returns: | ||
An explanation object, visualizing feature-value pairs | ||
for each instance as horizontal bar charts. | ||
""" | ||
return shap_explain_local(self, X, y=y, name=name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) 2019 Microsoft Corporation | ||
# Distributed under the MIT software license | ||
|
||
from ..api.templates import FeatureValueExplanation | ||
from . import gen_name_from_class, unify_data, perf_dict, gen_local_selector | ||
|
||
|
||
def shap_explain_local(explainer, X, y=None, name=None): | ||
if name is None: | ||
name = gen_name_from_class(explainer) | ||
X, y, _, _ = unify_data(X, y, explainer.feature_names, explainer.feature_types) | ||
|
||
all_shap_values = explainer.shap.shap_values(X) | ||
predictions = explainer.predict_fn(X) | ||
|
||
data_dicts = [] | ||
scores_list = all_shap_values | ||
perf_list = [] | ||
for i, instance in enumerate(X): | ||
shap_values = all_shap_values[i] | ||
perf_dict_obj = perf_dict(y, predictions, i) | ||
|
||
perf_list.append(perf_dict_obj) | ||
|
||
data_dict = { | ||
"type": "univariate", | ||
"names": explainer.feature_names, | ||
"perf": perf_dict_obj, | ||
"scores": shap_values, | ||
"values": instance, | ||
"extra": { | ||
"names": ["Base Value"], | ||
"scores": [explainer.shap.expected_value], | ||
"values": [1], | ||
}, | ||
} | ||
data_dicts.append(data_dict) | ||
|
||
internal_obj = { | ||
"overall": None, | ||
"specific": data_dicts, | ||
"mli": [ | ||
{ | ||
"explanation_type": "local_feature_importance", | ||
"value": { | ||
"scores": scores_list, | ||
"intercept": explainer.shap.expected_value, | ||
"perf": perf_list, | ||
}, | ||
} | ||
], | ||
} | ||
internal_obj["mli"].append( | ||
{ | ||
"explanation_type": "evaluation_dataset", | ||
"value": {"dataset_x": X, "dataset_y": y}, | ||
} | ||
) | ||
selector = gen_local_selector(X, y, predictions) | ||
|
||
return FeatureValueExplanation( | ||
"local", | ||
internal_obj, | ||
feature_names=explainer.feature_names, | ||
feature_types=explainer.feature_types, | ||
name=name, | ||
selector=selector, | ||
) |