Skip to content

Commit

Permalink
Added specific spec tests. Added more extensions. Fix for preserve de…
Browse files Browse the repository at this point in the history
…faults.
  • Loading branch information
interpret-ml committed Sep 17, 2019
1 parent 021ecad commit e318e7b
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 69 deletions.
2 changes: 1 addition & 1 deletion python/interpret-core/interpret/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ExplainerMixin(ABC):
available_explanations: A list of strings subsetting the following
- "perf", "data", "local", "global".
explainer_type: A string that is one of the following
- "blackbox", "model", "specific".
- "blackbox", "model", "specific", "data", "perf".
"""

@property
Expand Down
2 changes: 1 addition & 1 deletion python/interpret-core/interpret/blackbox/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ..utils import gen_name_from_class, gen_local_selector
from ..utils import perf_dict
from ..utils import unify_data, unify_predict_fn
from lime.lime_tabular import LimeTabularExplainer
import warnings


Expand All @@ -27,6 +26,7 @@ def __init__(
n_jobs=1,
**kwargs
):
from lime.lime_tabular import LimeTabularExplainer

self.data, _, self.feature_names, self.feature_types = unify_data(
data, None, feature_names, feature_types
Expand Down
68 changes: 4 additions & 64 deletions python/interpret-core/interpret/blackbox/shap.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# Copyright (c) 2019 Microsoft Corporation
# Distributed under the MIT software license
from ..utils.shap import shap_explain_local

from ..api.base import ExplainerMixin
from ..api.templates import FeatureValueExplanation
from ..utils import unify_predict_fn, unify_data
from ..utils import perf_dict, gen_name_from_class, gen_local_selector
import warnings

import shap


class ShapKernel(ExplainerMixin):
available_explanations = ["local"]
Expand All @@ -26,6 +23,8 @@ def __init__(
**kwargs
):

import shap

self.data, _, self.feature_names, self.feature_types = unify_data(
data, None, feature_names, feature_types
)
Expand All @@ -41,63 +40,4 @@ def __init__(
self.shap = shap.KernelExplainer(self.predict_fn, data, **self.kwargs)

def explain_local(self, X, y=None, name=None):
if name is None:
name = gen_name_from_class(self)
X, y, _, _ = unify_data(X, y, self.feature_names, self.feature_types)

all_shap_values = self.shap.shap_values(X)
predictions = self.predict_fn(X)

data_dicts = []
scores_list = all_shap_values
perf_list = []
for i, instance in enumerate(X):
shap_values = all_shap_values[i]
perf_dict_obj = perf_dict(y, predictions, i)

perf_list.append(perf_dict_obj)

data_dict = {
"type": "univariate",
"names": self.feature_names,
"perf": perf_dict_obj,
"scores": shap_values,
"values": instance,
"extra": {
"names": ["Base Value"],
"scores": [self.shap.expected_value],
"values": [1],
},
}
data_dicts.append(data_dict)

internal_obj = {
"overall": None,
"specific": data_dicts,
"mli": [
{
"explanation_type": "local_feature_importance",
"value": {
"scores": scores_list,
"intercept": self.shap.expected_value,
"perf": perf_list,
},
}
],
}
internal_obj["mli"].append(
{
"explanation_type": "evaluation_dataset",
"value": {"dataset_x": X, "dataset_y": y},
}
)
selector = gen_local_selector(X, y, predictions)

return FeatureValueExplanation(
"local",
internal_obj,
feature_names=self.feature_names,
feature_types=self.feature_types,
name=name,
selector=selector,
)
return shap_explain_local(self, X, y=y, name=name)
10 changes: 10 additions & 0 deletions python/interpret-core/interpret/ext/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2019 Microsoft Corporation
# Distributed under the MIT software license

import sys
from interpret.ext.extension_utils import load_class_extensions
from interpret.ext.extension import DATA_EXTENSION_KEY, _is_valid_data_explainer

load_class_extensions(
sys.modules[__name__], DATA_EXTENSION_KEY, _is_valid_data_explainer
)
10 changes: 10 additions & 0 deletions python/interpret-core/interpret/ext/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
PROVIDER_EXTENSION_KEY = "interpret_ext_provider"
BLACKBOX_EXTENSION_KEY = "interpret_ext_blackbox"
GREYBOX_EXTENSION_KEY = "interpret_ext_greybox"
DATA_EXTENSION_KEY = "interpret_ext_data"
PERF_EXTENSION_KEY = "interpret_ext_perf"


def _is_valid_explainer(target_explainer_type, proposed_explainer):
Expand Down Expand Up @@ -45,6 +47,14 @@ def _is_valid_greybox_explainer(proposed_explainer):
return _is_valid_explainer("specific", proposed_explainer)


def _is_valid_data_explainer(proposed_explainer):
return _is_valid_explainer("data", proposed_explainer)


def _is_valid_perf_explainer(proposed_explainer):
return _is_valid_explainer("perf", proposed_explainer)


def _is_valid_provider(proposed_provider):
try:
has_render_method = hasattr(proposed_provider, "render")
Expand Down
4 changes: 2 additions & 2 deletions python/interpret-core/interpret/ext/greybox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import sys
from interpret.ext.extension_utils import load_class_extensions
from interpret.ext.extension import BLACKBOX_EXTENSION_KEY, _is_valid_blackbox_explainer
from interpret.ext.extension import GREYBOX_EXTENSION_KEY, _is_valid_greybox_explainer

load_class_extensions(
sys.modules[__name__], BLACKBOX_EXTENSION_KEY, _is_valid_blackbox_explainer
sys.modules[__name__], GREYBOX_EXTENSION_KEY, _is_valid_greybox_explainer
)
10 changes: 10 additions & 0 deletions python/interpret-core/interpret/ext/perf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2019 Microsoft Corporation
# Distributed under the MIT software license

import sys
from interpret.ext.extension_utils import load_class_extensions
from interpret.ext.extension import PERF_EXTENSION_KEY, _is_valid_perf_explainer

load_class_extensions(
sys.modules[__name__], PERF_EXTENSION_KEY, _is_valid_perf_explainer
)
1 change: 1 addition & 0 deletions python/interpret-core/interpret/greybox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
# Distributed under the MIT software license

from .treeinterpreter import TreeInterpreter # noqa: F401
from .shaptree import ShapTree # noqa: F401
56 changes: 56 additions & 0 deletions python/interpret-core/interpret/greybox/shaptree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2019 Microsoft Corporation
# Distributed under the MIT software license

from ..utils.shap import shap_explain_local
from sklearn.base import is_classifier

from ..api.base import ExplainerMixin
from ..utils import unify_predict_fn, unify_data


class ShapTree(ExplainerMixin):
available_explanations = ["local"]
explainer_type = "specific"

def __init__(
self,
model,
data,
feature_names=None,
feature_types=None,
explain_kwargs=None,
n_jobs=1,
**kwargs
):
import shap

self.model = model
if is_classifier(self):
predict_fn = lambda x: self.model.predict_proba(x)[:, 1]
else:
predict_fn = self.model.predict
self.predict_fn = unify_predict_fn(predict_fn, self.data)

self.data, _, self.feature_names, self.feature_types = unify_data(
data, None, feature_names, feature_types
)
self.n_jobs = n_jobs

self.explain_kwargs = explain_kwargs
self.kwargs = kwargs

self.shap = shap.TreeExplainer(model, data, **self.kwargs)

def explain_local(self, X, y=None, name=None):
""" Provides local explanations for provided instances.
Args:
X: Numpy array for X to explain.
y: Numpy vector for y to explain.
name: User-defined explanation name.
Returns:
An explanation object, visualizing feature-value pairs
for each instance as horizontal bar charts.
"""
return shap_explain_local(self, X, y=y, name=name)
5 changes: 5 additions & 0 deletions python/interpret-core/interpret/provider/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def render(self, explanation, key=-1, **kwargs):
selector_key = kwargs.pop('selector_key', None)
file_name = kwargs.pop('file_name', None)

# NOTE: Preserve didn't support returning everything. If key is -1 default to key is None.
# This is for backward-compatibility. All of this will be deprecated shortly anyway.
if key == -1:
key = None

# Get visual object
visual = explanation.visualize(key=key)

Expand Down
8 changes: 7 additions & 1 deletion python/interpret-core/interpret/test/test_explainers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2019 Microsoft Corporation
# Distributed under the MIT software license

from sklearn.ensemble import RandomForestClassifier

from .utils import synthetic_classification, get_all_explainers
from .utils import assert_valid_explanation, assert_valid_model_explainer
Expand All @@ -10,12 +11,16 @@
import pytest


# TODO: Generalize specific models (currently only testing trees)
@pytest.mark.slow
def test_spec_synthetic():
all_explainers = get_all_explainers()
data = synthetic_classification()

blackbox = LogisticRegression()
blackbox.fit(data["train"]["X"], data["train"]["y"])
tree = RandomForestClassifier()
tree.fit(data["train"]["X"], data["train"]["y"])

predict_fn = lambda x: blackbox.predict_proba(x) # noqa: E731

Expand All @@ -24,9 +29,10 @@ def test_spec_synthetic():
explainer = explainer_class(predict_fn, data["train"]["X"])
elif explainer_class.explainer_type == "model":
explainer = explainer_class()

explainer.fit(data["train"]["X"], data["train"]["y"])
assert_valid_model_explainer(explainer, data["test"]["X"].head())
elif explainer_class.explainer_type == "specific":
explainer = explainer_class(tree, data["train"]["X"])
elif explainer_class.explainer_type == "data":
explainer = explainer_class()
elif explainer_class.explainer_type == "perf":
Expand Down
8 changes: 8 additions & 0 deletions python/interpret-core/interpret/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from ..blackbox import MorrisSensitivity
from ..blackbox import PartialDependence

from ..greybox import TreeInterpreter
from ..greybox import ShapTree

from ..glassbox import LogisticRegression, LinearRegression
from ..glassbox import ClassificationTree, RegressionTree
from ..glassbox import DecisionListClassifier
Expand All @@ -35,6 +38,10 @@ def get_all_explainers():
LinearRegression,
ExplainableBoostingRegressor,
]
specific_explainer_classes = [
TreeInterpreter,
ShapTree,
]
blackbox_explainer_classes = [
LimeTabular,
ShapKernel,
Expand All @@ -43,6 +50,7 @@ def get_all_explainers():
]
all_explainers = []
all_explainers.extend(model_explainer_classes)
all_explainers.extend(specific_explainer_classes)
all_explainers.extend(blackbox_explainer_classes)
all_explainers.extend(data_explainer_classes)
all_explainers.extend(perf_explainer_classes)
Expand Down
68 changes: 68 additions & 0 deletions python/interpret-core/interpret/utils/shap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2019 Microsoft Corporation
# Distributed under the MIT software license

from ..api.templates import FeatureValueExplanation
from . import gen_name_from_class, unify_data, perf_dict, gen_local_selector


def shap_explain_local(explainer, X, y=None, name=None):
if name is None:
name = gen_name_from_class(explainer)
X, y, _, _ = unify_data(X, y, explainer.feature_names, explainer.feature_types)

all_shap_values = explainer.shap.shap_values(X)
predictions = explainer.predict_fn(X)

data_dicts = []
scores_list = all_shap_values
perf_list = []
for i, instance in enumerate(X):
shap_values = all_shap_values[i]
perf_dict_obj = perf_dict(y, predictions, i)

perf_list.append(perf_dict_obj)

data_dict = {
"type": "univariate",
"names": explainer.feature_names,
"perf": perf_dict_obj,
"scores": shap_values,
"values": instance,
"extra": {
"names": ["Base Value"],
"scores": [explainer.shap.expected_value],
"values": [1],
},
}
data_dicts.append(data_dict)

internal_obj = {
"overall": None,
"specific": data_dicts,
"mli": [
{
"explanation_type": "local_feature_importance",
"value": {
"scores": scores_list,
"intercept": explainer.shap.expected_value,
"perf": perf_list,
},
}
],
}
internal_obj["mli"].append(
{
"explanation_type": "evaluation_dataset",
"value": {"dataset_x": X, "dataset_y": y},
}
)
selector = gen_local_selector(X, y, predictions)

return FeatureValueExplanation(
"local",
internal_obj,
feature_names=explainer.feature_names,
feature_types=explainer.feature_types,
name=name,
selector=selector,
)

0 comments on commit e318e7b

Please sign in to comment.