forked from Shreya-Adgaonkar/GREE-COCO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
50 lines (40 loc) · 1.32 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from __future__ import division
from utils.utils import *
from utils.datasets import *
import cv2
from PIL import Image
import torch
from torchvision import transforms
def resize(image, size):
image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
return image
def yolo_prediction(model, device, image,class_names):
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
imgs = transforms.ToTensor()(Image.fromarray(image))
c, h, w = imgs.shape
img_sacle = [w / 416, h / 416, w / 416, h / 416]
imgs = resize(imgs, 416)
imgs = imgs.unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
outputs = model(imgs)
outputs = non_max_suppression(outputs, conf_thres=0.5, nms_thres=0.45)
# print(outputs)
objects = []
try:
outputs = outputs[0].cpu().data
for i, output in enumerate(outputs):
item = []
item.append(class_names[int(output[-1])])
item.append(float(output[4]))
box = [int(value * img_sacle[i]) for i, value in enumerate(output[:4])]
x1,y1,x2,y2 = box
x = int((x2+x1)/2)
y = int((y1+y2)/2)
w = x2-x1
h = y2-y1
item.append([x,y,w,h])
objects.append(item)
except:
pass
return objects