Skip to content

Commit

Permalink
Merge pull request #4160 from rohis06/nms
Browse files Browse the repository at this point in the history
Adding Non-Maximum Suppression (NMS) utility to fiftyone.utils.labels through perform_nms()
  • Loading branch information
brimoor authored Mar 15, 2024
2 parents a326539 + a80ffe3 commit 1ace4e2
Showing 1 changed file with 85 additions and 0 deletions.
85 changes: 85 additions & 0 deletions fiftyone/utils/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import fiftyone.core.labels as fol
import fiftyone.core.utils as fou
import fiftyone.core.validation as fov
import fiftyone.utils.iou as foi


def objects_to_segmentations(
Expand Down Expand Up @@ -673,3 +674,87 @@ def classifications_to_detections(
detections.append(detection)

image[out_field] = fol.Detections(detections=detections)


def perform_nms(
sample_collection,
in_field,
out_field=None,
conf_threshold=0.6,
iou_threshold=0.5,
progress=None,
):
"""Performs Non-Maximum Suppression (NMS) on the
:class:`fiftyone.core.labels.Detections` field containing detections.
NMS is a post-processing technique used in object detection to eliminate
duplicate detections and select the most relevant detected objects. This
helps reduce false positives.
Args:
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
in_field: the name of the :class:`fiftyone.core.labels.Detections` field
out_field (None): the name of the :class:`fiftyone.core.labels.Detections`
field to populate. If not specified (None), the input field is updated
in-place.
conf_threshold (0.6): a floating-point value between 0 and 1 representing
the minimum confidence score required for a detection to be considered
valid. Detections with confidence scores lower than this threshold
will be discarded during the Non-Maximum Suppression (NMS) process.
iou_threshold (0.5): a floating-point value between 0 and 1 representing
the Intersection over Union (IoU) threshold used in the NMS algorithm.
It determines the minimum overlap required between bounding boxes for
them to be considered duplicates. Bounding boxes with IoU values
greater than or equal to this threshold will be suppressed.
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
"""
fov.validate_collection_label_fields(
sample_collection, in_field, fol.Detections
)

samples = sample_collection.select_fields(in_field)

with fou.ProgressBar(progress=progress) as pb:
for sample in pb(samples):
detections_data = sample[in_field].copy()
nms_processed_detections = _perform_nms(
detections_data, conf_threshold, iou_threshold
)
if out_field is None:
sample[in_field] = nms_processed_detections
else:
sample[out_field] = nms_processed_detections
sample.save()


def _perform_nms(detections_data, conf_threshold=0.6, iou_threshold=0.5):
detections = detections_data.detections

# Sort detections by confidence in descending order
detections.sort(key=lambda x: x.confidence, reverse=True)

# Remove detections with confidence less than the conf_threshold
detections = [d for d in detections if d.confidence >= conf_threshold]

nms_detections = []

while len(detections) > 0:
# Pick the detection with highest confidence
selected_detection = detections[0]
nms_detections.append(selected_detection)
del detections[0]

# Compare with other detections for NMS
for d in detections:
if d.label == selected_detection.label:
iou = foi.compute_ious([selected_detection], [d])[0][0]
if iou >= iou_threshold:
# Remove the detection if IoU is greater than iou_threshold
detections.remove(d)

# Update detections_data with NMS results
detections_data.detections = nms_detections
return detections_data

0 comments on commit 1ace4e2

Please sign in to comment.