From 2529d4e69f9de987c75aa2951e883b0b60382867 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Fri, 30 Jul 2021 11:33:49 +0200 Subject: [PATCH 01/13] Refactor sampler into its own class --- alibi/explainers/anchor_image.py | 341 +++++++---------------- alibi/explainers/anchor_image_sampler.py | 260 +++++++++++++++++ alibi/explainers/anchor_image_utils.py | 25 ++ 3 files changed, 378 insertions(+), 248 deletions(-) create mode 100644 alibi/explainers/anchor_image_sampler.py create mode 100644 alibi/explainers/anchor_image_utils.py diff --git a/alibi/explainers/anchor_image.py b/alibi/explainers/anchor_image.py index d730897ee..bcb208034 100644 --- a/alibi/explainers/anchor_image.py +++ b/alibi/explainers/anchor_image.py @@ -4,29 +4,36 @@ import numpy as np from functools import partial -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable from alibi.utils.wrappers import ArgmaxTransformer from alibi.api.interfaces import Explainer, Explanation from alibi.api.defaults import DEFAULT_META_ANCHOR, DEFAULT_DATA_ANCHOR_IMG from .anchor_base import AnchorBaseBeam from .anchor_explanation import AnchorExplanation +from .anchor_image_utils import scale_image +from .anchor_image_sampler import AnchorImageSampler from skimage.segmentation import felzenszwalb, slic, quickshift logger = logging.getLogger(__name__) DEFAULT_SEGMENTATION_KWARGS = { - 'felzenszwalb': {}, - 'quickshift': {}, - 'slic': {'n_segments': 10, 'compactness': 10, 'sigma': .5} + "felzenszwalb": {}, + "quickshift": {}, + "slic": {"n_segments": 10, "compactness": 10, "sigma": 0.5}, } class AnchorImage(Explainer): - - def __init__(self, predictor: Callable, image_shape: tuple, segmentation_fn: Any = 'slic', - segmentation_kwargs: dict = None, images_background: np.ndarray = None, - seed: int = None) -> None: + def __init__( + self, + predictor: Callable, + image_shape: tuple, + segmentation_fn: Any = "slic", + segmentation_kwargs: dict = None, + images_background: np.ndarray = None, + seed: int = None, + ) -> None: """ Initialize anchor image explainer. @@ -55,16 +62,16 @@ def __init__(self, predictor: Callable, image_shape: tuple, segmentation_fn: Any segmentation_kwargs = DEFAULT_SEGMENTATION_KWARGS[segmentation_fn] # type: ignore except KeyError: logger.warning( - 'DEFAULT_SEGMENTATION_KWARGS did not contain any entry' - 'for segmentation method {}. No kwargs will be passed to' - 'the segmentation function!'.format(segmentation_fn) + "DEFAULT_SEGMENTATION_KWARGS did not contain any entry" + "for segmentation method {}. No kwargs will be passed to" + "the segmentation function!".format(segmentation_fn) ) segmentation_kwargs = {} elif callable(segmentation_fn) and segmentation_kwargs: logger.warning( - 'Specified both a segmentation function to create superpixels and ' - 'keyword arguments for built segmentation functions. By default ' - 'the specified segmentation function will be used.' + "Specified both a segmentation function to create superpixels and " + "keyword arguments for built segmentation functions. By default " + "the specified segmentation function will be used." ) # set the predictor @@ -72,35 +79,37 @@ def __init__(self, predictor: Callable, image_shape: tuple, segmentation_fn: Any self.predictor = self._transform_predictor(predictor) # segmentation function is either a user-defined function or one of the values in - fn_options = {'felzenszwalb': felzenszwalb, 'slic': slic, 'quickshift': quickshift} + fn_options = { + "felzenszwalb": felzenszwalb, + "slic": slic, + "quickshift": quickshift, + } if callable(segmentation_fn): self.custom_segmentation = True self.segmentation_fn = segmentation_fn else: self.custom_segmentation = False - self.segmentation_fn = partial(fn_options[segmentation_fn], **segmentation_kwargs) + self.segmentation_fn = partial( + fn_options[segmentation_fn], **segmentation_kwargs + ) self.images_background = images_background - # [H, W] int array; each int is a superpixel labels - self.segments = None # type: np.ndarray - self.segment_labels = None # type: list - self.image = None # type: np.ndarray # a superpixel is perturbed with prob 1 - p_sample self.p_sample = 0.5 # type: float # update metadata - self.meta['params'].update( + self.meta["params"].update( custom_segmentation=self.custom_segmentation, segmentation_kwargs=segmentation_kwargs, p_sample=self.p_sample, seed=seed, image_shape=self.image_shape, - images_background=self.images_background + images_background=self.images_background, ) if not self.custom_segmentation: - self.meta['params'].update(segmentation_fn=segmentation_fn) + self.meta["params"].update(segmentation_fn=segmentation_fn) else: - self.meta['params'].update(segmentation_fn='custom') + self.meta["params"].update(segmentation_fn="custom") def generate_superpixels(self, image: np.ndarray) -> np.ndarray: """ @@ -142,184 +151,26 @@ def _preprocess_img(self, image: np.ndarray) -> np.ndarray: return image_preproc - def _choose_superpixels(self, num_samples: int, p_sample: float = 0.5) -> np.ndarray: - """ - Generates a binary mask of dimension [num_samples, M] where M is the number of - image superpixels (segments). - - Parameters - ---------- - num_samples - Number of perturbed images to be generated - p_sample: - The probability that a superpixel is perturbed - - Returns - ------- - data - Binary 2D mask, where each non-zero entry in a row indicates that - the values of the particular image segment will not be perturbed. - """ - - n_features = len(self.segment_labels) - data = np.random.choice([0, 1], num_samples * n_features, p=[p_sample, 1 - p_sample]) - data = data.reshape((num_samples, n_features)) - - return data - - def sampler(self, anchor: Tuple[int, tuple], num_samples: int, compute_labels: bool = True) -> \ - Union[List[Union[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float, int]], List[np.ndarray]]: - """ - Sample images from a perturbation distribution by masking randomly chosen superpixels - from the original image and replacing them with pixel values from superimposed images - if background images are provided to the explainer. Otherwise, the superpixels from the - original image are replaced with their average values. - - Parameters - ---------- - anchor - int: order of anchor in the batch - tuple: features (= superpixels) present in the proposed anchor - num_samples - Number of samples used - compute_labels - If True, an array of comparisons between predictions on perturbed samples and - instance to be explained is returned. - - Returns - ------- - If compute_labels=True, a list containing the following is returned: - - covered_true: perturbed examples where the anchor applies and the model prediction - on perturbed is the same as the instance prediction - - covered_false: perturbed examples where the anchor applies and the model prediction - on pertrurbed sample is NOT the same as the instance prediction - - labels: num_samples ints indicating whether the prediction on the perturbed sample - matches (1) the label of the instance to be explained or not (0) - - data: Matrix with 1s and 0s indicating whether the values in a superpixel will - remain unchanged (1) or will be perturbed (0), for each sample - - 1.0: indicates exact coverage is not computed for this algorithm - - anchor[0]: position of anchor in the batch request - Otherwise, a list containing the data matrix only is returned. - """ - - if compute_labels: - raw_data, data = self.perturbation(anchor[1], num_samples) - labels = self.compare_labels(raw_data) - covered_true = raw_data[labels][:self.n_covered_ex] - covered_true = [self._scale(img) for img in covered_true] - covered_false = raw_data[np.logical_not(labels)][:self.n_covered_ex] - covered_false = [self._scale(img) for img in covered_false] - # coverage set to -1.0 as we can't compute 'true'coverage for this model - - return [covered_true, covered_false, labels.astype(int), data, -1.0, anchor[0]] # type: ignore - - else: - data = self._choose_superpixels(num_samples) - data[:, anchor[1]] = 1 # superpixels in candidate anchor are not perturbed - - return [data] - - def perturbation(self, anchor: tuple, num_samples: int) -> Tuple[np.ndarray, np.ndarray]: - """ - Perturbs an image by altering the values of selected superpixels. If a dataset of image - backgrounds is provided to the explainer, then the superpixels are replaced with the - equivalent superpixels from the background image. Otherwise, the superpixels are replaced - by their average value. - - Parameters - ---------- - anchor: - Contains the superpixels whose values are not going to be perturbed. - num_samples: - Number of perturbed samples to be returned. - - Returns - ------- - imgs - A [num_samples, H, W, C] array of perturbed images. - segments_mask - A [num_samples, M] binary mask, where M is the number of image superpixels - segments. 1 indicates the values in that particular superpixels are not - perturbed. - """ - - image = self.image - segments = self.segments - - # choose superpixels to be perturbed - segments_mask = self._choose_superpixels(num_samples, p_sample=self.p_sample) - segments_mask[:, anchor] = 1 - - # for each sample, need to sample one of the background images if provided - if self.images_background: - backgrounds = np.random.choice( - range(len(self.images_background)), - segments_mask.shape[0], - replace=True, - ) - segments_mask = np.hstack((segments_mask, backgrounds.reshape(-1, 1))) - else: - backgrounds = [None] * segments_mask.shape[0] - # create fudged image where the pixel value in each superpixel is set to the - # average over the superpixel for each channel - fudged_image = image.copy() - n_channels = image.shape[-1] - for x in np.unique(segments): - fudged_image[segments == x] = [np.mean(image[segments == x][:, i]) for i in range(n_channels)] - - pert_imgs = [] - for mask, background_idx in zip(segments_mask, backgrounds): - temp = copy.deepcopy(image) - to_perturb = np.where(mask == 0)[0] - # create mask for each superpixel not present in the sample - mask = np.zeros(segments.shape).astype(bool) - for superpixel in to_perturb: - mask[segments == superpixel] = True - if background_idx: - # replace values with those of background image - temp[mask] = self.images_background[background_idx][mask] - else: - # ... or with the averaged superpixel value - temp[mask] = fudged_image[mask] - pert_imgs.append(temp) - - return np.array(pert_imgs), segments_mask - - def compare_labels(self, samples: np.ndarray) -> np.ndarray: - """ - Compute the agreement between a classifier prediction on an instance to be explained - and the prediction on a set of samples which have a subset of perturbed superpixels. - - Parameters - ---------- - samples - Samples whose labels are to be compared with the instance label. - - Returns - ------- - A boolean array indicating whether the prediction was the same as the instance label. - """ - - return self.predictor(samples) == self.instance_label - - def explain(self, # type: ignore - image: np.ndarray, - p_sample: float = 0.5, - threshold: float = 0.95, - delta: float = 0.1, - tau: float = 0.15, - batch_size: int = 100, - coverage_samples: int = 10000, - beam_size: int = 1, - stop_on_first: bool = False, - max_anchor_size: int = None, - min_samples_start: int = 100, - n_covered_ex: int = 10, - binary_cache_size: int = 10000, - cache_margin: int = 1000, - verbose: bool = False, - verbose_every: int = 1, - **kwargs: Any) -> Explanation: + def explain( + self, # type: ignore + image: np.ndarray, + p_sample: float = 0.5, + threshold: float = 0.95, + delta: float = 0.1, + tau: float = 0.15, + batch_size: int = 100, + coverage_samples: int = 10000, + beam_size: int = 1, + stop_on_first: bool = False, + max_anchor_size: int = None, + min_samples_start: int = 100, + n_covered_ex: int = 10, + binary_cache_size: int = 10000, + cache_margin: int = 1000, + verbose: bool = False, + verbose_every: int = 1, + **kwargs: Any + ) -> Explanation: """ Explain instance and return anchor with metadata. @@ -370,23 +221,27 @@ def explain(self, # type: ignore """ # get params for storage in meta params = locals() - remove = ['image', 'self'] + remove = ["image", "self"] for key in remove: params.pop(key) - self.image = image - self.n_covered_ex = n_covered_ex - self.p_sample = p_sample - self.segments = self.generate_superpixels(image) - self.segment_labels = list(np.unique(self.segments)) - self.instance_label = self.predictor(image[np.newaxis, ...])[0] + sampler = AnchorImageSampler( + predictor=self.predictor, + segmentation_fn=self.segmentation_fn, + custom_segmentation=self.custom_segmentation, + image=image, + images_background=self.images_background, + p_sample=p_sample, + n_covered_ex=n_covered_ex, + ) # get anchors and add metadata mab = AnchorBaseBeam( - samplers=[self.sampler], + samplers=[sampler.sample], sample_cache_size=binary_cache_size, cache_margin=cache_margin, - **kwargs) + **kwargs, + ) result = mab.anchor_beam( desired_confidence=threshold, delta=delta, @@ -403,9 +258,18 @@ def explain(self, # type: ignore ) # type: Any self.mab = mab - return self.build_explanation(image, result, self.instance_label, params) + return self.build_explanation( + image, result, sampler.instance_label, params, sampler + ) - def build_explanation(self, image: np.ndarray, result: dict, predicted_label: int, params: dict) -> Explanation: + def build_explanation( + self, + image: np.ndarray, + result: dict, + predicted_label: int, + params: dict, + sampler: AnchorImageSampler, + ) -> Explanation: """ Uses the metadata returned by the anchor search algorithm together with the instance to be explained to build an explanation object. @@ -422,57 +286,38 @@ def build_explanation(self, image: np.ndarray, result: dict, predicted_label: in Parameters passed to `explain` """ - result['instance'] = image - result['instances'] = np.expand_dims(image, 0) - result['prediction'] = np.array([predicted_label]) + result["instance"] = image + result["instances"] = np.expand_dims(image, 0) + result["prediction"] = np.array([predicted_label]) # overlay image with anchor mask - anchor = self.overlay_mask(image, self.segments, result['feature']) - exp = AnchorExplanation('image', result) + anchor = self.overlay_mask(image, sampler.segments, result["feature"]) + exp = AnchorExplanation("image", result) # output explanation dictionary data = copy.deepcopy(DEFAULT_DATA_ANCHOR_IMG) data.update( anchor=anchor, - segments=self.segments, + segments=sampler.segments, precision=exp.precision(), coverage=exp.coverage(), - raw=exp.exp_map + raw=exp.exp_map, ) # create explanation object explanation = Explanation(meta=copy.deepcopy(self.meta), data=data) # params passed to explain - explanation.meta['params'].update(params) + explanation.meta["params"].update(params) return explanation - @staticmethod - def _scale(image: np.ndarray, scale: tuple = (0, 255)) -> np.ndarray: - """ - Scales an image in a specified range. - - Parameters - ---------- - image - Image to be scale. - scale - The scaling interval. - - Returns - ------- - img_scaled - Scaled image. - """ - - img_max, img_min = image.max(), image.min() - img_std = (image - img_min) / (img_max - img_min) - img_scaled = img_std * (scale[1] - scale[0]) + scale[0] - - return img_scaled - - def overlay_mask(self, image: np.ndarray, segments: np.ndarray, mask_features: list, - scale: tuple = (0, 255)) -> np.ndarray: + def overlay_mask( + self, + image: np.ndarray, + segments: np.ndarray, + mask_features: list, + scale: tuple = (0, 255), + ) -> np.ndarray: """ Overlay image with mask described by the mask features. @@ -496,7 +341,7 @@ def overlay_mask(self, image: np.ndarray, segments: np.ndarray, mask_features: l mask = np.zeros(segments.shape) for f in mask_features: mask[segments == f] = 1 - image = self._scale(image, scale=scale) + image = scale_image(image, scale=scale) masked_image = (image * np.expand_dims(mask, 2)).astype(int) return masked_image diff --git a/alibi/explainers/anchor_image_sampler.py b/alibi/explainers/anchor_image_sampler.py new file mode 100644 index 000000000..0614d5376 --- /dev/null +++ b/alibi/explainers/anchor_image_sampler.py @@ -0,0 +1,260 @@ +import numpy as np +import copy + +from typing import Tuple, Callable, List, Union + +from .anchor_image_utils import scale_image + + +class AnchorImageSampler: + def __init__( + self, + # TODO: Should we call `predictor`, `prediction_fn` instead? + predictor: Callable, + segmentation_fn: Callable, + custom_segmentation: bool, + image: np.ndarray, + images_background: np.ndarray = None, + p_sample: float = 0.5, + n_covered_ex: int = 10, + ): + """ + Initialize anchor image sampler. + + Parameters + ---------- + predictor + A callable that takes a tensor of N data points as inputs and returns N outputs. + segmentation_fn + Function used to segment the images. + image + Image to be explained. + images_background + Images to overlay superpixels on. + p_sample + Probability for a pixel to be represented by the average value of its superpixel. + n_covered_ex + How many examples where anchors apply to store for each anchor sampled during search + (both examples where prediction on samples agrees/disagrees with desired_label are stored). + """ + self.predictor = predictor + self.segmentation_fn = segmentation_fn + self.custom_segmentation = custom_segmentation + self.image = image + self.images_background = images_background + self.n_covered_ex = n_covered_ex + self.p_sample = p_sample + self.segments = self.generate_superpixels(image) + self.segment_labels = list(np.unique(self.segments)) + self.instance_label = self.predictor(image[np.newaxis, ...])[0] + + def sample( + self, anchor: Tuple[int, tuple], num_samples: int, compute_labels: bool = True + ) -> List[Union[np.ndarray, float, int]]: + """ + Sample images from a perturbation distribution by masking randomly chosen superpixels + from the original image and replacing them with pixel values from superimposed images + if background images are provided to the explainer. Otherwise, the superpixels from the + original image are replaced with their average values. + + Parameters + ---------- + anchor + int: order of anchor in the batch + tuple: features (= superpixels) present in the proposed anchor + num_samples + Number of samples used + compute_labels + If True, an array of comparisons between predictions on perturbed samples and + instance to be explained is returned. + + Returns + ------- + If compute_labels=True, a list containing the following is returned: + - covered_true: perturbed examples where the anchor applies and the model prediction + on perturbed is the same as the instance prediction + - covered_false: perturbed examples where the anchor applies and the model prediction + on pertrurbed sample is NOT the same as the instance prediction + - labels: num_samples ints indicating whether the prediction on the perturbed sample + matches (1) the label of the instance to be explained or not (0) + - data: Matrix with 1s and 0s indicating whether the values in a superpixel will + remain unchanged (1) or will be perturbed (0), for each sample + - 1.0: indicates exact coverage is not computed for this algorithm + - anchor[0]: position of anchor in the batch request + Otherwise, a list containing the data matrix only is returned. + """ + + if compute_labels: + raw_data, data = self.perturbation(anchor[1], num_samples) + labels = self.compare_labels(raw_data) + covered_true = raw_data[labels][: self.n_covered_ex] + covered_true = [scale_image(img) for img in covered_true] + covered_false = raw_data[np.logical_not(labels)][: self.n_covered_ex] + covered_false = [scale_image(img) for img in covered_false] + # coverage set to -1.0 as we can't compute 'true'coverage for this model + + return [covered_true, covered_false, labels.astype(int), data, -1.0, anchor[0]] # type: ignore + + else: + data = self._choose_superpixels(num_samples) + data[:, anchor[1]] = 1 # superpixels in candidate anchor are not perturbed + + return [data] + + def compare_labels(self, samples: np.ndarray) -> np.ndarray: + """ + Compute the agreement between a classifier prediction on an instance to be explained + and the prediction on a set of samples which have a subset of perturbed superpixels. + + Parameters + ---------- + samples + Samples whose labels are to be compared with the instance label. + + Returns + ------- + A boolean array indicating whether the prediction was the same as the instance label. + """ + + return self.predictor(samples) == self.instance_label + + def _choose_superpixels( + self, num_samples: int, p_sample: float = 0.5 + ) -> np.ndarray: + """ + Generates a binary mask of dimension [num_samples, M] where M is the number of + image superpixels (segments). + + Parameters + ---------- + num_samples + Number of perturbed images to be generated + p_sample: + The probability that a superpixel is perturbed + + Returns + ------- + data + Binary 2D mask, where each non-zero entry in a row indicates that + the values of the particular image segment will not be perturbed. + """ + + n_features = len(self.segment_labels) + data = np.random.choice( + [0, 1], num_samples * n_features, p=[p_sample, 1 - p_sample] + ) + data = data.reshape((num_samples, n_features)) + + return data + + def perturbation( + self, anchor: tuple, num_samples: int + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Perturbs an image by altering the values of selected superpixels. If a dataset of image + backgrounds is provided to the explainer, then the superpixels are replaced with the + equivalent superpixels from the background image. Otherwise, the superpixels are replaced + by their average value. + + Parameters + ---------- + anchor: + Contains the superpixels whose values are not going to be perturbed. + num_samples: + Number of perturbed samples to be returned. + + Returns + ------- + imgs + A [num_samples, H, W, C] array of perturbed images. + segments_mask + A [num_samples, M] binary mask, where M is the number of image superpixels + segments. 1 indicates the values in that particular superpixels are not + perturbed. + """ + + image = self.image + segments = self.segments + + # choose superpixels to be perturbed + segments_mask = self._choose_superpixels(num_samples, p_sample=self.p_sample) + segments_mask[:, anchor] = 1 + + # for each sample, need to sample one of the background images if provided + if self.images_background: + backgrounds = np.random.choice( + range(len(self.images_background)), + segments_mask.shape[0], + replace=True, + ) + segments_mask = np.hstack((segments_mask, backgrounds.reshape(-1, 1))) + else: + backgrounds = [None] * segments_mask.shape[0] + # create fudged image where the pixel value in each superpixel is set to the + # average over the superpixel for each channel + fudged_image = image.copy() + n_channels = image.shape[-1] + for x in np.unique(segments): + fudged_image[segments == x] = [ + np.mean(image[segments == x][:, i]) for i in range(n_channels) + ] + + pert_imgs = [] + for mask, background_idx in zip(segments_mask, backgrounds): + temp = copy.deepcopy(image) + to_perturb = np.where(mask == 0)[0] + # create mask for each superpixel not present in the sample + mask = np.zeros(segments.shape).astype(bool) + for superpixel in to_perturb: + mask[segments == superpixel] = True + if background_idx: + # replace values with those of background image + # TODO: Could images_background be None herre? + temp[mask] = self.images_background[background_idx][mask] + else: + # ... or with the averaged superpixel value + # TODO: Where is fudged_image defined? + temp[mask] = fudged_image[mask] + pert_imgs.append(temp) + + return np.array(pert_imgs), segments_mask + + def generate_superpixels(self, image: np.ndarray) -> np.ndarray: + """ + Generates superpixels from (i.e., segments) an image. + + Parameters + ---------- + image + A grayscale or RGB image. + + Returns + ------- + A [H, W] array of integers. Each integer is a segment (superpixel) label. + """ + + image_preproc = self._preprocess_img(image) + + return self.segmentation_fn(image_preproc) + + def _preprocess_img(self, image: np.ndarray) -> np.ndarray: + """ + Applies necessary transformations to the image prior to segmentation. + + Parameters + ---------- + image + A grayscale or RGB image. + + Returns + ------- + A preprocessed image. + """ + + # Grayscale images are repeated across channels + if not self.custom_segmentation and image.shape[-1] == 1: + image_preproc = np.repeat(image, 3, axis=2) + else: + image_preproc = image.copy() + + return image_preproc diff --git a/alibi/explainers/anchor_image_utils.py b/alibi/explainers/anchor_image_utils.py new file mode 100644 index 000000000..1e97b9d52 --- /dev/null +++ b/alibi/explainers/anchor_image_utils.py @@ -0,0 +1,25 @@ +import numpy as np + + +def scale_image(image: np.ndarray, scale: tuple = (0, 255)) -> np.ndarray: + """ + Scales an image in a specified range. + + Parameters + ---------- + image + Image to be scale. + scale + The scaling interval. + + Returns + ------- + img_scaled + Scaled image. + """ + + img_max, img_min = image.max(), image.min() + img_std = (image - img_min) / (img_max - img_min) + img_scaled = img_std * (scale[1] - scale[0]) + scale[0] + + return img_scaled From 795f7ce21486b8391f404c2ed18d2581d42b0596 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Fri, 30 Jul 2021 12:37:09 +0200 Subject: [PATCH 02/13] Add tests for scale_image --- alibi/explainers/tests/test_anchor_image_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 alibi/explainers/tests/test_anchor_image_utils.py diff --git a/alibi/explainers/tests/test_anchor_image_utils.py b/alibi/explainers/tests/test_anchor_image_utils.py new file mode 100644 index 000000000..7ec9d82d6 --- /dev/null +++ b/alibi/explainers/tests/test_anchor_image_utils.py @@ -0,0 +1,15 @@ +import numpy as np + +from alibi.explainers.anchor_image_utils import scale_image + + +def test_scale_image(): + image_shape = (28, 28, 1) + scaling_offset = 260 + min_val = 0 + max_val = 255 + + fake_img = np.random.random(size=image_shape) + scaling_offset + scaled_img = scale_image(fake_img, scale=(min_val, max_val)) + assert (scaled_img <= max_val).all() + assert (scaled_img >= min_val).all() From 8ca1c60dd84f5c6c17e3d4919ca7c03abf7de389 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Fri, 30 Jul 2021 12:37:41 +0200 Subject: [PATCH 03/13] Add tests for sampler --- .../tests/test_anchor_image_sampler.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 alibi/explainers/tests/test_anchor_image_sampler.py diff --git a/alibi/explainers/tests/test_anchor_image_sampler.py b/alibi/explainers/tests/test_anchor_image_sampler.py new file mode 100644 index 000000000..6e3c597da --- /dev/null +++ b/alibi/explainers/tests/test_anchor_image_sampler.py @@ -0,0 +1,62 @@ +import pytest + +import numpy as np +from alibi.api.defaults import DEFAULT_META_ANCHOR, DEFAULT_DATA_ANCHOR_IMG +from alibi.explainers import AnchorImage +from alibi.explainers.anchor_image_sampler import AnchorImageSampler + + +@pytest.mark.parametrize( + "models", + [("mnist-cnn-tf2.2.0",), ("mnist-cnn-tf1.15.2.h5",)], + ids="model={}".format, + indirect=True, +) +def test_sampler(models, mnist_data): + eps = 0.0001 # tolerance for tensor comparisons + num_samples = 10 + + x_train = mnist_data["X_train"] + segmentation_fn = "slic" + segmentation_kwargs = {"n_segments": 10, "compactness": 10, "sigma": 0.5} + image_shape = (28, 28, 1) + predict_fn = lambda x: models[0].predict(x) # noqa: E731 + explainer = AnchorImage( + predict_fn, + image_shape, + segmentation_fn=segmentation_fn, + segmentation_kwargs=segmentation_kwargs, + ) + + image = x_train[0] + p_sample = 0.5 # probability of perturbing a superpixel + n_covered_ex = 3 # nb of examples where the anchor applies that are saved + sampler = AnchorImageSampler( + predictor=explainer.predictor, + segmentation_fn=explainer.segmentation_fn, + custom_segmentation=explainer.custom_segmentation, + image=image, + images_background=explainer.images_background, + p_sample=p_sample, + n_covered_ex=n_covered_ex, + ) + + image_preproc = sampler._preprocess_img(image) + superpixels_mask = sampler._choose_superpixels(num_samples=num_samples) + + # grayscale image should be replicated across channel dim before segmentation + assert image_preproc.shape[-1] == 3 + for channel in range(image_preproc.shape[-1]): + assert (image.squeeze() - image_preproc[..., channel] <= eps).all() + # check superpixels mask + assert superpixels_mask.shape[0] == num_samples + assert superpixels_mask.shape[1] == len(list(np.unique(sampler.segments))) + assert superpixels_mask.sum(axis=1).any() <= segmentation_kwargs["n_segments"] + assert superpixels_mask.any() <= 1 + + cov_true, cov_false, labels, data, coverage, _ = sampler.sample( + (0, ()), num_samples + ) + assert data.shape[0] == labels.shape[0] + assert data.shape[1] == len(np.unique(sampler.segments)) + assert coverage == -1 From 99c1cc0ea93ae3c57415d2f769c879901e439fe5 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Fri, 30 Jul 2021 12:48:16 +0200 Subject: [PATCH 04/13] Refactor anchor image test --- alibi/explainers/tests/test_anchor_image.py | 91 +++++++++------------ 1 file changed, 37 insertions(+), 54 deletions(-) diff --git a/alibi/explainers/tests/test_anchor_image.py b/alibi/explainers/tests/test_anchor_image.py index d6484f902..5a906c384 100644 --- a/alibi/explainers/tests/test_anchor_image.py +++ b/alibi/explainers/tests/test_anchor_image.py @@ -3,24 +3,22 @@ import numpy as np from alibi.api.defaults import DEFAULT_META_ANCHOR, DEFAULT_DATA_ANCHOR_IMG from alibi.explainers import AnchorImage +from alibi.explainers.anchor_image_sampler import AnchorImageSampler -@pytest.mark.parametrize('models', - [('mnist-cnn-tf2.2.0',), ('mnist-cnn-tf1.15.2.h5',)], - ids='model={}'.format, - indirect=True) +@pytest.mark.parametrize( + "models", + [("mnist-cnn-tf2.2.0",), ("mnist-cnn-tf1.15.2.h5",)], + ids="model={}".format, + indirect=True, +) def test_anchor_image(models, mnist_data): - x_train = mnist_data['X_train'] - segmentation_fn = 'slic' - segmentation_kwargs = {'n_segments': 10, 'compactness': 10, 'sigma': .5} + x_train = mnist_data["X_train"] + image = x_train[0] + + segmentation_fn = "slic" + segmentation_kwargs = {"n_segments": 10, "compactness": 10, "sigma": 0.5} image_shape = (28, 28, 1) - p_sample = 0.5 # probability of perturbing a superpixel - num_samples = 10 - # img scaling settings - scaling_offset = 260 - min_val = 0 - max_val = 255 - eps = 0.0001 # tolerance for tensor comparisons n_covered_ex = 3 # nb of examples where the anchor applies that are saved # define and train model @@ -33,54 +31,39 @@ def test_anchor_image(models, mnist_data): segmentation_fn=segmentation_fn, segmentation_kwargs=segmentation_kwargs, ) + + p_sample = 0.5 # probability of perturbing a superpixel + n_covered_ex = 3 # nb of examples where the anchor applies that are saved + sampler = AnchorImageSampler( + predictor=explainer.predictor, + segmentation_fn=explainer.segmentation_fn, + custom_segmentation=explainer.custom_segmentation, + image=image, + images_background=explainer.images_background, + p_sample=p_sample, + n_covered_ex=n_covered_ex, + ) + # test explainer initialization assert explainer.predictor(np.zeros((1,) + image_shape)).shape == (1,) assert explainer.custom_segmentation is False - # test sampling and segmentation functions - image = x_train[0] - explainer.instance_label = predict_fn(image[np.newaxis, ...])[0] - explainer.image = image - explainer.n_covered_ex = n_covered_ex - explainer.p_sample = p_sample - segments = explainer.generate_superpixels(image) - explainer.segments = segments - image_preproc = explainer._preprocess_img(image) - explainer.segment_labels = list(np.unique(segments)) - superpixels_mask = explainer._choose_superpixels(num_samples=num_samples) - - # grayscale image should be replicated across channel dim before segmentation - assert image_preproc.shape[-1] == 3 - for channel in range(image_preproc.shape[-1]): - assert (image.squeeze() - image_preproc[..., channel] <= eps).all() - # check superpixels mask - assert superpixels_mask.shape[0] == num_samples - assert superpixels_mask.shape[1] == len(list(np.unique(segments))) - assert superpixels_mask.sum(axis=1).any() <= segmentation_kwargs['n_segments'] - assert superpixels_mask.any() <= 1 - - cov_true, cov_false, labels, data, coverage, _ = explainer.sampler((0, ()), num_samples) - assert data.shape[0] == labels.shape[0] - assert data.shape[1] == len(np.unique(segments)) - assert coverage == -1 - # test explanation - threshold = .95 - explanation = explainer.explain(image, threshold=threshold) + threshold = 0.95 + explanation = explainer.explain(image, threshold=threshold, n_covered_ex=3) - if explanation.raw['feature']: - assert len(explanation.raw['examples'][-1]['covered_true']) <= explainer.n_covered_ex - assert len(explanation.raw['examples'][-1]['covered_false']) <= explainer.n_covered_ex + if explanation.raw["feature"]: + assert ( + len(explanation.raw["examples"][-1]["covered_true"]) <= sampler.n_covered_ex + ) + assert ( + len(explanation.raw["examples"][-1]["covered_false"]) + <= sampler.n_covered_ex + ) else: - assert not explanation.raw['examples'] + assert not explanation.raw["examples"] assert explanation.anchor.shape == image_shape assert explanation.precision >= threshold - assert len(np.unique(explanation.segments)) == len(np.unique(segments)) + assert len(np.unique(explanation.segments)) == len(np.unique(sampler.segments)) assert explanation.meta.keys() == DEFAULT_META_ANCHOR.keys() assert explanation.data.keys() == DEFAULT_DATA_ANCHOR_IMG.keys() - - # test scaling - fake_img = np.random.random(size=image_shape) + scaling_offset - scaled_img = explainer._scale(fake_img, scale=(min_val, max_val)) - assert (scaled_img <= max_val).all() - assert (scaled_img >= min_val).all() From 4d0bde49f117d3d4196508537bd657bc6b95ce78 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Fri, 30 Jul 2021 12:51:09 +0200 Subject: [PATCH 05/13] Don't store mab object internally --- alibi/explainers/anchor_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/alibi/explainers/anchor_image.py b/alibi/explainers/anchor_image.py index bcb208034..4ef180181 100644 --- a/alibi/explainers/anchor_image.py +++ b/alibi/explainers/anchor_image.py @@ -256,7 +256,6 @@ def explain( verbose_every=verbose_every, **kwargs, ) # type: Any - self.mab = mab return self.build_explanation( image, result, sampler.instance_label, params, sampler From 68c77ecb301ada805a874b3e1dada84fed3a60f9 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Fri, 30 Jul 2021 12:57:37 +0200 Subject: [PATCH 06/13] Add test to ensure explain method is stateless --- alibi/explainers/tests/test_anchor_image.py | 30 +++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/alibi/explainers/tests/test_anchor_image.py b/alibi/explainers/tests/test_anchor_image.py index 5a906c384..1b77f3ca6 100644 --- a/alibi/explainers/tests/test_anchor_image.py +++ b/alibi/explainers/tests/test_anchor_image.py @@ -67,3 +67,33 @@ def test_anchor_image(models, mnist_data): assert len(np.unique(explanation.segments)) == len(np.unique(sampler.segments)) assert explanation.meta.keys() == DEFAULT_META_ANCHOR.keys() assert explanation.data.keys() == DEFAULT_DATA_ANCHOR_IMG.keys() + + +@pytest.mark.parametrize( + "models", + [("mnist-cnn-tf2.2.0",)], + ids="model={}".format, + indirect=True, +) +def test_stateless_explainer(models, mnist_data): + predict_fn = lambda x: models[0].predict(x) # noqa: E731 + image_shape = (28, 28, 1) + segmentation_fn = "slic" + segmentation_kwargs = {"n_segments": 10, "compactness": 10, "sigma": 0.5} + + explainer = AnchorImage( + predict_fn, + image_shape, + segmentation_fn=segmentation_fn, + segmentation_kwargs=segmentation_kwargs, + ) + + x_train = mnist_data["X_train"] + image = x_train[0] + threshold = 0.95 + + before_explain = explainer.__dict__ + explainer.explain(image, threshold=threshold, n_covered_ex=3) + after_explain = explainer.__dict__ + + assert before_explain == after_explain From 904f46c68876469aa606b8518e218f2243c16d1c Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Mon, 2 Aug 2021 14:48:19 +0200 Subject: [PATCH 07/13] Add __call__ and revert black --- alibi/explainers/anchor_image.py | 78 +++++++------------ alibi/explainers/anchor_image_sampler.py | 2 +- .../tests/test_anchor_image_sampler.py | 2 +- 3 files changed, 32 insertions(+), 50 deletions(-) diff --git a/alibi/explainers/anchor_image.py b/alibi/explainers/anchor_image.py index 4ef180181..1b5ad4ed0 100644 --- a/alibi/explainers/anchor_image.py +++ b/alibi/explainers/anchor_image.py @@ -18,22 +18,16 @@ logger = logging.getLogger(__name__) DEFAULT_SEGMENTATION_KWARGS = { - "felzenszwalb": {}, - "quickshift": {}, - "slic": {"n_segments": 10, "compactness": 10, "sigma": 0.5}, + 'felzenszwalb': {}, + 'quickshift': {}, + 'slic': {'n_segments': 10, 'compactness': 10, 'sigma': .5} } class AnchorImage(Explainer): - def __init__( - self, - predictor: Callable, - image_shape: tuple, - segmentation_fn: Any = "slic", - segmentation_kwargs: dict = None, - images_background: np.ndarray = None, - seed: int = None, - ) -> None: + def __init__(self, predictor: Callable, image_shape: tuple, segmentation_fn: Any = 'slic', + segmentation_kwargs: dict = None, images_background: np.ndarray = None, + seed: int = None) -> None: """ Initialize anchor image explainer. @@ -62,16 +56,16 @@ def __init__( segmentation_kwargs = DEFAULT_SEGMENTATION_KWARGS[segmentation_fn] # type: ignore except KeyError: logger.warning( - "DEFAULT_SEGMENTATION_KWARGS did not contain any entry" - "for segmentation method {}. No kwargs will be passed to" - "the segmentation function!".format(segmentation_fn) + 'DEFAULT_SEGMENTATION_KWARGS did not contain any entry' + 'for segmentation method {}. No kwargs will be passed to' + 'the segmentation function!'.format(segmentation_fn) ) segmentation_kwargs = {} elif callable(segmentation_fn) and segmentation_kwargs: logger.warning( - "Specified both a segmentation function to create superpixels and " - "keyword arguments for built segmentation functions. By default " - "the specified segmentation function will be used." + 'Specified both a segmentation function to create superpixels and ' + 'keyword arguments for built segmentation functions. By default ' + 'the specified segmentation function will be used.' ) # set the predictor @@ -79,37 +73,31 @@ def __init__( self.predictor = self._transform_predictor(predictor) # segmentation function is either a user-defined function or one of the values in - fn_options = { - "felzenszwalb": felzenszwalb, - "slic": slic, - "quickshift": quickshift, - } + fn_options = {'felzenszwalb': felzenszwalb, 'slic': slic, 'quickshift': quickshift} if callable(segmentation_fn): self.custom_segmentation = True self.segmentation_fn = segmentation_fn else: self.custom_segmentation = False - self.segmentation_fn = partial( - fn_options[segmentation_fn], **segmentation_kwargs - ) + self.segmentation_fn = partial(fn_options[segmentation_fn], **segmentation_kwargs) self.images_background = images_background # a superpixel is perturbed with prob 1 - p_sample self.p_sample = 0.5 # type: float # update metadata - self.meta["params"].update( + self.meta['params'].update( custom_segmentation=self.custom_segmentation, segmentation_kwargs=segmentation_kwargs, p_sample=self.p_sample, seed=seed, image_shape=self.image_shape, - images_background=self.images_background, + images_background=self.images_background ) if not self.custom_segmentation: - self.meta["params"].update(segmentation_fn=segmentation_fn) + self.meta['params'].update(segmentation_fn=segmentation_fn) else: - self.meta["params"].update(segmentation_fn="custom") + self.meta['params'].update(segmentation_fn='custom') def generate_superpixels(self, image: np.ndarray) -> np.ndarray: """ @@ -221,7 +209,7 @@ def explain( """ # get params for storage in meta params = locals() - remove = ["image", "self"] + remove = ['image', 'self'] for key in remove: params.pop(key) @@ -237,11 +225,10 @@ def explain( # get anchors and add metadata mab = AnchorBaseBeam( - samplers=[sampler.sample], + samplers=[sampler], sample_cache_size=binary_cache_size, cache_margin=cache_margin, - **kwargs, - ) + **kwargs) result = mab.anchor_beam( desired_confidence=threshold, delta=delta, @@ -285,13 +272,13 @@ def build_explanation( Parameters passed to `explain` """ - result["instance"] = image - result["instances"] = np.expand_dims(image, 0) - result["prediction"] = np.array([predicted_label]) + result['instance'] = image + result['instances'] = np.expand_dims(image, 0) + result['prediction'] = np.array([predicted_label]) # overlay image with anchor mask - anchor = self.overlay_mask(image, sampler.segments, result["feature"]) - exp = AnchorExplanation("image", result) + anchor = self.overlay_mask(image, sampler.segments, result['feature']) + exp = AnchorExplanation('image', result) # output explanation dictionary data = copy.deepcopy(DEFAULT_DATA_ANCHOR_IMG) @@ -300,23 +287,18 @@ def build_explanation( segments=sampler.segments, precision=exp.precision(), coverage=exp.coverage(), - raw=exp.exp_map, + raw=exp.exp_map ) # create explanation object explanation = Explanation(meta=copy.deepcopy(self.meta), data=data) # params passed to explain - explanation.meta["params"].update(params) + explanation.meta['params'].update(params) return explanation - def overlay_mask( - self, - image: np.ndarray, - segments: np.ndarray, - mask_features: list, - scale: tuple = (0, 255), - ) -> np.ndarray: + def overlay_mask(self, image: np.ndarray, segments: np.ndarray, mask_features: list, + scale: tuple = (0, 255)) -> np.ndarray: """ Overlay image with mask described by the mask features. diff --git a/alibi/explainers/anchor_image_sampler.py b/alibi/explainers/anchor_image_sampler.py index 0614d5376..d5d66b836 100644 --- a/alibi/explainers/anchor_image_sampler.py +++ b/alibi/explainers/anchor_image_sampler.py @@ -48,7 +48,7 @@ def __init__( self.segment_labels = list(np.unique(self.segments)) self.instance_label = self.predictor(image[np.newaxis, ...])[0] - def sample( + def __call__( self, anchor: Tuple[int, tuple], num_samples: int, compute_labels: bool = True ) -> List[Union[np.ndarray, float, int]]: """ diff --git a/alibi/explainers/tests/test_anchor_image_sampler.py b/alibi/explainers/tests/test_anchor_image_sampler.py index 6e3c597da..cce3a2669 100644 --- a/alibi/explainers/tests/test_anchor_image_sampler.py +++ b/alibi/explainers/tests/test_anchor_image_sampler.py @@ -54,7 +54,7 @@ def test_sampler(models, mnist_data): assert superpixels_mask.sum(axis=1).any() <= segmentation_kwargs["n_segments"] assert superpixels_mask.any() <= 1 - cov_true, cov_false, labels, data, coverage, _ = sampler.sample( + cov_true, cov_false, labels, data, coverage, _ = sampler( (0, ()), num_samples ) assert data.shape[0] == labels.shape[0] From 7fb0ca3d0e9dfd67bb2f7fab3b99ca542502cba6 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Mon, 2 Aug 2021 15:29:47 +0200 Subject: [PATCH 08/13] fix linter --- alibi/explainers/tests/test_anchor_image_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/alibi/explainers/tests/test_anchor_image_sampler.py b/alibi/explainers/tests/test_anchor_image_sampler.py index cce3a2669..96535d5f9 100644 --- a/alibi/explainers/tests/test_anchor_image_sampler.py +++ b/alibi/explainers/tests/test_anchor_image_sampler.py @@ -1,7 +1,6 @@ import pytest import numpy as np -from alibi.api.defaults import DEFAULT_META_ANCHOR, DEFAULT_DATA_ANCHOR_IMG from alibi.explainers import AnchorImage from alibi.explainers.anchor_image_sampler import AnchorImageSampler From 09673f57b9c2732e8490b04bea90b6ec706c983d Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Tue, 3 Aug 2021 10:39:17 +0200 Subject: [PATCH 09/13] fix mypy --- alibi/explainers/anchor_image.py | 39 +++++++++++++++----------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/alibi/explainers/anchor_image.py b/alibi/explainers/anchor_image.py index 1b5ad4ed0..1c08506e9 100644 --- a/alibi/explainers/anchor_image.py +++ b/alibi/explainers/anchor_image.py @@ -139,27 +139,24 @@ def _preprocess_img(self, image: np.ndarray) -> np.ndarray: return image_preproc - def explain( - self, # type: ignore - image: np.ndarray, - p_sample: float = 0.5, - threshold: float = 0.95, - delta: float = 0.1, - tau: float = 0.15, - batch_size: int = 100, - coverage_samples: int = 10000, - beam_size: int = 1, - stop_on_first: bool = False, - max_anchor_size: int = None, - min_samples_start: int = 100, - n_covered_ex: int = 10, - binary_cache_size: int = 10000, - cache_margin: int = 1000, - verbose: bool = False, - verbose_every: int = 1, - **kwargs: Any - ) -> Explanation: - + def explain(self, # type: ignore + image: np.ndarray, + p_sample: float = 0.5, + threshold: float = 0.95, + delta: float = 0.1, + tau: float = 0.15, + batch_size: int = 100, + coverage_samples: int = 10000, + beam_size: int = 1, + stop_on_first: bool = False, + max_anchor_size: int = None, + min_samples_start: int = 100, + n_covered_ex: int = 10, + binary_cache_size: int = 10000, + cache_margin: int = 1000, + verbose: bool = False, + verbose_every: int = 1, + **kwargs: Any) -> Explanation: """ Explain instance and return anchor with metadata. From ae66df4c0d2b2816aa2a726a3840301a87f28603 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Tue, 3 Aug 2021 17:11:22 +0200 Subject: [PATCH 10/13] Move sampler and util to same module --- alibi/explainers/anchor_image.py | 279 +++++++++++++++++- alibi/explainers/anchor_image_sampler.py | 260 ---------------- alibi/explainers/anchor_image_utils.py | 25 -- alibi/explainers/tests/test_anchor_image.py | 100 +++++-- .../tests/test_anchor_image_sampler.py | 61 ---- .../tests/test_anchor_image_utils.py | 15 - 6 files changed, 349 insertions(+), 391 deletions(-) delete mode 100644 alibi/explainers/anchor_image_sampler.py delete mode 100644 alibi/explainers/anchor_image_utils.py delete mode 100644 alibi/explainers/tests/test_anchor_image_sampler.py delete mode 100644 alibi/explainers/tests/test_anchor_image_utils.py diff --git a/alibi/explainers/anchor_image.py b/alibi/explainers/anchor_image.py index 1c08506e9..9ddf57ae5 100644 --- a/alibi/explainers/anchor_image.py +++ b/alibi/explainers/anchor_image.py @@ -11,8 +11,6 @@ from alibi.api.defaults import DEFAULT_META_ANCHOR, DEFAULT_DATA_ANCHOR_IMG from .anchor_base import AnchorBaseBeam from .anchor_explanation import AnchorExplanation -from .anchor_image_utils import scale_image -from .anchor_image_sampler import AnchorImageSampler from skimage.segmentation import felzenszwalb, slic, quickshift logger = logging.getLogger(__name__) @@ -23,6 +21,283 @@ 'slic': {'n_segments': 10, 'compactness': 10, 'sigma': .5} } +def scale_image(image: np.ndarray, scale: tuple = (0, 255)) -> np.ndarray: + """ + Scales an image in a specified range. + + Parameters + ---------- + image + Image to be scale. + scale + The scaling interval. + + Returns + ------- + img_scaled + Scaled image. + """ + + img_max, img_min = image.max(), image.min() + img_std = (image - img_min) / (img_max - img_min) + img_scaled = img_std * (scale[1] - scale[0]) + scale[0] + + return img_scaled + + +class AnchorImageSampler: + def __init__( + self, + # TODO: Should we call `predictor`, `prediction_fn` instead? + predictor: Callable, + segmentation_fn: Callable, + custom_segmentation: bool, + image: np.ndarray, + images_background: np.ndarray = None, + p_sample: float = 0.5, + n_covered_ex: int = 10, + ): + """ + Initialize anchor image sampler. + + Parameters + ---------- + predictor + A callable that takes a tensor of N data points as inputs and returns N outputs. + segmentation_fn + Function used to segment the images. + image + Image to be explained. + images_background + Images to overlay superpixels on. + p_sample + Probability for a pixel to be represented by the average value of its superpixel. + n_covered_ex + How many examples where anchors apply to store for each anchor sampled during search + (both examples where prediction on samples agrees/disagrees with desired_label are stored). + """ + self.predictor = predictor + self.segmentation_fn = segmentation_fn + self.custom_segmentation = custom_segmentation + self.image = image + self.images_background = images_background + self.n_covered_ex = n_covered_ex + self.p_sample = p_sample + self.segments = self.generate_superpixels(image) + self.segment_labels = list(np.unique(self.segments)) + self.instance_label = self.predictor(image[np.newaxis, ...])[0] + + def __call__( + self, anchor: Tuple[int, tuple], num_samples: int, compute_labels: bool = True + ) -> List[Union[np.ndarray, float, int]]: + """ + Sample images from a perturbation distribution by masking randomly chosen superpixels + from the original image and replacing them with pixel values from superimposed images + if background images are provided to the explainer. Otherwise, the superpixels from the + original image are replaced with their average values. + + Parameters + ---------- + anchor + int: order of anchor in the batch + tuple: features (= superpixels) present in the proposed anchor + num_samples + Number of samples used + compute_labels + If True, an array of comparisons between predictions on perturbed samples and + instance to be explained is returned. + + Returns + ------- + If compute_labels=True, a list containing the following is returned: + - covered_true: perturbed examples where the anchor applies and the model prediction + on perturbed is the same as the instance prediction + - covered_false: perturbed examples where the anchor applies and the model prediction + on pertrurbed sample is NOT the same as the instance prediction + - labels: num_samples ints indicating whether the prediction on the perturbed sample + matches (1) the label of the instance to be explained or not (0) + - data: Matrix with 1s and 0s indicating whether the values in a superpixel will + remain unchanged (1) or will be perturbed (0), for each sample + - 1.0: indicates exact coverage is not computed for this algorithm + - anchor[0]: position of anchor in the batch request + Otherwise, a list containing the data matrix only is returned. + """ + + if compute_labels: + raw_data, data = self.perturbation(anchor[1], num_samples) + labels = self.compare_labels(raw_data) + covered_true = raw_data[labels][: self.n_covered_ex] + covered_true = [scale_image(img) for img in covered_true] + covered_false = raw_data[np.logical_not(labels)][: self.n_covered_ex] + covered_false = [scale_image(img) for img in covered_false] + # coverage set to -1.0 as we can't compute 'true'coverage for this model + + return [covered_true, covered_false, labels.astype(int), data, -1.0, anchor[0]] # type: ignore + + else: + data = self._choose_superpixels(num_samples) + data[:, anchor[1]] = 1 # superpixels in candidate anchor are not perturbed + + return [data] + + def compare_labels(self, samples: np.ndarray) -> np.ndarray: + """ + Compute the agreement between a classifier prediction on an instance to be explained + and the prediction on a set of samples which have a subset of perturbed superpixels. + + Parameters + ---------- + samples + Samples whose labels are to be compared with the instance label. + + Returns + ------- + A boolean array indicating whether the prediction was the same as the instance label. + """ + + return self.predictor(samples) == self.instance_label + + def _choose_superpixels( + self, num_samples: int, p_sample: float = 0.5 + ) -> np.ndarray: + """ + Generates a binary mask of dimension [num_samples, M] where M is the number of + image superpixels (segments). + + Parameters + ---------- + num_samples + Number of perturbed images to be generated + p_sample: + The probability that a superpixel is perturbed + + Returns + ------- + data + Binary 2D mask, where each non-zero entry in a row indicates that + the values of the particular image segment will not be perturbed. + """ + + n_features = len(self.segment_labels) + data = np.random.choice( + [0, 1], num_samples * n_features, p=[p_sample, 1 - p_sample] + ) + data = data.reshape((num_samples, n_features)) + + return data + + def perturbation( + self, anchor: tuple, num_samples: int + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Perturbs an image by altering the values of selected superpixels. If a dataset of image + backgrounds is provided to the explainer, then the superpixels are replaced with the + equivalent superpixels from the background image. Otherwise, the superpixels are replaced + by their average value. + + Parameters + ---------- + anchor: + Contains the superpixels whose values are not going to be perturbed. + num_samples: + Number of perturbed samples to be returned. + + Returns + ------- + imgs + A [num_samples, H, W, C] array of perturbed images. + segments_mask + A [num_samples, M] binary mask, where M is the number of image superpixels + segments. 1 indicates the values in that particular superpixels are not + perturbed. + """ + + image = self.image + segments = self.segments + + # choose superpixels to be perturbed + segments_mask = self._choose_superpixels(num_samples, p_sample=self.p_sample) + segments_mask[:, anchor] = 1 + + # for each sample, need to sample one of the background images if provided + if self.images_background: + backgrounds = np.random.choice( + range(len(self.images_background)), + segments_mask.shape[0], + replace=True, + ) + segments_mask = np.hstack((segments_mask, backgrounds.reshape(-1, 1))) + else: + backgrounds = [None] * segments_mask.shape[0] + # create fudged image where the pixel value in each superpixel is set to the + # average over the superpixel for each channel + fudged_image = image.copy() + n_channels = image.shape[-1] + for x in np.unique(segments): + fudged_image[segments == x] = [ + np.mean(image[segments == x][:, i]) for i in range(n_channels) + ] + + pert_imgs = [] + for mask, background_idx in zip(segments_mask, backgrounds): + temp = copy.deepcopy(image) + to_perturb = np.where(mask == 0)[0] + # create mask for each superpixel not present in the sample + mask = np.zeros(segments.shape).astype(bool) + for superpixel in to_perturb: + mask[segments == superpixel] = True + if background_idx: + # replace values with those of background image + # TODO: Could images_background be None herre? + temp[mask] = self.images_background[background_idx][mask] + else: + # ... or with the averaged superpixel value + # TODO: Where is fudged_image defined? + temp[mask] = fudged_image[mask] + pert_imgs.append(temp) + + return np.array(pert_imgs), segments_mask + + def generate_superpixels(self, image: np.ndarray) -> np.ndarray: + """ + Generates superpixels from (i.e., segments) an image. + + Parameters + ---------- + image + A grayscale or RGB image. + + Returns + ------- + A [H, W] array of integers. Each integer is a segment (superpixel) label. + """ + + image_preproc = self._preprocess_img(image) + + return self.segmentation_fn(image_preproc) + + def _preprocess_img(self, image: np.ndarray) -> np.ndarray: + """ + Applies necessary transformations to the image prior to segmentation. + + Parameters + ---------- + image + A grayscale or RGB image. + + Returns + ------- + A preprocessed image. + """ + + # Grayscale images are repeated across channels + if not self.custom_segmentation and image.shape[-1] == 1: + image_preproc = np.repeat(image, 3, axis=2) + else: + image_preproc = image.copy() + + return image_preproc + class AnchorImage(Explainer): def __init__(self, predictor: Callable, image_shape: tuple, segmentation_fn: Any = 'slic', diff --git a/alibi/explainers/anchor_image_sampler.py b/alibi/explainers/anchor_image_sampler.py deleted file mode 100644 index d5d66b836..000000000 --- a/alibi/explainers/anchor_image_sampler.py +++ /dev/null @@ -1,260 +0,0 @@ -import numpy as np -import copy - -from typing import Tuple, Callable, List, Union - -from .anchor_image_utils import scale_image - - -class AnchorImageSampler: - def __init__( - self, - # TODO: Should we call `predictor`, `prediction_fn` instead? - predictor: Callable, - segmentation_fn: Callable, - custom_segmentation: bool, - image: np.ndarray, - images_background: np.ndarray = None, - p_sample: float = 0.5, - n_covered_ex: int = 10, - ): - """ - Initialize anchor image sampler. - - Parameters - ---------- - predictor - A callable that takes a tensor of N data points as inputs and returns N outputs. - segmentation_fn - Function used to segment the images. - image - Image to be explained. - images_background - Images to overlay superpixels on. - p_sample - Probability for a pixel to be represented by the average value of its superpixel. - n_covered_ex - How many examples where anchors apply to store for each anchor sampled during search - (both examples where prediction on samples agrees/disagrees with desired_label are stored). - """ - self.predictor = predictor - self.segmentation_fn = segmentation_fn - self.custom_segmentation = custom_segmentation - self.image = image - self.images_background = images_background - self.n_covered_ex = n_covered_ex - self.p_sample = p_sample - self.segments = self.generate_superpixels(image) - self.segment_labels = list(np.unique(self.segments)) - self.instance_label = self.predictor(image[np.newaxis, ...])[0] - - def __call__( - self, anchor: Tuple[int, tuple], num_samples: int, compute_labels: bool = True - ) -> List[Union[np.ndarray, float, int]]: - """ - Sample images from a perturbation distribution by masking randomly chosen superpixels - from the original image and replacing them with pixel values from superimposed images - if background images are provided to the explainer. Otherwise, the superpixels from the - original image are replaced with their average values. - - Parameters - ---------- - anchor - int: order of anchor in the batch - tuple: features (= superpixels) present in the proposed anchor - num_samples - Number of samples used - compute_labels - If True, an array of comparisons between predictions on perturbed samples and - instance to be explained is returned. - - Returns - ------- - If compute_labels=True, a list containing the following is returned: - - covered_true: perturbed examples where the anchor applies and the model prediction - on perturbed is the same as the instance prediction - - covered_false: perturbed examples where the anchor applies and the model prediction - on pertrurbed sample is NOT the same as the instance prediction - - labels: num_samples ints indicating whether the prediction on the perturbed sample - matches (1) the label of the instance to be explained or not (0) - - data: Matrix with 1s and 0s indicating whether the values in a superpixel will - remain unchanged (1) or will be perturbed (0), for each sample - - 1.0: indicates exact coverage is not computed for this algorithm - - anchor[0]: position of anchor in the batch request - Otherwise, a list containing the data matrix only is returned. - """ - - if compute_labels: - raw_data, data = self.perturbation(anchor[1], num_samples) - labels = self.compare_labels(raw_data) - covered_true = raw_data[labels][: self.n_covered_ex] - covered_true = [scale_image(img) for img in covered_true] - covered_false = raw_data[np.logical_not(labels)][: self.n_covered_ex] - covered_false = [scale_image(img) for img in covered_false] - # coverage set to -1.0 as we can't compute 'true'coverage for this model - - return [covered_true, covered_false, labels.astype(int), data, -1.0, anchor[0]] # type: ignore - - else: - data = self._choose_superpixels(num_samples) - data[:, anchor[1]] = 1 # superpixels in candidate anchor are not perturbed - - return [data] - - def compare_labels(self, samples: np.ndarray) -> np.ndarray: - """ - Compute the agreement between a classifier prediction on an instance to be explained - and the prediction on a set of samples which have a subset of perturbed superpixels. - - Parameters - ---------- - samples - Samples whose labels are to be compared with the instance label. - - Returns - ------- - A boolean array indicating whether the prediction was the same as the instance label. - """ - - return self.predictor(samples) == self.instance_label - - def _choose_superpixels( - self, num_samples: int, p_sample: float = 0.5 - ) -> np.ndarray: - """ - Generates a binary mask of dimension [num_samples, M] where M is the number of - image superpixels (segments). - - Parameters - ---------- - num_samples - Number of perturbed images to be generated - p_sample: - The probability that a superpixel is perturbed - - Returns - ------- - data - Binary 2D mask, where each non-zero entry in a row indicates that - the values of the particular image segment will not be perturbed. - """ - - n_features = len(self.segment_labels) - data = np.random.choice( - [0, 1], num_samples * n_features, p=[p_sample, 1 - p_sample] - ) - data = data.reshape((num_samples, n_features)) - - return data - - def perturbation( - self, anchor: tuple, num_samples: int - ) -> Tuple[np.ndarray, np.ndarray]: - """ - Perturbs an image by altering the values of selected superpixels. If a dataset of image - backgrounds is provided to the explainer, then the superpixels are replaced with the - equivalent superpixels from the background image. Otherwise, the superpixels are replaced - by their average value. - - Parameters - ---------- - anchor: - Contains the superpixels whose values are not going to be perturbed. - num_samples: - Number of perturbed samples to be returned. - - Returns - ------- - imgs - A [num_samples, H, W, C] array of perturbed images. - segments_mask - A [num_samples, M] binary mask, where M is the number of image superpixels - segments. 1 indicates the values in that particular superpixels are not - perturbed. - """ - - image = self.image - segments = self.segments - - # choose superpixels to be perturbed - segments_mask = self._choose_superpixels(num_samples, p_sample=self.p_sample) - segments_mask[:, anchor] = 1 - - # for each sample, need to sample one of the background images if provided - if self.images_background: - backgrounds = np.random.choice( - range(len(self.images_background)), - segments_mask.shape[0], - replace=True, - ) - segments_mask = np.hstack((segments_mask, backgrounds.reshape(-1, 1))) - else: - backgrounds = [None] * segments_mask.shape[0] - # create fudged image where the pixel value in each superpixel is set to the - # average over the superpixel for each channel - fudged_image = image.copy() - n_channels = image.shape[-1] - for x in np.unique(segments): - fudged_image[segments == x] = [ - np.mean(image[segments == x][:, i]) for i in range(n_channels) - ] - - pert_imgs = [] - for mask, background_idx in zip(segments_mask, backgrounds): - temp = copy.deepcopy(image) - to_perturb = np.where(mask == 0)[0] - # create mask for each superpixel not present in the sample - mask = np.zeros(segments.shape).astype(bool) - for superpixel in to_perturb: - mask[segments == superpixel] = True - if background_idx: - # replace values with those of background image - # TODO: Could images_background be None herre? - temp[mask] = self.images_background[background_idx][mask] - else: - # ... or with the averaged superpixel value - # TODO: Where is fudged_image defined? - temp[mask] = fudged_image[mask] - pert_imgs.append(temp) - - return np.array(pert_imgs), segments_mask - - def generate_superpixels(self, image: np.ndarray) -> np.ndarray: - """ - Generates superpixels from (i.e., segments) an image. - - Parameters - ---------- - image - A grayscale or RGB image. - - Returns - ------- - A [H, W] array of integers. Each integer is a segment (superpixel) label. - """ - - image_preproc = self._preprocess_img(image) - - return self.segmentation_fn(image_preproc) - - def _preprocess_img(self, image: np.ndarray) -> np.ndarray: - """ - Applies necessary transformations to the image prior to segmentation. - - Parameters - ---------- - image - A grayscale or RGB image. - - Returns - ------- - A preprocessed image. - """ - - # Grayscale images are repeated across channels - if not self.custom_segmentation and image.shape[-1] == 1: - image_preproc = np.repeat(image, 3, axis=2) - else: - image_preproc = image.copy() - - return image_preproc diff --git a/alibi/explainers/anchor_image_utils.py b/alibi/explainers/anchor_image_utils.py deleted file mode 100644 index 1e97b9d52..000000000 --- a/alibi/explainers/anchor_image_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import numpy as np - - -def scale_image(image: np.ndarray, scale: tuple = (0, 255)) -> np.ndarray: - """ - Scales an image in a specified range. - - Parameters - ---------- - image - Image to be scale. - scale - The scaling interval. - - Returns - ------- - img_scaled - Scaled image. - """ - - img_max, img_min = image.max(), image.min() - img_std = (image - img_min) / (img_max - img_min) - img_scaled = img_std * (scale[1] - scale[0]) + scale[0] - - return img_scaled diff --git a/alibi/explainers/tests/test_anchor_image.py b/alibi/explainers/tests/test_anchor_image.py index 1b77f3ca6..481a0777b 100644 --- a/alibi/explainers/tests/test_anchor_image.py +++ b/alibi/explainers/tests/test_anchor_image.py @@ -3,8 +3,75 @@ import numpy as np from alibi.api.defaults import DEFAULT_META_ANCHOR, DEFAULT_DATA_ANCHOR_IMG from alibi.explainers import AnchorImage +from alibi.explainers.anchor_image import AnchorImageSampler, scale_image from alibi.explainers.anchor_image_sampler import AnchorImageSampler +def test_scale_image(): + image_shape = (28, 28, 1) + scaling_offset = 260 + min_val = 0 + max_val = 255 + + fake_img = np.random.random(size=image_shape) + scaling_offset + scaled_img = scale_image(fake_img, scale=(min_val, max_val)) + assert (scaled_img <= max_val).all() + assert (scaled_img >= min_val).all() + + +@pytest.mark.parametrize( + "models", + [("mnist-cnn-tf2.2.0",), ("mnist-cnn-tf1.15.2.h5",)], + ids="model={}".format, + indirect=True, +) +def test_sampler(models, mnist_data): + eps = 0.0001 # tolerance for tensor comparisons + num_samples = 10 + + x_train = mnist_data["X_train"] + segmentation_fn = "slic" + segmentation_kwargs = {"n_segments": 10, "compactness": 10, "sigma": 0.5} + image_shape = (28, 28, 1) + predict_fn = lambda x: models[0].predict(x) # noqa: E731 + explainer = AnchorImage( + predict_fn, + image_shape, + segmentation_fn=segmentation_fn, + segmentation_kwargs=segmentation_kwargs, + ) + + image = x_train[0] + p_sample = 0.5 # probability of perturbing a superpixel + n_covered_ex = 3 # nb of examples where the anchor applies that are saved + sampler = AnchorImageSampler( + predictor=explainer.predictor, + segmentation_fn=explainer.segmentation_fn, + custom_segmentation=explainer.custom_segmentation, + image=image, + images_background=explainer.images_background, + p_sample=p_sample, + n_covered_ex=n_covered_ex, + ) + + image_preproc = sampler._preprocess_img(image) + superpixels_mask = sampler._choose_superpixels(num_samples=num_samples) + + # grayscale image should be replicated across channel dim before segmentation + assert image_preproc.shape[-1] == 3 + for channel in range(image_preproc.shape[-1]): + assert (image.squeeze() - image_preproc[..., channel] <= eps).all() + # check superpixels mask + assert superpixels_mask.shape[0] == num_samples + assert superpixels_mask.shape[1] == len(list(np.unique(sampler.segments))) + assert superpixels_mask.sum(axis=1).any() <= segmentation_kwargs["n_segments"] + assert superpixels_mask.any() <= 1 + + cov_true, cov_false, labels, data, coverage, _ = sampler( + (0, ()), num_samples + ) + assert data.shape[0] == labels.shape[0] + assert data.shape[1] == len(np.unique(sampler.segments)) + assert coverage == -1 @pytest.mark.parametrize( "models", @@ -50,8 +117,13 @@ def test_anchor_image(models, mnist_data): # test explanation threshold = 0.95 + + before_explain = explainer.__dict__ explanation = explainer.explain(image, threshold=threshold, n_covered_ex=3) + after_explain = explainer.__dict__ + # Ensure that explainer's internal state doesn't change + assert before_explain == after_explain if explanation.raw["feature"]: assert ( len(explanation.raw["examples"][-1]["covered_true"]) <= sampler.n_covered_ex @@ -69,31 +141,3 @@ def test_anchor_image(models, mnist_data): assert explanation.data.keys() == DEFAULT_DATA_ANCHOR_IMG.keys() -@pytest.mark.parametrize( - "models", - [("mnist-cnn-tf2.2.0",)], - ids="model={}".format, - indirect=True, -) -def test_stateless_explainer(models, mnist_data): - predict_fn = lambda x: models[0].predict(x) # noqa: E731 - image_shape = (28, 28, 1) - segmentation_fn = "slic" - segmentation_kwargs = {"n_segments": 10, "compactness": 10, "sigma": 0.5} - - explainer = AnchorImage( - predict_fn, - image_shape, - segmentation_fn=segmentation_fn, - segmentation_kwargs=segmentation_kwargs, - ) - - x_train = mnist_data["X_train"] - image = x_train[0] - threshold = 0.95 - - before_explain = explainer.__dict__ - explainer.explain(image, threshold=threshold, n_covered_ex=3) - after_explain = explainer.__dict__ - - assert before_explain == after_explain diff --git a/alibi/explainers/tests/test_anchor_image_sampler.py b/alibi/explainers/tests/test_anchor_image_sampler.py deleted file mode 100644 index 96535d5f9..000000000 --- a/alibi/explainers/tests/test_anchor_image_sampler.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -import numpy as np -from alibi.explainers import AnchorImage -from alibi.explainers.anchor_image_sampler import AnchorImageSampler - - -@pytest.mark.parametrize( - "models", - [("mnist-cnn-tf2.2.0",), ("mnist-cnn-tf1.15.2.h5",)], - ids="model={}".format, - indirect=True, -) -def test_sampler(models, mnist_data): - eps = 0.0001 # tolerance for tensor comparisons - num_samples = 10 - - x_train = mnist_data["X_train"] - segmentation_fn = "slic" - segmentation_kwargs = {"n_segments": 10, "compactness": 10, "sigma": 0.5} - image_shape = (28, 28, 1) - predict_fn = lambda x: models[0].predict(x) # noqa: E731 - explainer = AnchorImage( - predict_fn, - image_shape, - segmentation_fn=segmentation_fn, - segmentation_kwargs=segmentation_kwargs, - ) - - image = x_train[0] - p_sample = 0.5 # probability of perturbing a superpixel - n_covered_ex = 3 # nb of examples where the anchor applies that are saved - sampler = AnchorImageSampler( - predictor=explainer.predictor, - segmentation_fn=explainer.segmentation_fn, - custom_segmentation=explainer.custom_segmentation, - image=image, - images_background=explainer.images_background, - p_sample=p_sample, - n_covered_ex=n_covered_ex, - ) - - image_preproc = sampler._preprocess_img(image) - superpixels_mask = sampler._choose_superpixels(num_samples=num_samples) - - # grayscale image should be replicated across channel dim before segmentation - assert image_preproc.shape[-1] == 3 - for channel in range(image_preproc.shape[-1]): - assert (image.squeeze() - image_preproc[..., channel] <= eps).all() - # check superpixels mask - assert superpixels_mask.shape[0] == num_samples - assert superpixels_mask.shape[1] == len(list(np.unique(sampler.segments))) - assert superpixels_mask.sum(axis=1).any() <= segmentation_kwargs["n_segments"] - assert superpixels_mask.any() <= 1 - - cov_true, cov_false, labels, data, coverage, _ = sampler( - (0, ()), num_samples - ) - assert data.shape[0] == labels.shape[0] - assert data.shape[1] == len(np.unique(sampler.segments)) - assert coverage == -1 diff --git a/alibi/explainers/tests/test_anchor_image_utils.py b/alibi/explainers/tests/test_anchor_image_utils.py deleted file mode 100644 index 7ec9d82d6..000000000 --- a/alibi/explainers/tests/test_anchor_image_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import numpy as np - -from alibi.explainers.anchor_image_utils import scale_image - - -def test_scale_image(): - image_shape = (28, 28, 1) - scaling_offset = 260 - min_val = 0 - max_val = 255 - - fake_img = np.random.random(size=image_shape) + scaling_offset - scaled_img = scale_image(fake_img, scale=(min_val, max_val)) - assert (scaled_img <= max_val).all() - assert (scaled_img >= min_val).all() From 2ee5bc62549db9eb72247e762987cb9759563ee2 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Tue, 3 Aug 2021 18:12:31 +0200 Subject: [PATCH 11/13] Fix tests --- alibi/explainers/anchor_image.py | 2 +- alibi/explainers/tests/test_anchor_image.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/alibi/explainers/anchor_image.py b/alibi/explainers/anchor_image.py index 9ddf57ae5..b79ccddc3 100644 --- a/alibi/explainers/anchor_image.py +++ b/alibi/explainers/anchor_image.py @@ -4,7 +4,7 @@ import numpy as np from functools import partial -from typing import Any, Callable +from typing import Any, Callable, List, Union, Tuple from alibi.utils.wrappers import ArgmaxTransformer from alibi.api.interfaces import Explainer, Explanation diff --git a/alibi/explainers/tests/test_anchor_image.py b/alibi/explainers/tests/test_anchor_image.py index 481a0777b..283121218 100644 --- a/alibi/explainers/tests/test_anchor_image.py +++ b/alibi/explainers/tests/test_anchor_image.py @@ -2,9 +2,7 @@ import numpy as np from alibi.api.defaults import DEFAULT_META_ANCHOR, DEFAULT_DATA_ANCHOR_IMG -from alibi.explainers import AnchorImage -from alibi.explainers.anchor_image import AnchorImageSampler, scale_image -from alibi.explainers.anchor_image_sampler import AnchorImageSampler +from alibi.explainers.anchor_image import AnchorImage, AnchorImageSampler, scale_image def test_scale_image(): image_shape = (28, 28, 1) From 7d500edd1e99aa7e0d0fe039961f09cd5d34fb3f Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Tue, 3 Aug 2021 18:14:33 +0200 Subject: [PATCH 12/13] Fix linter --- alibi/explainers/anchor_image.py | 1 + alibi/explainers/tests/test_anchor_image.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/alibi/explainers/anchor_image.py b/alibi/explainers/anchor_image.py index b79ccddc3..e956a2f7e 100644 --- a/alibi/explainers/anchor_image.py +++ b/alibi/explainers/anchor_image.py @@ -21,6 +21,7 @@ 'slic': {'n_segments': 10, 'compactness': 10, 'sigma': .5} } + def scale_image(image: np.ndarray, scale: tuple = (0, 255)) -> np.ndarray: """ Scales an image in a specified range. diff --git a/alibi/explainers/tests/test_anchor_image.py b/alibi/explainers/tests/test_anchor_image.py index 283121218..4a7b40d7a 100644 --- a/alibi/explainers/tests/test_anchor_image.py +++ b/alibi/explainers/tests/test_anchor_image.py @@ -4,6 +4,7 @@ from alibi.api.defaults import DEFAULT_META_ANCHOR, DEFAULT_DATA_ANCHOR_IMG from alibi.explainers.anchor_image import AnchorImage, AnchorImageSampler, scale_image + def test_scale_image(): image_shape = (28, 28, 1) scaling_offset = 260 @@ -71,6 +72,7 @@ def test_sampler(models, mnist_data): assert data.shape[1] == len(np.unique(sampler.segments)) assert coverage == -1 + @pytest.mark.parametrize( "models", [("mnist-cnn-tf2.2.0",), ("mnist-cnn-tf1.15.2.h5",)], @@ -137,5 +139,3 @@ def test_anchor_image(models, mnist_data): assert len(np.unique(explanation.segments)) == len(np.unique(sampler.segments)) assert explanation.meta.keys() == DEFAULT_META_ANCHOR.keys() assert explanation.data.keys() == DEFAULT_DATA_ANCHOR_IMG.keys() - - From f967f53c9b00e80675a7e2795d6832c5ba2133a3 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Tue, 3 Aug 2021 18:14:58 +0200 Subject: [PATCH 13/13] Remove TODO comment --- alibi/explainers/anchor_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/alibi/explainers/anchor_image.py b/alibi/explainers/anchor_image.py index e956a2f7e..454158bec 100644 --- a/alibi/explainers/anchor_image.py +++ b/alibi/explainers/anchor_image.py @@ -49,7 +49,6 @@ def scale_image(image: np.ndarray, scale: tuple = (0, 255)) -> np.ndarray: class AnchorImageSampler: def __init__( self, - # TODO: Should we call `predictor`, `prediction_fn` instead? predictor: Callable, segmentation_fn: Callable, custom_segmentation: bool,