-
Notifications
You must be signed in to change notification settings - Fork 0
/
result_plot.py
66 lines (51 loc) · 2.39 KB
/
result_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import matplotlib.pyplot as plt
import numpy as np
from data_generator import DataGenerator
class ResultPlot:
def _get_target_color(self, value):
if value == 0:
return 'r'
return 'g'
def _get_dot_color(self, value):
value = abs(value - 0.5) * 2
value = value ** 0.1
value = np.clip(value, 0, 1)
return [value, value, value]
def _float_range(self, r_min, r_max, r_step):
return [round(x * r_step, 5) for x in range(int(r_min/r_step), int(r_max/r_step)+2)]
def create_dots(self, xmin=0, xmax=5, ymin=0, ymax=5, predict_function=None, step=0.05):
dots = np.zeros((1, 3))
if predict_function is None:
return dots
for x in self._float_range(xmin, xmax, step):
for y in self._float_range(ymin, ymax, step):
value = predict_function(np.array([[x, y]]))[0][0][0]
dots = np.append(dots, np.array([[x, y, value]]), axis=0)
return dots
def create_plot(self, vectors, targets, input_vector=None, input_target=None, show_plot=False, dots=None):
# unpack vector
vectors = np.array([v[0] for v in vectors])
target_colors = np.array([self._get_target_color(v) for v in targets])
fig = plt.figure()
axis = plt.gca()
axis.set_xlim([0, 3])
axis.set_ylim([0, 3])
# draw dots
if dots is not None:
dot_colors = np.array([self._get_dot_color(v) for v in dots[:, 2]])
plt.scatter(dots[:, 0], dots[:, 1], color=dot_colors, s=15)
plt.scatter(vectors[:, 0], vectors[:, 1], color=target_colors, s=20)
if input_vector is not None and input_target is not None:
# unpack vector
input_vector = input_vector[0]
input_target_color = self._get_target_color(input_target)
plt.scatter(input_vector[0], input_vector[1], color=input_target_color, s=100, marker='*')
if show_plot:
plt.show()
plt.close()
return fig
if __name__ == "__main__":
result_plot = ResultPlot()
dots = result_plot.create_dots(0, 5, 0, 5, step=0.1)
training_input_vectors, training_targets = DataGenerator().generate_training_data(50, nested=True)
result_plot.create_plot(training_input_vectors, training_targets, show_plot=True, dots=dots)