-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathloss_plot.py
91 lines (68 loc) · 2.66 KB
/
loss_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Usage: python3 scripts/loss_plot.py train-log-002.txt title
# You should not include figures generated by this script in your academic paper, because
# 1. This version of HandyRL doesn't display all the results of the matches.
# 2. Smoothing method in this script is not a simple moving average.
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
n = 10
def kernel(n):
a = np.array(list(range(1, 1 + (n+1)//2)) + list(range(1 + n//2, 1, -1)))
return a / a.sum()
def get_loss_list(path):
epoch_data_list = []
epoch_list = [0]
step_list = [0]
game_list = [0]
f = open(path)
lines = f.readlines()
prev_line = ''
for line in lines:
if line.startswith('updated'):
epoch_list.append(len(epoch_list))
step_list.append(int(line.split('(')[1].rstrip().rstrip(')')))
if line.startswith('loss'):
elms = line.split()
epoch_data_list.append({})
for e in elms[2:]:
name, loss = e.split(':')
loss = float(loss)
epoch_data_list[-1][name] = loss
if line.startswith('epoch '):
print(line, len(epoch_list))
if ' ' in prev_line:
game = int(prev_line.split()[-1])
game_list.append(game)
prev_line = line
game_list = game_list[:len(epoch_data_list)]
if (len(epoch_list) - n//2 + 1 - n//2) - n <= 0:
raise Exception('Range to average `n` is too large, set small nubmer')
clipped_epoch_list = epoch_list[n//2:-n//2+1]
clipped_step_list = step_list[n//2:-n//2+1]
clipped_game_list = game_list[n//2:-n//2+1]
kn = kernel(n)
start_epoch = {}
averaged_loss_lists = {}
for name in epoch_data_list[0].keys():
data = [d[name] for d in epoch_data_list]
averaged_loss_lists[name] = np.convolve(data, kn, mode='valid')
start_epoch = 0
return clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_loss_lists, start_epoch
# Plot
flatui = ["#9b59b6", "#95a5a6", "#34495e", "#3498db", "#e74c3c", "#2ecc71", "#b22222"]
sns.set_palette(sns.color_palette(flatui, 24))
clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_loss_lists, start_epoch = get_loss_list(sys.argv[1])
fig_num = len(averaged_loss_lists)
fig = plt.figure()
for i, (k, v) in enumerate(averaged_loss_lists.items()):
ax = fig.add_subplot((fig_num - 1) // 2 + 1, 2, i + 1)
start = start_epoch
ax.plot(clipped_game_list[start:], v[start:], label=k)
ax.set_xlabel('Games')
ax.set_ylabel(k)
ax.set_title(f"loss: {k}")
ax.grid()
fig.suptitle(sys.argv[2])
fig.tight_layout()
plt.show()