Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficientdet #250

Merged
merged 29 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e62e9d7
Input, conv, squeeze excitation and output blocks of EfficientNet mod…
Manojkumarmuru Jan 25, 2023
24bbf15
Conv2D() function wrapped inside MB_conv2D() function
Manojkumarmuru Jan 25, 2023
5e11969
SE_ratio renamed to excite_ratios
Manojkumarmuru Jan 25, 2023
b28dbf5
repeat_MB() function refactored
Manojkumarmuru Jan 25, 2023
7014478
Converted to pep8 standard
Manojkumarmuru Jan 25, 2023
8dc3404
Docstrings updated
Manojkumarmuru Jan 25, 2023
bc0fd6d
Ascii diagram refined
Manojkumarmuru Jan 25, 2023
5a29084
Static arguments taken out of for loop inside build_head_conv2D() fun…
Manojkumarmuru Jan 26, 2023
74fa38c
Bounding box visualization reverted to old version
Manojkumarmuru Jan 26, 2023
786904d
demo.py working for old weights- first version
Manojkumarmuru Jan 26, 2023
3c54bd0
duplicate efficientdet_portprocess() function removed
Manojkumarmuru Jan 26, 2023
43b8992
Changed DetectSingleShotEfficientDet according to the working version…
Manojkumarmuru Jan 26, 2023
1d7b02d
Basic cleanup
Manojkumarmuru Jan 26, 2023
b609ad7
Some cleanup
Manojkumarmuru Jan 26, 2023
043b589
Some cleanup
Manojkumarmuru Jan 26, 2023
04b5e3c
Converted to pep8 standard
Manojkumarmuru Jan 26, 2023
abb74af
draw.py removed as it is no longer used for time being
Manojkumarmuru Jan 26, 2023
74e5739
Converted to pep8 standard
Manojkumarmuru Jan 26, 2023
525b47f
Incorrect exception message corrected
Manojkumarmuru Jan 26, 2023
bfbe4dc
EFFICIENTDETD0COCO made
Manojkumarmuru Jan 26, 2023
987b000
Basic cleanup
Manojkumarmuru Jan 26, 2023
580d3f7
EFFICIENTDETD0VOC added
Manojkumarmuru Jan 26, 2023
822ca1b
Converted to pep8 standard
Manojkumarmuru Jan 26, 2023
2c32b09
demo_video.py added
Manojkumarmuru Jan 26, 2023
d2f5ca9
Docstrings updated
Manojkumarmuru Jan 27, 2023
5ef312f
Minor refactoring
Manojkumarmuru Jan 27, 2023
2c1f3ab
EFFICIENTDETCOCO class added for EfficientDet models D0-D7
Manojkumarmuru Jan 27, 2023
bad21ae
get_class_names() function call replaced with self.class_names inside…
Manojkumarmuru Jan 27, 2023
7d3cacb
draw.py and DrawBoxes2d() class restored
Manojkumarmuru Jan 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
demo.py working for old weights- first version
  • Loading branch information
Manojkumarmuru committed Jan 26, 2023
commit 786904d65d3361b080062c7b610085eddb894340
30 changes: 14 additions & 16 deletions examples/efficientdet/demo.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
from tensorflow.keras.utils import get_file
from paz.backend.image import show_image, write_image
from paz.datasets import get_class_names
from paz.pipelines.detection import DetectSingleShot
from paz.processors.image import LoadImage
from detection import DetectSingleShotEfficientDet, get_class_name_efficientdet
from efficientdet import EFFICIENTDETD0
from detection import efficientdet_preprocess, efficientdet_postprocess

IMAGE_PATH = ('/home/manummk95/Desktop/efficientdet_BKP/paz/examples/efficientdet/img.jpg')

IMAGE_PATH = ('/home/manummk95/Desktop/efficientdet_BKP/paz/'
'examples/efficientdet/000132.jpg')
WEIGHT_PATH = (
'https://github.com/oarriaga/altamira-data/releases/download/v0.16/')
WEIGHT_FILE = 'efficientdet-d0-VOC-VOC_weights.hdf5'

if __name__ == "__main__":
raw_image = LoadImage()(IMAGE_PATH)
model = EFFICIENTDETD0(num_classes=21, base_weights='COCO',
head_weights=None)
weights_path = get_file(WEIGHT_FILE, WEIGHT_PATH + WEIGHT_FILE,
cache_subdir='paz/models')
model.load_weights(weights_path)

