-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathread_results.py
142 lines (120 loc) · 4.59 KB
/
read_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
##########################################################################
# Author: Gianfranco Paternò (paterno@fe.infn.it) #
# collaboration of INFN Ferrara and INFN Napoli for the next_AIM project #
# Last modification date: March 2024 #
# #
# SPDX-License-Identifier: EUPL-1.2 #
# Copyright 2024 Istituto Nazionale di Fisica Nucleare #
##########################################################################
# Read the results provided by trained models.
# It is typically used to analyze the results obtained on subsets (folds) of a given datasets.
# Import the required libraries
import numpy as np
import matplotlib.pyplot as plt
import os
import json
os.environ['MPLBACKEND'] = 'Agg'
MPLBACKEND = os.environ['MPLBACKEND']
print('MPLBACKEND:', MPLBACKEND, '\n')
# Input
###############################################################################################
# set data to read
dataset_path = '/media/paterno/Verbatim HDD/datasets/Database_DBT/Database_unified_1/sampling_1/python_output_ResNet18_TL/'
set_list_legend = ['1', '2', '3', '4', '5']
add_to_all_pre = 'fully_processed_'
add_to_all_post = '_ResNet18_TL_300x300_aug'
file_to_open = 'results_test.txt'
remove_TPR = True
do_plots = True
save_analysis = True
save_path = dataset_path
save_filename = 'results_' + add_to_all_pre.removesuffix('_') + add_to_all_post
###############################################################################################
# Read data (dictionaries) from files
set_list = [add_to_all_pre + item + add_to_all_post for item in set_list_legend]
data_set = {} # it will be a dictionary with set names as keys
for item in set_list:
filename = dataset_path + item + '/' + file_to_open
with open(filename) as f:
data = f.read()
js = json.loads(data)
keys = list(js.keys())
temp = []
for key in keys:
temp.append(js[key])
data_set[item] = temp
if remove_TPR:
if len(keys) == 9: #3 classes
itpr = 1
else: #2 classes
itpr = 6
data_set[item].pop(itpr)
if remove_TPR:
keys.pop(itpr)
# plot data (bar plots)
if do_plots and len(set_list) <= 4:
y = np.linspace(0., 1., 6) # for yticks
x = np.arange(len(keys)) # the label locations
width = 0.20 # the bar width
strtitle = 'Comparison of different datasets'
if len(set_list) == 4:
coeff = [-2/4, 2/4, 6/4, 10/4] # used to position the bars
bars_shift = 0.20 # shift all the bars
if len(set_list) == 3:
coeff = [-2/3, 1/3, 4/3]
bars_shift = 0.075
elif len(set_list) == 2:
coeff = [-1/2, 1/2]
bars_shift = 0.
else:
coeff = [0.]
bars_shift = -0.0075
strtitle = ''
fig, ax = plt.subplots(figsize=[12, 8])
fs = 16
bar_groups = []
for i, Set in enumerate(set_list):
bar_loc = x + width*coeff[i] - bars_shift
bar_groups.append(ax.bar(bar_loc, data_set[Set], width, label=set_list_legend[i]))
ax.set_ylabel('Score', fontsize=fs)
ax.set_xticks(x, keys, fontsize=fs, rotation=0)
ax.set_yticks(y, fontsize=fs, rotation=0)
ax.yaxis.set_tick_params(labelsize=fs, rotation=0)
#ax.yaxis.grid(True)
ax.set_ylim([0., 1.3])
ax.set_title(strtitle, fontsize=fs)
ax.legend(fontsize=fs, loc='upper right')
#for i, Set in enumerate(rects):
#ax.bar_label(bar_groups[i], padding=3)
fig.tight_layout()
if save_analysis:
plt.savefig(save_path + save_filename + '.jpg')
if not MPLBACKEND == 'Agg':
plt.show()
plt.close()
# calculate the statistics on the metrics
results = {}
for j, key in enumerate(keys):
temp = []
for Set in set_list:
temp.append(data_set[Set][j])
results[key] = temp
#print(results,'\n')
import statistics
metrics_stat = {}
for key in keys:
mean = round(statistics.mean(results[key]), 3)
stdev = round(statistics.stdev(results[key]), 3)
metrics_stat[key] = [mean, stdev]
metrics_str = '%s = %.2f +/- %.2f' % (key, mean, stdev)
print(metrics_str)
print('\n')
if save_analysis:
metrics_stat_file = save_path + save_filename
with open(metrics_stat_file + '.txt', 'w') as f:
f.write('metrics statistics:\n')
for key in keys:
metrics_str = '%s = %.2f +/- %.2f' % (key, metrics_stat[key][0], metrics_stat[key][1])
f.write(metrics_str + '\n')
with open(metrics_stat_file + '.json', 'w') as f:
json.dump(metrics_stat, f)