-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
247 lines (202 loc) · 11.7 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# -- coding: utf-8 --
import sys
from interaction.interaction import Interactor
from hand_key.models.resnet import resnet50
from hand_track.tracker import Tracker
from hand_detection.utils.datasets import LoadImages, LoadStreams
from hand_detection.utils.general import check_img_size, non_max_suppression, scale_coords, check_imshow
from hand_detection.utils.torch_utils import select_device, time_synchronized
import numpy as np
import argparse
import os
import time
from pathlib import Path
import cv2
import torch
import torch.backends.cudnn as cudnn
sys.path.insert(0, './hand_detection')
# os.environ['CUDA_ENABLE_DEVICES'] = '0'
# torch.cuda.set_device(0)
def detect(opt):
# 检查参数--------------------------------------------------------------------------
out, source, yolov5_weights, view, save = opt.output, opt.source, opt.yolov5_weights, opt.view, opt.save
webcam = source == '0' or source == '1' or source.endswith('.txt') or source.startswith(
'rtsp') or source.startswith('http')
pose_thres = opt.pose_thres
res50_img_size, res50_weight = opt.res50_img_size, opt.res50_weight
vid_path, vid_writer = None, None
# 使用CUDA则启用半精度浮点数
device = select_device(opt.device)
half = device.type != 'cpu'
# 检查输出文件夹
if not os.path.exists(out):
os.makedirs(out)
# 加载yolov5--------------------------------------------------------------------------
model_yolo5 = torch.load(yolov5_weights, map_location=device)['model']
model_yolo5.float().to(device).eval()
stride = int(model_yolo5.stride.max()) # model stride
yolov5_img_size = check_img_size(opt.yolov5_img_size, s=stride) # check img_size
if half:
model_yolo5.half() # to FP16
names = model_yolo5.module.names if hasattr(model_yolo5, 'module') else model_yolo5.names # 获取分类名
print('load model : {}'.format(yolov5_weights))
# 加载resnet50--------------------------------------------------------------------------
model_res50 = resnet50(num_classes=42, img_size=res50_img_size[0])
model_res50.to(device).eval()
if half:
model_res50.half()
chkpt = torch.load(res50_weight, map_location=device)
model_res50.load_state_dict(chkpt)
print('load model : {}'.format(res50_weight))
# 初始化追踪状态器--------------------------------------------------------------------------
tracker = Tracker(opt.pose_cfg, pose_thres=pose_thres)
# 初始化交互模块--------------------------------------------------------------------------
interactor = Interactor()
# Dataloader--------------------------------------------------------------------------
# 使用摄像头
if webcam:
view = check_imshow()
# view = True
cudnn.benchmark = True # 加快在视频中恒定大小图像的推断
dataset = LoadStreams(source, img_size=yolov5_img_size, stride=stride)
else:
view = True
dataset = LoadImages(source, img_size=yolov5_img_size, stride=stride)
# 开始推理--------------------------------------------------------------------------------------------
t0 = time.time()
# 预热模型
img = torch.zeros((1, 3, yolov5_img_size, yolov5_img_size), device=device)
_ = model_yolo5(img.half() if half else img) if device.type != 'cpu' else None
img = torch.zeros((1, 3, res50_img_size[0], res50_img_size[0]), device=device)
_ = model_res50(img.half() if half else img) if device.type != 'cpu' else None
# 获取当前时间
str_time_now = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
# 对每张/帧图片进行处理。(识别文件的路径,yolo尺寸(3,h,w),原始图片(h,w,3),none)
for frame_idx, (path, img, im0s, vid_cap) in enumerate(dataset):
# yolo部分图片预处理--------------------------------------------------------------------------
img = torch.from_numpy(img).to(device) # 将图片转成tensor并指认设备
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0) # 加了一个维度
# Inference
t1 = time_synchronized()
pred = model_yolo5(img, augment=opt.augment)[0] # 得到预测结果列表
# 1.非极大值抑制 最大检测数2 最小框边长:握拳的边长*0.9
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=[0])
# 从pred列表中取出数据,如果是视频就只有一次循环,仅包含tensor[2,6] 2只手,6个信息
for b, det_box in enumerate(pred): # enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标
if webcam: # batch_size >= 1
p, s, im0 = path[b], '%g: ' % b, im0s[b].copy()
else:
p, s, im0 = path, '', im0s
s += '%gx%g ' % img.shape[2:]
save_path = str(Path(out).__str__() + os.sep.__str__() + str_time_now + '_' + Path(p).name.__str__())
# save_path = str(Path(out) / Path(p).name)
if det_box is not None and len(det_box): # 如果检测到目标,则跟踪两只手的状态,并显示跟踪效果
# 将预测框从yolov5_img_size缩放回原始坐标
det_box[:, :4] = scale_coords(img.shape[2:], det_box[:, :4], im0.shape).round()
# 手检测计数
s += '%g %ss ' % (len(det_box), names[0])
# 2.对每只手进行关键点预测,并返回原始坐标
keypoint_list = []
for one_hand_box in det_box.data.cpu().numpy():
# resnet50图像预处理 -----------------------------------------------------------------
# 将检测出的手部图像切割出来,进行关键点检测
cut_img = im0[int(one_hand_box[1]):int(one_hand_box[3]),
int(one_hand_box[0]):int(one_hand_box[2])] # 先切y轴,再切x轴
# cv2.imshow('test', cut_img)
# cv2.waitKey(0)
key_img = cv2.resize(cut_img, (res50_img_size[1], res50_img_size[0]),
interpolation=cv2.INTER_CUBIC) # 缩放成res50输入尺寸
key_img = (key_img.astype(np.float32) - 128.) / 256.
key_img = torch.from_numpy(key_img.transpose(2, 0, 1)).unsqueeze_(0)
if torch.cuda.is_available():
key_img = key_img.cuda()
key_img = key_img.half() if half else key_img.float()
# 模型推理
key_output = model_res50(key_img)
key_output = key_output.cpu().detach().numpy()
key_output = np.squeeze(key_output) # 预测值域[0,1]
hand_ = []
for i in range(int(key_output.shape[0] / 2)):
x = (key_output[i * 2 + 0] * float(cut_img.shape[1])) + int(one_hand_box[0])
y = (key_output[i * 2 + 1] * float(cut_img.shape[0])) + int(one_hand_box[1])
hand_.append([x, y])
keypoint_list.append(hand_)
# 3.处理结果:追踪&打印----------------------------------------------------------------------------
tracker.update(det_box, keypoint_list)
else:
tracker.update_nodata([0, 1]) # 当因为运动太快而丢失检测时可以忽略
tracker.plot(im0)
# print('tracker: ', tracker.get_order())
# 传入食指坐标以及手势编号判断进入交互功能
im0 = interactor.interact(im0, tracker.get_order())
# Print time (yolov5 + NMS + keypoint + track + draw + interact + ...)
t2 = time_synchronized()
# print('%s (%.3fs)' % (s, t2 - t1)) # 时间
# Stream results 是否在ui界面中展示
# if view is None:
cv2.imshow(p, im0)
if cv2.waitKey(1) == ord('q'): # 按q退出
raise StopIteration
# Save results (image with detections)
if save:
if dataset.mode == 'image':
print('saving img!')
cv2.imwrite(save_path, im0)
else:
print('saving video!')
if vid_path != save_path: # new video
vid_path = save_path
if isinstance(vid_writer, cv2.VideoWriter):
vid_writer.release() # release previous video writer
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*opt.fourcc), fps, (w, h))
vid_writer.write(im0)
if save:
print('Results saved to %s' % os.getcwd() + os.sep + out)
print(save_path)
print('Done. (%.3fs)' % (time.time() - t0))
def parse_argument():
parser = argparse.ArgumentParser()
# file/folder, 0-webcam。不支持图片格式 inference/input/test_video2.mp4
# parser.add_argument('--source', type=str, default='inference/input/Snipaste_2023-02-27_15-36-08.png', help='source') # 输入视频文件
parser.add_argument('--source', type=str, default='0', help='source') # 输入摄像头画面。0:笔记本自带摄像头;1:usb摄像头
# 输出文件夹
parser.add_argument('--output', type=str, default='inference/output', help='output folder')
# 输出视频格式
parser.add_argument('--fourcc', type=str, default='mp4v', help='output video codec (verify ffmpeg support)')
# 是否展示结果
parser.add_argument('--view', default=True, help='display results')
# 是否保存
parser.add_argument("--save", type=str, default=False, help='save results')
# 是否使用显卡+半精度
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') # 默认值可改为cpu --by Steven
# yolov5模型路径 hand_v5s/best best_YOLOv5l best_yolo5s_half yolov5l_best
parser.add_argument('--yolov5_weights', type=str, default='inference/weights/hand_weight/best_YOLOv5l.pt')
# yolov5输入尺寸
parser.add_argument('--yolov5_img_size', type=int, default=640, help='inference size (pixels)')
# yolov5推理时进行多尺度,翻转等操作(TTA)
parser.add_argument('--augment', action='store_true', default=False, help='augmented inference')
# nms置信度阈值
parser.add_argument('--conf-thres', type=float, default=0.1, help='object confidence threshold')
# nms的IOU阈值
parser.add_argument('--iou-thres', type=float, default=0.3, help='IOU threshold for NMS')
# tracker二维角度约束阈值
parser.add_argument("--pose_thres", type=float, default=0.4, help='pose angle threshold')
# tracker手势字典设置文件
parser.add_argument("--pose_cfg", type=str, default='inference/weights/cfg_pose.json', help='pose_cfg')
# res50模型路径
parser.add_argument('--res50_weight', type=str, default='inference/weights/pose_weight/resnet50_2021-418.pth',
help='res50_weight')
# res50输入尺寸
parser.add_argument('--res50_img_size', type=tuple, default=(256, 256), help='res50_img_size')
opt = parser.parse_args()
print(opt)
return opt
if __name__ == '__main__':
with torch.no_grad():
detect(parse_argument())