detect = DetectSingleShot(model, get_class_names('VOC'), 0.5, 0.45)
detections = detect(raw_image)
model = EFFICIENTDETD0(base_weights='COCO', head_weights='COCO')
model.prior_boxes = model.prior_boxes*512.0
image_size = model.input_shape[1]
input_image, image_scales = efficientdet_preprocess(raw_image, image_size)

show_image(detections['image'])
write_image('detections.png', detections['image'])
outputs = model(input_image)

image, detections = efficientdet_postprocess(
model, outputs, image_scales, raw_image)
print(detections)
write_image('paz_postprocess.jpg', image)
print('task completed')
333 changes: 332 additions & 1 deletion examples/efficientdet/detection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import numpy as np
from paz import processors as pr
from paz.backend.image.draw import draw_rectangle
from paz.abstract import SequentialProcessor, Processor
from draw import (compute_text_bounds, draw_opaque_box, make_box_transparent,
put_text)

B_IMAGENET_STDEV, G_IMAGENET_STDEV, R_IMAGENET_STDEV = 57.3 , 57.1, 58.4
RGB_IMAGENET_STDEV = (R_IMAGENET_STDEV, G_IMAGENET_STDEV, B_IMAGENET_STDEV)
from paz.backend.image import resize_image
from paz.processors.image import RGB_IMAGENET_MEAN

class DrawBoxes2D(pr.DrawBoxes2D):
"""Draws bounding boxes from Boxes2D messages.
Expand Down Expand Up @@ -77,3 +82,329 @@ def call(self, image, boxes2D):
self.scale, text_color, text_thickness)
put_text(*args)
return image


class DetectSingleShot(Processor):
"""Single-shot object detection prediction.

# Arguments
model: Keras model.
class_names: List of strings indicating the class names.
score_thresh: Float between [0, 1]
nms_thresh: Float between [0, 1].
mean: List of three elements indicating the per channel mean.
draw: Boolean. If ``True`` prediction are drawn in the returned image.
"""
def __init__(self, model, class_names, score_thresh, nms_thresh,
mean=pr.BGR_IMAGENET_MEAN, variances=[0.1, 0.1, 0.2, 0.2],
draw=True):
self.model = model
self.class_names = class_names
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.variances = variances
self.draw = draw

super(DetectSingleShot, self).__init__()
preprocessing = SequentialProcessor(
[pr.ResizeImage(self.model.input_shape[1:3]),
pr.ConvertColorSpace(pr.RGB2BGR),
pr.SubtractMeanImage(mean),
pr.CastImage(float),
pr.ExpandDims(axis=0)])
postprocessing = SequentialProcessor(
[pr.Squeeze(axis=None),
pr.DecodeBoxes(self.model.prior_boxes, self.variances),
pr.NonMaximumSuppressionPerClass(self.nms_thresh),
pr.FilterBoxes(self.class_names, self.score_thresh)])
self.predict = pr.Predict(self.model, preprocessing, postprocessing)

self.denormalize = pr.DenormalizeBoxes2D()
self.draw_boxes2D = pr.DrawBoxes2D(self.class_names)
self.wrap = pr.WrapOutput(['image', 'boxes2D'])

def call(self, image):
boxes2D = self.predict(image)
boxes2D = self.denormalize(image, boxes2D)
if self.draw:
image = self.draw_boxes2D(image, boxes2D)
return self.wrap(image, boxes2D)


class DetectSingleShotEfficientDet(Processor):
"""Single-shot object detection prediction.

