This repository has been archived by the owner on Dec 9, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 84
/
Copy pathutils.py
155 lines (133 loc) · 5.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import numpy as np
# Because of error when using a virutal env
# https://markhneedham.com/blog/2018/05/04/python-runtime-error-osx-matplotlib-not-installed-as-framework-mac/
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_recall_curve
from sklearn.utils.multiclass import unique_labels
def plot_confusion_matrix(y_true, y_pred, classes,
normalize=False,
title=None,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Only use the labels that appear in the data
classes = classes[unique_labels(y_true, y_pred)]
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax
class IssueLabeler:
def __init__(self,
body_text_preprocessor,
title_text_preprocessor,
model,
class_names=['bug', 'feature_request', 'question']):
"""
Parameters
----------
body_text_preprocessor: ktext.preprocess.processor
the text preprocessor trained on issue bodies
title_text_preprocessor: ktext.preprocess.processor
text preprocessor trained on issue titles
model: tensorflow.keras.models
a keras model that takes as input two tensors: vectorized
issue body and issue title.
class_names: list
class names as they correspond to the integer indices supplied to the model.
"""
self.body_pp = body_text_preprocessor
self.title_pp = title_text_preprocessor
self.model = model
self.class_names = class_names
def get_probabilities(self, body:str, title:str):
"""
Get probabilities for the each class.
Parameters
----------
body: str
the issue body
title: str
the issue title
Returns
------
Dict[str:float]
Example
-------
>>> issue_labeler = IssueLabeler(body_pp, title_pp, model)
>>> issue_labeler.get_probabilities('hello world', 'hello world')
{'bug': 0.08372017741203308,
'feature': 0.6401631832122803,
'question': 0.2761166989803314}
"""
#transform raw text into array of ints
vec_body = self.body_pp.transform([body])
vec_title = self.title_pp.transform([title])
# get predictions
probs = self.model.predict(x=[vec_body, vec_title]).tolist()[0]
return {k:v for k,v in zip(self.class_names, probs)}
def plot_precision_recall_vs_threshold(y, y_hat, class_names, precision_threshold):
"plot precision recall curves focused on precision."
# credit: https://github.com/ageron/handson-ml/blob/master/03_classification.ipynb
assert len(class_names)-1 <= y_hat.shape[-1], 'number of class names must equal number of classes in the data'
assert y.shape == y_hat.shape, 'shape of ground_truth and predictions must be the same.'
for class_name in class_names:
class_int = class_names.index(class_name)
precisions, recalls, thresholds = precision_recall_curve(y[:, class_int], y_hat[:, class_int])
# get the first index of the precision that meets the threshold
precision_idx = np.argmax(precisions >= precision_threshold)
# find the exact probability at that threshold
prob_thresh = thresholds[precision_idx]
# find the exact recall at that threshold
recall_at_thresh = recalls[precision_idx]
plt.figure(figsize=(8, 4))
plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
plt.axhline(y=precision_threshold, label=f'{precision_threshold:.2f}', linewidth=1)
plt.xlabel("Threshold", fontsize=11)
plt.legend(loc="lower left", fontsize=10)
plt.title(f'Precision vs. Recall For Label: {class_name}')
plt.ylim([0, 1])
plt.xlim([0, 1])
plt.show()
print(f'Label "{class_name}" @ {precision_threshold:.2f} precision:')
print(f' Cutoff: {prob_thresh:.2f}')
print(f' Recall: {recall_at_thresh:.2f}')
print('\n')