forked from broadinstitute/keras-rcnn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request broadinstitute#96 from jhung0/final_detection_layer
Final detection layer
- Loading branch information
Showing
7 changed files
with
209 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import keras.engine.topology | ||
import keras.backend | ||
import keras_rcnn.backend | ||
import keras_rcnn.layers.object_detection._object_proposal | ||
|
||
class Detection(keras.engine.topology.Layer): | ||
""" | ||
Get final detections + labels by unscaling back to image space, applying regression deltas, | ||
choosing box coordinates, and removing extra detections via NMS | ||
# Arguments | ||
threshold: objects with maximum score less than threshold are thrown out | ||
test_nms: A float representing the threshold for deciding whether boxes overlap too much with respect to IoU | ||
""" | ||
def __init__(self, threshold = 0.05, test_nms = 0.5, **kwargs): | ||
self.threshold = threshold | ||
|
||
self.TEST_NMS = test_nms | ||
|
||
super(Detection, self).__init__(**kwargs) | ||
|
||
def build(self, input_shape): | ||
|
||
super(Detection, self).build(input_shape) | ||
|
||
def call(self, x, **kwargs): | ||
""" | ||
# Inputs | ||
rois: output of proposal target (1, N, 4) | ||
pred_deltas: predicted deltas (1, N, 4*classes) | ||
pred_scores: score distributions (1, N, classes) | ||
metadata: image information (1, 3) | ||
# Returns | ||
pred_boxes: final predicted boxes of the predicted class (1, N, 4) | ||
pred_scores: score distribution over all classes (1, N, classes), note the box only corresponds to the most | ||
probable class, not the other classes | ||
""" | ||
rois, pred_deltas, pred_scores, metadata = x[0], x[1], x[2], x[3] | ||
|
||
rois = rois[0, :, :] | ||
pred_deltas = pred_deltas[0, :, :] | ||
pred_scores = pred_scores[0, :, :] | ||
|
||
# unscale back to raw image space | ||
|
||
boxes = rois / metadata[0][2] | ||
|
||
# Apply bounding-box regression deltas | ||
pred_boxes = keras_rcnn.backend.bbox_transform_inv(boxes, pred_deltas) | ||
|
||
pred_boxes = keras_rcnn.backend.clip(pred_boxes, metadata[0][:2]) | ||
|
||
# Final detections | ||
|
||
# for each object, get the top class score and corresponding bbox, apply nms | ||
pred_classes = keras.backend.argmax(pred_scores, axis=1) | ||
pred_classes = keras.backend.cast(pred_classes, 'int32') | ||
|
||
# keep detections above threshold | ||
|
||
indices_threshold = keras_rcnn.backend.where(keras.backend.greater(keras.backend.max(pred_scores, axis=1), self.threshold)) | ||
indices_threshold = keras.backend.reshape(indices_threshold, (-1,)) | ||
pred_scores = keras.backend.gather(pred_scores, indices_threshold) | ||
pred_boxes = keras.backend.gather(pred_boxes, indices_threshold) | ||
|
||
indices = keras.backend.arange(0, keras.backend.shape(pred_scores)[0]) | ||
pred_scores_classes = keras_rcnn.backend.gather_nd(pred_scores, keras.backend.concatenate([keras.backend.expand_dims(indices), keras.backend.expand_dims(pred_classes)], axis=1)) | ||
indices_boxes = keras.backend.concatenate([4 * pred_classes, 4 * pred_classes + 1, 4 * pred_classes + 2, 4 * pred_classes + 3], 0) | ||
indices = keras.backend.tile(indices, [4]) | ||
|
||
pred_boxes = keras_rcnn.backend.gather_nd(pred_boxes, keras.backend.concatenate([keras.backend.expand_dims(indices), keras.backend.expand_dims(indices_boxes)], axis=1)) | ||
pred_boxes = keras.backend.reshape(pred_boxes, (-1, 4)) | ||
|
||
indices = keras_rcnn.backend.non_maximum_suppression(pred_boxes, pred_scores_classes, keras.backend.shape(pred_boxes)[0], self.TEST_NMS) | ||
pred_scores = keras.backend.gather(pred_scores, indices) | ||
pred_boxes = keras.backend.gather(pred_boxes, indices) | ||
|
||
return [keras.backend.expand_dims(pred_boxes, 0), keras.backend.expand_dims(pred_scores, 0)] | ||
|
||
|
||
def compute_output_shape(self, input_shape): | ||
return [(1, None, 4), (1, None, input_shape[2][2])] | ||
|
||
def compute_mask(self, inputs, mask=None): | ||
return 2 * [None] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import keras.backend | ||
import keras.utils | ||
import numpy | ||
|
||
import keras_rcnn.layers | ||
|
||
|
||
class TestDetection: | ||
def test_call(self): | ||
num_classes = 3 | ||
proposal_target = keras_rcnn.layers.Detection() | ||
|
||
pred_boxes = numpy.random.random((1, 100, 4 * num_classes)) | ||
pred_boxes = keras.backend.variable(pred_boxes) | ||
|
||
proposals = numpy.random.choice(range(0, 224), (1, 100, 4)) | ||
proposals = keras.backend.variable(proposals) | ||
|
||
pred_scores = numpy.random.random((1, 100, num_classes)) | ||
pred_scores = keras.backend.variable(pred_scores) | ||
|
||
metadata = keras.backend.variable([[224, 224, 1.5]]) | ||
|
||
boxes, classes = proposal_target.call([proposals, pred_boxes, pred_scores, metadata]) | ||
|
||
assert keras.backend.eval(classes).shape[:2] == keras.backend.eval(boxes).shape[:2] | ||
|
||
assert keras.backend.eval(boxes).shape[-1] == 4 | ||
|
||
assert keras.backend.eval(classes).shape[-1] == num_classes | ||
|