Skip to content

Commit

Permalink
1. Printing CV variab. info, 2. Printing split info in GENERATE
Browse files Browse the repository at this point in the history
  • Loading branch information
jvalegre committed Aug 15, 2024
1 parent dfed597 commit 287f158
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 23 deletions.
12 changes: 8 additions & 4 deletions robert/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,13 @@ def separate_test(self, csv_df, csv_X, csv_y):
if self.args.test_set != 0:
if len(csv_df[self.args.y]) < 50:
self.args.test_set = 0
self.args.log.write(f'\nx WARNING! The database contains {len(csv_df[self.args.y])} datapoints, the data will be split in training and validation with no points separated as external test set (too few points to reach a reliable splitting). You can bypass this option and include test points with "--auto_test False".')
self.args.test_set
self.args.log.write(f'\nx WARNING! The database contains {len(csv_df[self.args.y])} datapoints, the data will be split into training and validation sets with no points separated as external test set (too few points to reach a reliable splitting). You can bypass this option and include test points with "--auto_test False".')
elif self.args.test_set < 0.1:
self.args.test_set = 0.1
self.args.log.write(f'\nx WARNING! The test_set option was set to {self.args.test_set}, this value will be raised to 0.1 to include a meaningful amount of points in the test set. You can bypass this option and include less test points with "--auto_test False".')
else:
self.args.log.write(f'\no Before hyproptimization, {int(self.args.test_set*100)}% of the data was separated as test set, using an even distribution of data points across the range of y values. The remaining data points will be split into training and validation.')


if self.args.test_set > 0:
n_of_points = int(len(csv_X)*(self.args.test_set))
Expand Down Expand Up @@ -278,8 +280,10 @@ def adjust_split(self, csv_df):
# when using databases with a small number of points (less than 250 datapoints)
if len(csv_df[self.args.y]) < 250 and self.args.split.lower() == 'rnd':
self.args.split = 'KN'
self.args.log.write(f'\nx WARNING! The database contains {len(csv_df[self.args.y])} datapoints, KN data splitting will replace the default random splitting (too few points to reach a reliable splitting). You can use random splitting with "--auto_kn False".')

self.args.log.write(f'\no The database contains {len(csv_df[self.args.y])} datapoints, k-means data splitting will replace the default random splitting to select training and validation sets (too few points to reach a reliable splitting). You can use random splitting with "--auto_kn False".')
elif self.args.split.lower() == 'rnd':
self.args.log.write(f'\no The database contains {len(csv_df[self.args.y])} datapoints, the default random splitting will be used to select training and validation sets.')

