Skip to content

Commit

Permalink
softnms
Browse files Browse the repository at this point in the history
  • Loading branch information
“iscyy” committed Aug 8, 2022
1 parent 7515630 commit 1a0d56f
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 2 deletions.
81 changes: 81 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from multiprocessing.pool import ThreadPool
from pathlib import Path
from subprocess import check_output
from utils.metrics import bbox_iou
from zipfile import ZipFile

import cv2
Expand Down Expand Up @@ -672,6 +673,86 @@ def clip_coords(boxes, shape):
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2

def soft_nms(prediction, conf_thres=0.25, iou_thres=0.45, multi_label=False):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""

nc = prediction.shape[2] - 5 # number of classes
# xc = prediction[..., 4] > conf_thres # candidates

# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
time_limit = 10.0 # seconds to quit after

multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
soft_nms = True

t = time.time()
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
x = x[x[:, 4] > conf_thres] # confidence
x = x[(x[:, 2:4] > min_wh).all(1) & (x[:, 2:4] < max_wh).all(1)]
if len(x) == 0:
continue

# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf

# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])

# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
else: # best class only
conf, j = x[:, 5:].max(1)
x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)[conf.view(-1) > conf_thres]

if len(x) == 0:
continue

x = x[x[:, 4].argsort(descending=True)] # sort by confidence

# Batched NMS
det_max = []
cls = x[:, -1] # classes

for c in cls.unique():
dc = x[cls == c]
n = len(dc)
#print(n)
if n == 1:
det_max.append(dc)
continue
elif n > 30000:
dc = dc[:30000]
if soft_nms:
sigma = 0.5
while len(dc):
det_max.append(dc[:1])
if len(dc) == 1:
break
iou = bbox_iou(dc[0], dc[1:])
dc = dc[1:]
dc[:, 4] *= torch.exp(-iou ** 2 / sigma)
dc = dc[dc[:, 4] > conf_thres]
if len(det_max):
det_max = torch.cat(det_max)
#output[xi] = det_max[(-det_max[:, 4]).argsort()]
output[xi] = det_max[(-det_max[:, 4]).argsort()]
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {time_limit}s exceeded')
break # time limit exceeded

return output

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=(), max_det=300):
Expand Down
9 changes: 7 additions & 2 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from utils.callbacks import Callbacks
from utils.datasets import create_dataloader
from utils.general import (LOGGER, box_iou, check_dataset, check_img_size, check_requirements, check_yaml,
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
coco80_to_coco91_class, colorstr, soft_nms, increment_path, non_max_suppression, print_args,
scale_coords, xywh2xyxy, xyxy2xywh)
from utils.metrics import ConfusionMatrix, ap_per_class
from utils.plots import output_to_target, plot_images, plot_val_study
Expand Down Expand Up @@ -120,6 +120,7 @@ def run(data,
plots=True,
callbacks=Callbacks(),
compute_loss=None,
soft=False,
):
# Initialize/load model and set device
training = model is not None
Expand Down Expand Up @@ -209,7 +210,10 @@ def run(data,
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t3 = time_sync()
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
if soft:
out = soft_nms(out, conf_thres, iou_thres, multi_label=True)
else:
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
dt[2] += time_sync() - t3

# Metrics
Expand Down Expand Up @@ -347,6 +351,7 @@ def parse_opt():
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
parser.add_argument('--soft', action='store_true', help='use soft nms')
opt = parser.parse_args()
opt.data = check_yaml(opt.data) # check YAML
opt.save_json |= opt.data.endswith('coco.yaml')
Expand Down

0 comments on commit 1a0d56f

Please sign in to comment.