# Arguments
model: Keras model.
class_names: List of strings indicating the class names.
score_thresh: Float between [0, 1]
nms_thresh: Float between [0, 1].
mean: List of three elements indicating the per channel mean.
draw: Boolean. If ``True`` prediction are drawn in the returned image.
"""
def __init__(self, model, class_names, score_thresh, nms_thresh,
mean=pr.BGR_IMAGENET_MEAN, variances=[1.0, 1.0, 1.0, 1.0],
draw=True):
self.model = model
self.class_names = class_names
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.variances = variances
self.draw = draw

super(DetectSingleShotEfficientDet, self).__init__()
preprocessing = SequentialProcessor(
[pr.ResizeImage(self.model.input_shape[1:3]),
pr.ConvertColorSpace(pr.RGB2BGR),
pr.SubtractMeanImage(mean),
pr.CastImage(float),
pr.ExpandDims(axis=0)])
self.predict = pr.Predict(self.model, preprocessing, None)

postprocessing = SequentialProcessor(
[pr.Squeeze(axis=None),
pr.DecodeBoxes(model.prior_boxes*512.0, variances=self.variances),
ScaleBox(np.array(1.1953124860847313)),
pr.NonMaximumSuppressionPerClass(0.4),
pr.FilterBoxes(get_class_name_efficientdet('COCO'), 0.4)])

self.postprocessing = postprocessing
self.denormalize = pr.DenormalizeBoxes2D()
self.draw_boxes2D = pr.DrawBoxes2D(self.class_names)
self.wrap = pr.WrapOutput(['image', 'boxes2D'])

def call(self, image):
outputs = self.predict(image)
# outputs = process_outputs(outputs)
outputs = self.postprocessing(outputs)
# boxes2D = self.denormalize(image, outputs)
boxes2D = outputs
if self.draw:
image = self.draw_boxes2D(image.astype('uint8'), boxes2D)
return image


def get_class_name_efficientdet(dataset_name):
if dataset_name == 'COCO':
return ['person', 'bicycle', 'car', 'motorcycle',
'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', '0', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', '0', 'backpack', 'umbrella', '0',
'0', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', '0', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', '0', 'dining table', '0', '0',
'toilet', '0', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', '0', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']

elif dataset_name == 'VOC':
return ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'motorbike', 'person',
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']


def efficientdet_postprocess(model, outputs, image_scales, raw_images=None):
"""EfficientDet output postprocessing function.

# Arguments
model: EfficientDet model
class_outputs: Tensor, logits for all classes corresponding to the
features associated with the box coordinates at each feature levels.
box_outputs: Tensor, box coordinate offsets for the corresponding prior
boxes at each feature levels.
image_scale: Numpy array, scale to reconstruct each of the raw images
to original size from the resized image.
raw_images: Numpy array, RGB image to draw the detections on the image.

# Returns
image: Numpy array, RGB input image with detections overlaid.
outputs: List of Box2D, containing the detections with bounding box
and class details.
"""
outputs = process_outputs(outputs)
postprocessing = SequentialProcessor(
[pr.Squeeze(axis=None),
pr.DecodeBoxes(model.prior_boxes, variances=[1, 1, 1, 1]),
ScaleBox(1.0), pr.NonMaximumSuppressionPerClass(0.4),
pr.FilterBoxes(get_class_name_efficientdet('COCO'), 0.4)])
outputs = postprocessing(outputs)
draw_boxes2D = pr.DrawBoxes2D(get_class_name_efficientdet('COCO'))
image = draw_boxes2D(raw_images.astype('uint8'), outputs)
return image, outputs


def process_outputs(outputs):
"""Merges all feature levels into single tensor and combines box offsets
and class scores.

# Arguments
class_outputs: Tensor, logits for all classes corresponding to the
features associated with the box coordinates at each feature levels.
box_outputs: Tensor, box coordinate offsets for the corresponding prior
boxes at each feature levels.
num_levels: Int, number of levels considered at efficientnet features.
num_classes: Int, number of classes in the dataset.

# Returns
outputs: Numpy array, Processed outputs by merging the features at
all levels. Each row corresponds to box coordinate offsets and
sigmoid of the class logits.
"""
outputs = outputs[0]
boxes, classes = outputs[:, :4], outputs[:, 4:]
s1, s2, s3, s4 = np.hsplit(boxes, 4)
boxes = np.concatenate([s2, s1, s4, s3], axis=1)
boxes = boxes[np.newaxis]
classes = classes[np.newaxis]
outputs = np.concatenate([boxes, classes], axis=2)
return outputs


def scale_box(predictions, image_scales=None):
"""
# Arguments
image: Numpy array.
boxes: Numpy array of shape `[num_boxes, N]` where N >= 4.
# Returns
Numpy array of shape `[num_boxes, N]`.
"""

if image_scales is not None:
boxes = predictions[:, :4]
scales = image_scales[np.newaxis][np.newaxis]
boxes = boxes * scales
predictions = np.concatenate([boxes, predictions[:, 4:]], 1)
return predictions


class ScaleBox(Processor):
"""Scale box coordinates of the prediction.
"""
def __init__(self, scales):
super(ScaleBox, self).__init__()
self.scales = scales

def call(self, boxes):
boxes = scale_box(boxes, self.scales)
return boxes


def efficientdet_preprocess(image, image_size):
"""Preprocess image for EfficientDet model.