# when using unbalanced databases (using an arbitrary imbalance ratio of 10 with the two halves of the data)
mid_value = max(csv_df[self.args.y])-((max(csv_df[self.args.y])-min(csv_df[self.args.y]))/2)
high_vals = len([i for i in csv_df[self.args.y] if i >= mid_value])
Expand Down
9 changes: 7 additions & 2 deletions robert/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
print_predict,
shap_analysis,
PFI_plot,
outlier_plot
outlier_plot,
print_cv_var
)
from robert.utils import (load_variables,
load_db_n_params,
Expand Down Expand Up @@ -94,14 +95,18 @@ def __init__(self, **kwargs):
Xy_data = load_n_predict(self, params_dict, Xy_data, mapie=True)

# save predictions for all sets
path_n_suffix, name_points = save_predictions(self,Xy_data,params_dir,Xy_test_df)
path_n_suffix, name_points, Xy_data = save_predictions(self,Xy_data,params_dir,Xy_test_df)

# represent y vs predicted y
colors = plot_predictions(self, params_dict, Xy_data, path_n_suffix)

# print results
_ = print_predict(self,Xy_data,params_dict,path_n_suffix)

# print CV variation (for regression)
if params_dict['type'].lower() == 'reg':
_ = print_cv_var(self,Xy_data,params_dict,path_n_suffix)

# SHAP analysis
_ = shap_analysis(self,Xy_data,params_dict,path_n_suffix)

Expand Down
32 changes: 29 additions & 3 deletions robert/predict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ def save_predictions(self,Xy_data,params_dir,Xy_test_df):
Saves CSV files with the different sets and their predicted results
'''

cv_type = Xy_data['cv_type']
Xy_orig_df, Xy_path, params_df, _, _, suffix_title = load_dfs(self,params_dir,'no_print')
base_csv_name = '_'.join(os.path.basename(Path(Xy_path)).replace('.csv','_').split('_')[0:2])
base_csv_name = f'PREDICT/{base_csv_name}'
Expand All @@ -344,6 +343,7 @@ def save_predictions(self,Xy_data,params_dir,Xy_test_df):
valid_path = f'{base_csv_path}_valid_{suffix_title}.csv'
_ = Xy_orig_valid.to_csv(valid_path, index = None, header=True)
print_preds += f'\n - Validation set with predicted results: PREDICT/{os.path.basename(valid_path)}'
Xy_data['csv_pred_name'] = os.path.basename(valid_path)
# saves test predictions
Xy_orig_test = None
if 'X_test_scaled' in Xy_data:
Expand All @@ -354,7 +354,8 @@ def save_predictions(self,Xy_data,params_dir,Xy_test_df):
test_path = f'{base_csv_path}_test_{suffix_title}.csv'
_ = Xy_orig_test.to_csv(test_path, index = None, header=True)
print_preds += f'\n - Test set with predicted results: PREDICT/{os.path.basename(test_path)}'

Xy_data['csv_pred_name'] = os.path.basename(test_path)

# saves prediction for external test in --csv_test
if self.args.csv_test != '':
Xy_test_df[f'{params_df["y"][0]}_pred'] = Xy_data['y_pred_csv_test']
Expand All @@ -367,6 +368,7 @@ def save_predictions(self,Xy_data,params_dir,Xy_test_df):
csv_test_path = f'{folder_csv}/{csv_name}'
_ = Xy_test_df.to_csv(csv_test_path, index = None, header=True)
print_preds += f'\n - External set with predicted results: PREDICT/csv_test/{os.path.basename(csv_test_path)}'
Xy_data['csv_pred_name'] = f'csv_test/{os.path.basename(csv_test_path)}'

self.args.log.write(print_preds)

Expand All @@ -385,7 +387,7 @@ def save_predictions(self,Xy_data,params_dir,Xy_test_df):
if Xy_orig_test is not None:
name_points['test'] = Xy_orig_test[self.args.names]

return path_n_suffix, name_points
return path_n_suffix, name_points, Xy_data


def print_predict(self,Xy_data,params_dict,path_n_suffix):
Expand Down Expand Up @@ -445,6 +447,30 @@ def print_predict(self,Xy_data,params_dict,path_n_suffix):
dat_results.close()


def print_cv_var(self,Xy_data,params_dict,path_n_suffix):
'''
Prints results of the predictions for all the sets
'''

shap_plot_file = f'{os.path.dirname(path_n_suffix)}/CV_variability_{os.path.basename(path_n_suffix)}.png'
path_reduced = '/'.join(f'{shap_plot_file}'.replace('\\','/').split('/')[-2:])
if Xy_data['cv_type'] == 'loocv':
cv_type = f'Jackknife CV'
else:
kfold = Xy_data['cv_type'].split('_')[-3]
cv_type = f'{kfold}-fold CV'

print_cv_var = f"\n o Cross-validation variation (with {cv_type}) graph saved in {path_reduced}:"
print_cv_var += f"\n - Standard deviations saved in PREDICT/{Xy_data['csv_pred_name']} in the {params_dict['y']}_pred_sd column"
print_cv_var += f"\n - Average SD = {round(Xy_data['avg_sd'],2)}"

cv_var_file = f'{os.path.dirname(path_n_suffix)}/CV_variability_{os.path.basename(path_n_suffix)}.dat'
self.args.log.write(print_cv_var)
dat_results = open(cv_var_file, "w")
dat_results.write(print_cv_var)
dat_results.close()


def shap_analysis(self,Xy_data,params_dict,path_n_suffix):
'''
Plots and prints the results of the SHAP analysis
Expand Down
6 changes: 4 additions & 2 deletions robert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,10 +1079,12 @@ def calc_ci_n_sd(self,loaded_model,data,set_mapie):
# assuming normal population doesn't add very significant errors even in low-data regimes (i.e. for 20 points,
# Student's t value is 2.086 instead of 1.96)
dict_alpha = {0.05: 1.96, 0.1: 1.645, 0.5: 0.674}
y_test_sd = y_interval_width / (2 * dict_alpha[self.args.alpha])
y_pred_sd = y_interval_width / (2 * dict_alpha[self.args.alpha])
avg_sd = np.mean(y_pred_sd) # average SD

# Add 'y_pred_SET_cv' and 'y_pred_SET_sd' entry to data dictionary
data[f'y_pred_{set_mapie}_sd'] = y_test_sd
data[f'y_pred_{set_mapie}_sd'] = y_pred_sd
data['avg_sd'] = avg_sd
data['cv_type'] = cv_type

return data
Expand Down
26 changes: 14 additions & 12 deletions robert/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from matplotlib.legend_handler import HandlerPatch
import matplotlib.patches as mpatches
from matplotlib.patches import FancyArrowPatch, ArrowStyle
import matplotlib.lines as lines
import numpy as np
import pandas as pd
from pathlib import Path
Expand Down Expand Up @@ -398,12 +399,12 @@ def plot_metrics(self,params_path,suffix_title,verify_metrics,verify_results):

sb.reset_defaults()
sb.set(style="ticks")
_, (ax1, ax2) = plt.subplots(1, 2, sharex=False, sharey= False, figsize=(7.45,6),
fig, (ax1, ax2) = plt.subplots(1, 2, sharex=False, sharey= False, figsize=(7.45,6),
constrained_layout=True, gridspec_kw={
'width_ratios': [1, 1.3],
'wspace': 0.05})
'wspace': 0.07})

