-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
130 lines (98 loc) · 4.08 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import argparse
import torch
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from config.config import get_cfg
from dataset import build_data_loader
from CoSODNet import CoSODNet
import transforms as trans
def get_args_parser():
"""
Parse arguments
"""
parser = argparse.ArgumentParser("CoSOD_Test", add_help=False)
parser.add_argument("-config_file", default="./config/cosod.yaml", metavar="FILE",
help="path to config file")
parser.add_argument("-num_works", default=1, type=int)
parser.add_argument("-batch_size", default=1, type=int)
parser.add_argument("-device_id", type=str, default="1")
parser.add_argument("-img_size", type=int, default=256)
parser.add_argument("-max_num", type=int, default=25)
parser.add_argument("-model_root_dir", default="./checkpoint")
parser.add_argument("-test_data_root", type=str, default="./dataset/test_data")
parser.add_argument("-test_datasets", nargs='+', default=["CoCA", "CoSal2015", "CoSOD3k"])
parser.add_argument("-save_dir", type=str, default='prediction')
return parser
def _get_cfg(cfg_file):
cfg = get_cfg()
cfg.merge_from_file(cfg_file)
cfg.freeze()
return cfg
def test_group(model, group_data, save_root, max_num):
img_num = group_data['imgs'].shape[1]
groups = list(range(0, img_num + 1, max_num))
if groups[-1] != img_num:
groups.append(img_num)
print(groups)
for i in range(len(groups) - 1):
if i == len(groups) - 2:
end = groups[i + 1]
start = max(0, end - max_num)
else:
start = groups[i]
end = groups[i + 1]
print(start, end)
inputs = Variable(group_data['imgs'][:, start:end].squeeze(0).cuda())
subpaths = group_data['subpaths'][start:end]
ori_sizes = group_data['ori_sizes'][start:end]
# img_name = '_'.join(subpaths[0][0][:-4].split('/')).replace(' ', '_')
with torch.no_grad():
result = model(inputs)
co_preds = result.pop("co_pred")
pred_prob = torch.sigmoid(co_preds)
save_final_path = os.path.join(save_root, subpaths[0][0].split('/')[0])
os.makedirs(save_final_path, exist_ok=True)
for p_id in range(end - start):
pre = pred_prob[p_id, :, :, :].data.cpu()
subpath = subpaths[p_id][0]
ori_size = (ori_sizes[p_id][1].item(),
ori_sizes[p_id][0].item())
transform = trans.Compose([
trans.ToPILImage(),
trans.Scale(ori_size)
])
outputImage = transform(pre)
filename = subpath.split('/')[1]
outputImage.save(os.path.join(save_final_path, filename))
def main(args):
cfg = _get_cfg(args.config_file)
model = CoSODNet(cfg)
model.cuda()
model_name = os.path.abspath('').split('/')[-1]
model_dir = args.model_root_dir
model.load_state_dict(torch.load(os.path.join(model_dir, model_name, "DMT_model.pth")))
print("Model loaded from {}".format(model_dir))
test_loaders = build_data_loader(args, mode='test')
for dataset, data_loader in test_loaders.items():
save_root = os.path.join(args.save_dir, dataset, '{}'.format(model_name))
print("testing on {}".format(dataset))
for idx, group_data in enumerate(data_loader):
print('{}/{}'.format(idx, len(data_loader)))
max_num = args.max_num
flag = True
while flag:
try:
test_group(model, group_data, save_root, max_num)
flag = False
except:
print("set max_num as {}".format(max_num-2))
max_num = max_num - 1
continue
if __name__ == '__main__':
ap = argparse.ArgumentParser("CoSOD testing script", parents=[get_args_parser()])
args = ap.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
cudnn.benchmark = True
main(args)
pass