# Arguments
image: Tensor, raw input image to be preprocessed
of shape [bs, h, w, c]
image_size: Tensor, size to resize the raw image
of shape [bs, new_h, new_w, c]

# Returns
image: Numpy array, resized and preprocessed image
image_scale: Numpy array, scale to reconstruct each of
the raw images to original size from the resized
image.
"""

preprocessing = SequentialProcessor([
pr.CastImage(float),
pr.SubtractMeanImage(mean=RGB_IMAGENET_MEAN),
DivideStandardDeviationImage(standard_deviation=RGB_IMAGENET_STDEV),
ScaledResize(image_size=image_size),
])
image, image_scale = preprocessing(image)
return image, image_scale

class DivideStandardDeviationImage(Processor):
"""Divide channel-wise standard deviation to image.

# Arguments
mean: List of length 3, containing the channel-wise mean.
"""
def __init__(self, standard_deviation):
self.standard_deviation = standard_deviation
super(DivideStandardDeviationImage, self).__init__()

def call(self, image):
return image / self.standard_deviation

class ScaledResize(Processor):
"""Resizes image by returning the scales to original image.

# Arguments
image_size: Int, desired size of the model input.

# Returns
output_images: Numpy array, image resized to match
image size.
image_scales: Numpy array, scale to reconstruct the
raw image from the output_images.
"""
def __init__(self, image_size):
self.image_size = image_size
super(ScaledResize, self).__init__()

def call(self, image):
"""
# Arguments
image: Numpy array, raw input image.
"""
crop_offset_y = np.array(0)
crop_offset_x = np.array(0)
height = np.array(image.shape[0]).astype('float32')
width = np.array(image.shape[1]).astype('float32')
image_scale_y = np.array(self.image_size).astype('float32') / height
image_scale_x = np.array(self.image_size).astype('float32') / width
image_scale = np.minimum(image_scale_x, image_scale_y)
scaled_height = (height * image_scale).astype('int32')
scaled_width = (width * image_scale).astype('int32')
scaled_image = resize_image(image, (scaled_width, scaled_height))
scaled_image = scaled_image[
crop_offset_y: crop_offset_y + self.image_size,
crop_offset_x: crop_offset_x + self.image_size,
:]
output_images = np.zeros((self.image_size,
self.image_size,
image.shape[2]))
output_images[:scaled_image.shape[0],
:scaled_image.shape[1],
:scaled_image.shape[2]] = scaled_image
image_scale = 1 / image_scale
output_images = output_images[np.newaxis]
return output_images, image_scale


def efficientdet_postprocess(model, outputs, image_scales, raw_images=None):
"""EfficientDet output postprocessing function.

# Arguments
model: EfficientDet model
class_outputs: Tensor, logits for all classes corresponding to the
features associated with the box coordinates at each feature levels.
box_outputs: Tensor, box coordinate offsets for the corresponding prior
boxes at each feature levels.
image_scale: Numpy array, scale to reconstruct each of the raw images
to original size from the resized image.
raw_images: Numpy array, RGB image to draw the detections on the image.

# Returns
image: Numpy array, RGB input image with detections overlaid.
outputs: List of Box2D, containing the detections with bounding box
and class details.
"""
outputs = process_outputs(outputs)
postprocessing = SequentialProcessor(
[pr.Squeeze(axis=None),
pr.DecodeBoxes(model.prior_boxes, variances=[1, 1, 1, 1]),
ScaleBox(image_scales), pr.NonMaximumSuppressionPerClass(0.4),
pr.FilterBoxes(get_class_name_efficientdet('COCO'), 0.8)])
outputs = postprocessing(outputs)
draw_boxes2D = pr.DrawBoxes2D(get_class_name_efficientdet('COCO'))
image = draw_boxes2D(raw_images.astype('uint8'), outputs)
return image, outputs