-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval_txt.py
63 lines (55 loc) · 2.57 KB
/
eval_txt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report, precision_recall_fscore_support
class Eval():
def __init__(self):
self.exp_name = 'swap_epoch_v2_valid_set2'
self.results_folder = os.path.join('results', self.exp_name)
self.gt_folder = '/mnt/data3/jinyue/dataset/opensource/AffWild2/Annotations/annotations/Annotations/annotations/Annotations/all_annotations'
def eval_expr(self):
predicted = []
true_labels = []
results_folder = os.path.join(self.results_folder, 'EXPR_Set')
for txt_file in os.listdir(results_folder):
with open(os.path.join(results_folder, txt_file), 'r') as f_pred:
preds = f_pred.readlines()[1:]
# print(len(preds))
with open(os.path.join(self.gt_folder, 'EXPR_Set', txt_file), 'r') as f_gt:
for idx, line in enumerate(f_gt.readlines()[1:]):
gt = int(line.strip())
if idx < len(preds):
if gt != -1 and int(preds[idx].strip()) != -1:
true_labels.append(gt)
predicted.append(int(preds[idx].strip()))
acc = accuracy_score(true_labels, predicted)
total_f1 = f1_score(true_labels, predicted, average='macro')
print(acc, total_f1)
return acc, total_f1
# def eval_au(self, eval_loader):
# self.model.eval()
# predicted = []
# true_labels = []
# correct = 0
# total = 0
# for imgs, labels in tqdm(iter(eval_loader)):
# imgs = imgs.to(self.conf.device)
# labels = labels.to(self.conf.device)
# logits = self.model(imgs)[1]
# predicts = torch.greater(logits, 0).type_as(labels)
# cmp = predicts.eq(labels).cpu().numpy()
# correct += cmp.sum()
# total += len(cmp) * 12
# predicted += predicts.cpu().numpy().tolist()
# true_labels += labels.cpu().numpy().tolist()
# acc = correct / total
# print('acc:', acc)
# total_f1 = f1_score(true_labels, predicted, average='macro')
# print('f1:', total_f1)
# class_names = ['AU1', 'AU2', 'AU4', 'AU6', 'AU7', 'AU10', 'AU12',
# 'AU15', 'AU23', 'AU24', 'AU25', 'AU26']
# print(classification_report(true_labels,
# predicted, target_names=class_names))
# return acc, total_f1
if __name__ == '__main__':
eval_module = Eval()
eval_module.eval_expr()
# eval_module.eval_au()