Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficientdet #256

Merged
merged 8 commits into from
Feb 8, 2023
Prev Previous commit
Next Next commit
'class_arg' and 'method' are now passed from DetectSingleShotEfficien…
…tdet pipeline
  • Loading branch information
Manojkumarmuru committed Feb 8, 2023
commit 719438f9071ae75a2dc674ab7f033c8841ae72aa
9 changes: 6 additions & 3 deletions examples/efficientdet/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,16 @@ class DetectSingleShotEfficientDet(Processor):
"""
def __init__(self, model, class_names, score_thresh, nms_thresh,
mean=pr.RGB_IMAGENET_MEAN, variances=[1.0, 1.0, 1.0, 1.0],
draw=True):
class_arg=None, renormalize=False, method=0, draw=True):
self.model = model
self.class_names = class_names
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.variances = variances
self.draw = draw
self.class_arg = class_arg
self.renormalize = renormalize
self.method = method
self.model.prior_boxes = model.prior_boxes * model.input_shape[1]

super(DetectSingleShotEfficientDet, self).__init__()
Expand All @@ -68,11 +71,11 @@ def call(self, image):
postprocessing = SequentialProcessor([
pr.Squeeze(axis=None),
pr.DecodeBoxes(self.model.prior_boxes, variances=self.variances),
RemoveClass(class_arg=None, renormalize=False),
RemoveClass(self.class_arg, self.renormalize),
ScaleBox(image_scales),
NonMaximumSuppressionPerClass(self.nms_thresh),
FilterBoxes(self.class_names, self.score_thresh),
ToBoxes2D(self.class_names, method=0)])
ToBoxes2D(self.class_names, self.method)])
outputs = process_outputs(outputs)
boxes2D = postprocessing(outputs)
if self.draw:
Expand Down