width_1 = 0.7 # respect to the original size of the bar (i.e. single bar takes whole graph)
width_1 = 0.67 # respect to the original size of the bar (i.e. single bar takes whole graph)
width_2 = 0.75
for test_metric,test_name,test_color in zip(verify_metrics['metrics'],verify_metrics['test_names'],verify_metrics['colors']):
# flawed models
Expand All @@ -422,22 +423,15 @@ def plot_metrics(self,params_path,suffix_title,verify_metrics,verify_results):
# styling preferences
ax1.tick_params(axis='y', labelsize=14)
ax1.tick_params(axis='x', labelsize=14)
ax2.tick_params(axis='y', labelsize=14, labelleft=False)
ax2.tick_params(axis='y', labelsize=14, labelleft=False, left = False)
ax2.tick_params(axis='x', labelsize=14)

# title and labels of the axis
ax1.set_ylabel(f'{verify_results["error_type"].upper()}', fontsize=14)

# borders
ax1.spines[['right', 'top']].set_visible(False)
ax2.spines[['right', 'top']].set_visible(False)

# titles
fontsize = 14
title_verify = f"VERIFY tests of {os.path.basename(path_n_suffix)}"
ax1.set_title(f'Model & cross-valid.', fontsize=14, style='italic', y=0.96)
ax2.set_title('"Flawed" models', fontsize=14, style='italic', y=0.96)
plt.suptitle(title_verify, y=1.06, fontsize = fontsize, fontweight="bold")
ax2.spines[['right', 'top', 'left']].set_visible(False)

# axis limits
max_val = max(verify_metrics['metrics'])
Expand All @@ -452,6 +446,14 @@ def plot_metrics(self,params_path,suffix_title,verify_metrics,verify_results):
ax1.set_ylim([min_lim, max_lim])
ax2.set_ylim([min_lim, max_lim])

# titles and line separating titles
fontsize = 14
title_verify = f"VERIFY tests of {os.path.basename(path_n_suffix)}"
ax1.set_title(f'Model & cross-valid.', fontsize=14, y=0.96)
ax2.set_title('"Flawed" models', fontsize=14, y=0.96)
fig.add_artist(lines.Line2D([0.41, 0.62], [0.975, 0.975],color='k',linewidth=1)) # format: [x1,x2], [y1,y2]
plt.suptitle(title_verify, y=1.06, fontsize = fontsize, fontweight="bold")

# add threshold line and arrow indicating passed test direction
arrow_length = np.abs(max_lim-min_lim)/11

Expand Down

0 comments on commit 287f158

Please sign in to comment.