-
Notifications
You must be signed in to change notification settings - Fork 19
/
ood_detection.py
259 lines (228 loc) · 10.5 KB
/
ood_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import argparse
import os
import os.path as osp
import torch
import mmcv
from mmaction.apis import init_recognizer
from mmcv.parallel import collate, scatter
from operator import itemgetter
from mmaction.datasets.pipelines import Compose
from mmaction.datasets import build_dataloader, build_dataset
from mmcv.parallel import MMDataParallel
import numpy as np
from scipy.special import xlogy
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from tqdm import tqdm
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
def parse_args():
parser = argparse.ArgumentParser(description='MMAction2 test')
# model config
parser.add_argument('--config', help='test config file path')
parser.add_argument('--checkpoint', help='checkpoint file/url')
parser.add_argument('--uncertainty', choices=['BALD', 'Entropy', 'EDL'], help='the uncertainty estimation method')
parser.add_argument('--forward_pass', type=int, default=10, help='the number of forward passes')
# data config
parser.add_argument('--ind_data', help='the split file of in-distribution testing data')
parser.add_argument('--ood_data', help='the split file of out-of-distribution testing data')
# env config
parser.add_argument('--device', type=str, default='cuda:0', help='CPU/CUDA device option')
parser.add_argument('--result_prefix', help='result file prefix')
args = parser.parse_args()
return args
def apply_dropout(m):
if type(m) == torch.nn.Dropout:
m.train()
def update_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def compute_uncertainty(predictions, method='BALD'):
"""Compute the entropy
scores: (B x C x T)
"""
expected_p = np.mean(predictions, axis=-1) # mean of all forward passes (C,)
entropy_expected_p = - np.sum(xlogy(expected_p, expected_p), axis=1) # the entropy of expect_p (across classes)
if method == 'Entropy':
uncertain_score = entropy_expected_p
elif method == 'BALD':
expected_entropy = - np.mean(np.sum(xlogy(predictions, predictions), axis=1), axis=-1) # mean of entropies (across classes), (scalar)
uncertain_score = entropy_expected_p - expected_entropy
else:
raise NotImplementedError
if not np.all(np.isfinite(uncertain_score)):
uncertain_score[~np.isfinite] = 9999
return uncertain_score
def run_stochastic_inference(model, data_loader, npass=10):
# run inference
model = MMDataParallel(model, device_ids=[0])
all_confidences, all_uncertainties, all_results, all_gts = [], [], [], []
prog_bar = mmcv.ProgressBar(len(data_loader.dataset))
for i, data in enumerate(data_loader):
all_scores = []
with torch.no_grad():
for n in range(npass):
# set new random seed
update_seed(n * 1234)
scores = model(return_loss=False, **data)
# gather results
all_scores.append(np.expand_dims(scores, axis=-1))
all_scores = np.concatenate(all_scores, axis=-1) # (B, C, T)
# compute the uncertainty
uncertainty = compute_uncertainty(all_scores, method=args.uncertainty)
all_uncertainties.append(uncertainty)
# compute the predictions and save labels
mean_scores = np.mean(all_scores, axis=-1)
preds = np.argmax(mean_scores, axis=1)
all_results.append(preds)
conf = np.max(mean_scores, axis=1)
all_confidences.append(conf)
labels = data['label'].numpy()
all_gts.append(labels)
# use the first key as main key to calculate the batch size
batch_size = len(next(iter(data.values())))
for _ in range(batch_size):
prog_bar.update()
all_confidences = np.concatenate(all_confidences, axis=0)
all_uncertainties = np.concatenate(all_uncertainties, axis=0)
all_results = np.concatenate(all_results, axis=0)
all_gts = np.concatenate(all_gts, axis=0)
return all_confidences, all_uncertainties, all_results, all_gts
def run_evidence_inference(model, data_loader, evidence='exp'):
# set new random seed
update_seed(1234)
# get the evidence function
if evidence == 'relu':
from mmaction.models.losses.edl_loss import relu_evidence as get_evidence
elif evidence == 'exp':
from mmaction.models.losses.edl_loss import exp_evidence as get_evidence
elif evidence == 'softplus':
from mmaction.models.losses.edl_loss import softplus_evidence as get_evidence
else:
raise NotImplementedError
num_classes = model.cls_head.num_classes
# run inference
model = MMDataParallel(model, device_ids=[0])
all_confidences, all_uncertainties, all_results, all_gts = [], [], [], []
prog_bar = mmcv.ProgressBar(len(data_loader.dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
output = model(return_loss=False, **data)
evidence = get_evidence(torch.from_numpy(output))
alpha = evidence + 1
uncertainty = num_classes / torch.sum(alpha, dim=1)
scores = alpha / torch.sum(alpha, dim=1, keepdim=True)
all_uncertainties.append(uncertainty.numpy())
# compute the predictions and save labels
preds = np.argmax(scores.numpy(), axis=1)
all_results.append(preds)
conf = np.max(scores.numpy(), axis=1)
all_confidences.append(conf)
labels = data['label'].numpy()
all_gts.append(labels)
# use the first key as main key to calculate the batch size
batch_size = len(next(iter(data.values())))
for _ in range(batch_size):
prog_bar.update()
all_confidences = np.concatenate(all_confidences, axis=0)
all_uncertainties = np.concatenate(all_uncertainties, axis=0)
all_results = np.concatenate(all_results, axis=0)
all_gts = np.concatenate(all_gts, axis=0)
return all_confidences, all_uncertainties, all_results, all_gts
def run_inference(model, datalist_file, npass=10):
# switch config for different dataset
cfg = model.cfg
cfg.data.test.ann_file = datalist_file
cfg.data.test.data_prefix = os.path.join(os.path.dirname(datalist_file), 'videos')
evidence = cfg.get('evidence', 'exp')
# build the dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
dataloader_setting = dict(
videos_per_gpu=cfg.data.get('videos_per_gpu', 1),
workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
dist=False,
shuffle=False,
pin_memory=False)
dataloader_setting = dict(dataloader_setting, **cfg.data.get('test_dataloader', {}))
data_loader = build_dataloader(dataset, **dataloader_setting)
if not args.uncertainty == 'EDL':
all_confidences, all_uncertainties, all_results, all_gts = run_stochastic_inference(model, data_loader, npass)
else:
all_confidences, all_uncertainties, all_results, all_gts = run_evidence_inference(model, data_loader, evidence)
return all_confidences, all_uncertainties, all_results, all_gts
def main():
# build the recognizer from a config file and checkpoint file/url
model = init_recognizer(
args.config,
args.checkpoint,
device=device,
use_frames=False)
cfg = model.cfg
if not args.uncertainty == 'EDL':
# use dropout in testing stage
if 'dnn' in args.config:
model.apply(apply_dropout)
if 'bnn' in args.config:
model.test_cfg.npass = 1
# set cudnn benchmark
torch.backends.cudnn.benchmark = True
cfg.data.test.test_mode = True
result_file = os.path.join(args.result_prefix + '_result.npz')
if not os.path.exists(result_file):
# prepare result path
result_dir = os.path.dirname(result_file)
if not os.path.exists(result_dir):
os.makedirs(result_dir)
# run inference (OOD)
ood_confidences, ood_uncertainties, ood_results, ood_labels = run_inference(model, args.ood_data, npass=args.forward_pass)
# run inference (IND)
ind_confidences, ind_uncertainties, ind_results, ind_labels = run_inference(model, args.ind_data, npass=args.forward_pass)
# save
np.savez(result_file[:-4], ind_conf=ind_confidences, ood_conf=ood_confidences,
ind_unctt=ind_uncertainties, ood_unctt=ood_uncertainties,
ind_pred=ind_results, ood_pred=ood_results,
ind_label=ind_labels, ood_label=ood_labels)
else:
results = np.load(result_file, allow_pickle=True)
ind_confidences = results['ind_conf']
ood_confidences = results['ood_conf']
ind_uncertainties = results['ind_unctt'] # (N1,)
ood_uncertainties = results['ood_unctt'] # (N2,)
ind_results = results['ind_pred'] # (N1,)
ood_results = results['ood_pred'] # (N2,)
ind_labels = results['ind_label']
ood_labels = results['ood_label']
# visualize
ind_uncertainties = np.array(ind_uncertainties)
ind_uncertainties = (ind_uncertainties-np.min(ind_uncertainties)) / (np.max(ind_uncertainties) - np.min(ind_uncertainties)) # normalize
ood_uncertainties = np.array(ood_uncertainties)
ood_uncertainties = (ood_uncertainties-np.min(ood_uncertainties)) / (np.max(ood_uncertainties) - np.min(ood_uncertainties)) # normalize
dataName_ind = args.ind_data.split('/')[-2].upper()
dataName_ood = args.ood_data.split('/')[-2].upper()
if dataName_ind == 'UCF101':
dataName_ind = 'UCF-101'
if dataName_ood == 'MIT':
dataName_ood = 'MiT-v2'
if dataName_ood == 'HMDB51':
dataName_ood = 'HMDB-51'
plt.figure(figsize=(5,4)) # (w, h)
plt.rcParams["font.family"] = "Arial" # Times New Roman
fontsize = 15
plt.hist([ind_uncertainties, ood_uncertainties], 50,
density=True, histtype='bar', color=['blue', 'red'],
label=['in-distribution (%s)'%(dataName_ind), 'out-of-distribution (%s)'%(dataName_ood)])
plt.legend(fontsize=fontsize)
plt.xlabel('%s uncertainty'%(args.uncertainty), fontsize=fontsize)
plt.ylabel('density', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xlim(0, 1.01)
plt.ylim(0, 10.01)
plt.tight_layout()
plt.savefig(os.path.join(args.result_prefix + '_distribution.png'))
plt.savefig(os.path.join(args.result_prefix + '_distribution.pdf'))
if __name__ == '__main__':
args = parse_args()
# assign the desired device.
device = torch.device(args.device)
main()