Skip to content

Commit

Permalink
misc updates & fixes (#104)
Browse files Browse the repository at this point in the history
Signed-off-by: Soundarya Krishnan <soundaryak4898@gmail.com>
  • Loading branch information
soundarya98 authored Mar 10, 2021
1 parent 9cc0d2f commit bfdd5c1
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 407 deletions.
3 changes: 2 additions & 1 deletion dice_ml/explainer_interfaces/dice_KD.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range=Non
test_pred = self.predict_fn(query_instance)[0]

query_instance[self.data_interface.outcome_name] = test_pred
self.misc_init(stopping_threshold, desired_class, desired_range, test_pred)
desired_class = self.misc_init(stopping_threshold, desired_class, desired_range, test_pred)
if desired_range != None:
if desired_range[0] > desired_range[1]:
raise ValueError("Invalid Range!")
Expand Down Expand Up @@ -106,6 +106,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range=Non
stopping_threshold,
posthoc_sparsity_param,
posthoc_sparsity_algorithm, verbose)
self.cfs_preds = cfs_preds

return exp.CounterfactualExamples(data_interface=self.data_interface,
final_cfs_df=self.final_cfs_df,
Expand Down
15 changes: 8 additions & 7 deletions dice_ml/explainer_interfaces/dice_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
import math
import numpy as np
import pandas as pd
import random
Expand Down Expand Up @@ -151,7 +150,7 @@ def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desir
remaining_cfs = self.do_random_init(self.population_size - len(uniques), features_to_vary, query_instance, desired_class, desired_range)
self.cfs = np.concatenate([uniques, remaining_cfs])

def do_cf_initializations(self, total_CFs, initialization, algorithm, features_to_vary, permitted_range, desired_range, desired_class, query_instance, query_instance_df_dummies, verbose):
def do_cf_initializations(self, total_CFs, initialization, algorithm, features_to_vary, desired_range, desired_class, query_instance, query_instance_df_dummies, verbose):
"""Intializes CFs and other related variables."""
self.cf_init_weights = [total_CFs, algorithm, features_to_vary]

Expand Down Expand Up @@ -195,7 +194,9 @@ def do_param_initializations(self, total_CFs, initialization, desired_range, des

self.feature_range = self.get_valid_feature_range(normalized=False)
if len(self.cfs) != total_CFs:
self.do_cf_initializations(total_CFs, initialization, algorithm, features_to_vary, permitted_range, desired_range, desired_class, query_instance, query_instance_df_dummies, verbose)
self.do_cf_initializations(total_CFs, initialization, algorithm, features_to_vary, desired_range, desired_class, query_instance, query_instance_df_dummies, verbose)
else:
self.total_CFs = total_CFs
self.do_loss_initializations(yloss_type, diversity_loss_type, feature_weights, encoding='label')
self.update_hyperparameters(proximity_weight, sparsity_weight, diversity_weight, categorical_penalty)

Expand Down Expand Up @@ -247,7 +248,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, initialization="k
test_pred = self.predict_fn(query_instance)
self.test_pred = test_pred

self.misc_init(stopping_threshold, desired_class, desired_range, test_pred)
desired_class = self.misc_init(stopping_threshold, desired_class, desired_range, test_pred)

query_instance_df_dummies = pd.get_dummies(query_instance_orig)
for col in pd.get_dummies(self.data_interface.data_df[self.data_interface.feature_names]).columns:
Expand All @@ -256,7 +257,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, initialization="k

self.do_param_initializations(total_CFs, initialization, desired_range, desired_class, query_instance, query_instance_df_dummies, algorithm, features_to_vary, permitted_range, yloss_type, diversity_loss_type, feature_weights, proximity_weight, sparsity_weight, diversity_weight, categorical_penalty, verbose)

query_instance_df = self.find_counterfactuals(query_instance, desired_range, desired_class, features_to_vary, stopping_threshold, posthoc_sparsity_param, posthoc_sparsity_algorithm, maxiterations, thresh, verbose)
query_instance_df = self.find_counterfactuals(query_instance, desired_range, desired_class, features_to_vary, maxiterations, thresh, verbose)

return exp.CounterfactualExamples(data_interface=self.data_interface,
test_instance_df=query_instance_df,
Expand Down Expand Up @@ -357,7 +358,7 @@ def mate(self, k1, k2, features_to_vary, query_instance):
one_init[j] = query_instance[j]
return one_init

def find_counterfactuals(self, query_instance, desired_range, desired_class, features_to_vary, stopping_threshold, posthoc_sparsity_param, posthoc_sparsity_algorithm, maxiterations, thresh, verbose):
def find_counterfactuals(self, query_instance, desired_range, desired_class, features_to_vary, maxiterations, thresh, verbose):
"""Finds counterfactuals by generating cfs through the genetic algorithm"""
population = self.cfs.copy()
iterations = 0
Expand All @@ -367,7 +368,7 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class, fea
cfs_preds = [np.inf]*self.total_CFs
to_pred = None

while iterations < maxiterations or len(population) == self.total_CFs:
while iterations < maxiterations and self.total_CFs > 0:
if abs(previous_best_loss - current_best_loss) <= thresh and (self.model.model_type == 'classifier' and all(i == desired_class for i in cfs_preds) or (self.model.model_type == 'regressor' and all(desired_range[0] <= i <= desired_range[1] for i in cfs_preds))):
stop_cnt += 1
else:
Expand Down
8 changes: 2 additions & 6 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@

import numpy as np
import pandas as pd
import random
import timeit
import copy
from collections.abc import Iterable
from sklearn.neighbors import KDTree

import dice_ml.diverse_counterfactuals as exp
from dice_ml.counterfactual_explanations import CounterfactualExplanations
from dice_ml.utils.exception import UserConfigValidationException

Expand Down Expand Up @@ -444,6 +439,7 @@ def misc_init(self, stopping_threshold, desired_class, desired_range, test_pred)

elif self.model.model_type == 'regressor':
self.target_cf_range = self.infer_target_cfs_range(desired_range)
return desired_class

def infer_target_cfs_class(self, desired_class_input, original_pred,
num_output_nodes):
Expand All @@ -465,7 +461,7 @@ def infer_target_cfs_class(self, desired_class_input, original_pred,
if desired_class_input >= 0 and desired_class_input < num_output_nodes:
target_class = desired_class_input
else:
raise ValueError("Desired class should be within 0 and num_classes-1.")
raise ValueError("Desired class not present in training data!")
return target_class

def infer_target_cfs_range(self, desired_range_input):
Expand Down
Loading

0 comments on commit bfdd5c1

Please sign in to comment.