Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove internal state from AnchorImage explainer #460

Merged
merged 13 commits into from
Aug 4, 2021
Prev Previous commit
Next Next commit
Add __call__ and revert black
  • Loading branch information
Adrian Gonzalez-Martin committed Aug 2, 2021
commit 904f46c68876469aa606b8518e218f2243c16d1c
78 changes: 30 additions & 48 deletions alibi/explainers/anchor_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,16 @@
logger = logging.getLogger(__name__)

DEFAULT_SEGMENTATION_KWARGS = {
"felzenszwalb": {},
"quickshift": {},
"slic": {"n_segments": 10, "compactness": 10, "sigma": 0.5},
'felzenszwalb': {},
'quickshift': {},
'slic': {'n_segments': 10, 'compactness': 10, 'sigma': .5}
}


class AnchorImage(Explainer):
def __init__(
self,
predictor: Callable,
image_shape: tuple,
segmentation_fn: Any = "slic",
segmentation_kwargs: dict = None,
images_background: np.ndarray = None,
seed: int = None,
) -> None:
def __init__(self, predictor: Callable, image_shape: tuple, segmentation_fn: Any = 'slic',
segmentation_kwargs: dict = None, images_background: np.ndarray = None,
seed: int = None) -> None:
"""
Initialize anchor image explainer.

Expand Down Expand Up @@ -62,54 +56,48 @@ def __init__(
segmentation_kwargs = DEFAULT_SEGMENTATION_KWARGS[segmentation_fn] # type: ignore
except KeyError:
logger.warning(
"DEFAULT_SEGMENTATION_KWARGS did not contain any entry"
"for segmentation method {}. No kwargs will be passed to"
"the segmentation function!".format(segmentation_fn)
'DEFAULT_SEGMENTATION_KWARGS did not contain any entry'
'for segmentation method {}. No kwargs will be passed to'
'the segmentation function!'.format(segmentation_fn)
)
segmentation_kwargs = {}
elif callable(segmentation_fn) and segmentation_kwargs:
logger.warning(
"Specified both a segmentation function to create superpixels and "
"keyword arguments for built segmentation functions. By default "
"the specified segmentation function will be used."
'Specified both a segmentation function to create superpixels and '
'keyword arguments for built segmentation functions. By default '
'the specified segmentation function will be used.'
)

# set the predictor
self.image_shape = image_shape
self.predictor = self._transform_predictor(predictor)

# segmentation function is either a user-defined function or one of the values in
fn_options = {
"felzenszwalb": felzenszwalb,
"slic": slic,
"quickshift": quickshift,
}
fn_options = {'felzenszwalb': felzenszwalb, 'slic': slic, 'quickshift': quickshift}
if callable(segmentation_fn):
self.custom_segmentation = True
self.segmentation_fn = segmentation_fn
else:
self.custom_segmentation = False
self.segmentation_fn = partial(
fn_options[segmentation_fn], **segmentation_kwargs
)
self.segmentation_fn = partial(fn_options[segmentation_fn], **segmentation_kwargs)

self.images_background = images_background
# a superpixel is perturbed with prob 1 - p_sample
self.p_sample = 0.5 # type: float

# update metadata
self.meta["params"].update(
self.meta['params'].update(
custom_segmentation=self.custom_segmentation,
segmentation_kwargs=segmentation_kwargs,
p_sample=self.p_sample,
seed=seed,
image_shape=self.image_shape,
images_background=self.images_background,
images_background=self.images_background
)
if not self.custom_segmentation:
self.meta["params"].update(segmentation_fn=segmentation_fn)
self.meta['params'].update(segmentation_fn=segmentation_fn)
else:
self.meta["params"].update(segmentation_fn="custom")
self.meta['params'].update(segmentation_fn='custom')

def generate_superpixels(self, image: np.ndarray) -> np.ndarray:
"""
Expand Down Expand Up @@ -221,7 +209,7 @@ def explain(
"""
# get params for storage in meta
params = locals()
remove = ["image", "self"]
remove = ['image', 'self']
for key in remove:
params.pop(key)

Expand All @@ -237,11 +225,10 @@ def explain(

# get anchors and add metadata
mab = AnchorBaseBeam(
samplers=[sampler.sample],
samplers=[sampler],
sample_cache_size=binary_cache_size,
cache_margin=cache_margin,
**kwargs,
)
**kwargs)
result = mab.anchor_beam(
desired_confidence=threshold,
delta=delta,
Expand Down Expand Up @@ -285,13 +272,13 @@ def build_explanation(
Parameters passed to `explain`
"""

result["instance"] = image
result["instances"] = np.expand_dims(image, 0)
result["prediction"] = np.array([predicted_label])
result['instance'] = image
result['instances'] = np.expand_dims(image, 0)
result['prediction'] = np.array([predicted_label])

# overlay image with anchor mask
anchor = self.overlay_mask(image, sampler.segments, result["feature"])
exp = AnchorExplanation("image", result)
anchor = self.overlay_mask(image, sampler.segments, result['feature'])
exp = AnchorExplanation('image', result)

# output explanation dictionary
data = copy.deepcopy(DEFAULT_DATA_ANCHOR_IMG)
Expand All @@ -300,23 +287,18 @@ def build_explanation(
segments=sampler.segments,
precision=exp.precision(),
coverage=exp.coverage(),
raw=exp.exp_map,
raw=exp.exp_map
)

# create explanation object
explanation = Explanation(meta=copy.deepcopy(self.meta), data=data)

# params passed to explain
explanation.meta["params"].update(params)
explanation.meta['params'].update(params)
return explanation

def overlay_mask(
self,
image: np.ndarray,
segments: np.ndarray,
mask_features: list,
scale: tuple = (0, 255),
) -> np.ndarray:
def overlay_mask(self, image: np.ndarray, segments: np.ndarray, mask_features: list,
scale: tuple = (0, 255)) -> np.ndarray:
"""
Overlay image with mask described by the mask features.

Expand Down
2 changes: 1 addition & 1 deletion alibi/explainers/anchor_image_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self.segment_labels = list(np.unique(self.segments))
self.instance_label = self.predictor(image[np.newaxis, ...])[0]

def sample(
def __call__(
self, anchor: Tuple[int, tuple], num_samples: int, compute_labels: bool = True
) -> List[Union[np.ndarray, float, int]]:
"""
Expand Down
2 changes: 1 addition & 1 deletion alibi/explainers/tests/test_anchor_image_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_sampler(models, mnist_data):
assert superpixels_mask.sum(axis=1).any() <= segmentation_kwargs["n_segments"]
assert superpixels_mask.any() <= 1

cov_true, cov_false, labels, data, coverage, _ = sampler.sample(
cov_true, cov_false, labels, data, coverage, _ = sampler(
(0, ()), num_samples
)
assert data.shape[0] == labels.shape[0]
Expand Down