Skip to content

Commit

Permalink
Merge pull request #1 from Cpruce/master
Browse files Browse the repository at this point in the history
adding changes by Cpruce
  • Loading branch information
ps48 authored Jan 10, 2018
2 parents 1c51787 + 92f76e8 commit 1ee739c
Showing 1 changed file with 170 additions and 35 deletions.
205 changes: 170 additions & 35 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def clip_boxes_graph(boxes, window):
y2 = tf.maximum(tf.minimum(y2, wy2), wy1)
x2 = tf.maximum(tf.minimum(x2, wx2), wx1)
clipped = tf.concat([y1, x1, y2, x2], axis=1, name="clipped_boxes")
clipped.set_shape((clipped.shape[0], 4))
return clipped


Expand Down Expand Up @@ -737,6 +738,146 @@ def refine_detections(rois, probs, deltas, window, config):
class_scores[keep][..., np.newaxis]))
return result

def refine_detections_graph(rois, probs, deltas, window, config):
"""Refine classified proposals and filter overlaps and return final
detections.
Inputs:
rois: [N, (y1, x1, y2, x2)] in normalized coordinates
probs: [N, num_classes]. Class probabilities.
deltas: [N, num_classes, (dy, dx, log(dh), log(dw))]. Class-specific
bounding box deltas.
window: (y1, x1, y2, x2) in image coordinates. The part of the image
that contains the image excluding the padding.
Returns detections shaped: [N, (y1, x1, y2, x2, class_id, score)]
"""
# Class IDs per ROI
class_ids = tf.argmax(probs, axis=1, output_type=tf.int32)
# Class probability of the top class of each ROI
scores_select_size = class_ids.shape[0]
scores_select = tf.range(scores_select_size)
score_indices = tf.stack([scores_select,class_ids], axis=1)
class_scores = tf.gather_nd(probs, score_indices)
# Class-specific bounding box deltas
deltas_range_size = deltas.shape[0]
deltas_range = tf.range(deltas_range_size)
deltas_indices = tf.stack([deltas_range, class_ids], axis=1)
deltas_specific = tf.gather_nd(deltas, deltas_indices)
# Apply bounding box deltas
# Shape: [boxes, (y1, x1, y2, x2)] in normalized coordinates
refined_rois = apply_box_deltas_graph(
rois, deltas_specific * config.BBOX_STD_DEV)

# Convert coordiates to image domain
# TODO: better to keep them normalized until later
height, width = config.IMAGE_SHAPE[:2]
refined_rois *= tf.constant([height, width, height, width], dtype=tf.float32)
#np.array([height, width, height, width])
# Clip boxes to image window
refined_rois = clip_boxes_graph(refined_rois, window)
# Round and cast to int since we're deadling with pixels now
refined_rois = tf.to_int32(tf.rint(refined_rois))
# TODO: Filter out boxes with zero area

# Filter out background boxes
keep = tf.where(class_ids > 0)[:,0]
# Filter out low confidence boxes
if config.DETECTION_MIN_CONFIDENCE:
keep = tf.sets.set_intersection(
keep, tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[:,0])

# Apply per-class NMS
pre_nms_class_ids = tf.gather(class_ids, keep) #class_ids[keep]
pre_nms_scores = tf.gather(class_scores, keep) #class_scores[keep]
pre_nms_rois = tf.gather(refined_rois, keep) #refined_rois[keep]
print('pre_nms_class_ids = {}'.format(pre_nms_class_ids.shape))
print('pre_nms_scores = {}'.format(pre_nms_scores.shape))
print('pre_nms_rois = {}'.format(pre_nms_rois.shape))

uniq_pre_nms_class_ids = tf.unique(pre_nms_class_ids)[0]

nms_keep = []
def nms_keep_map(class_id):
print('pre_nms_class_ids.shape', pre_nms_class_ids.shape)
print('class_id', class_id.shape)
ixs = tf.where(pre_nms_class_ids == tf.expand_dims(class_id, -1))[0]

# Apply NMS
class_keep = tf.image.non_max_suppression(
tf.to_float(tf.gather(pre_nms_rois,ixs)),
tf.gather(pre_nms_scores, ixs),
ixs.shape[0],
iou_threshold=config.DETECTION_NMS_THRESHOLD)

# Map indicies
return tf.gather(keep, tf.gather(ixs, class_keep))

print('uniq_pre_nms_class_ids: {}'.format(uniq_pre_nms_class_ids.shape))
nms_keep = tf.to_int64(tf.unique(tf.concat(
tf.map_fn(nms_keep_map, uniq_pre_nms_class_ids), axis=0))[0])

print(keep.shape, nms_keep.shape)
print(keep.dtype, nms_keep.dtype)
"""keep = tf.sets.set_intersection(
tf.expand_dims(keep, 0),
tf.expand_dims(tf.sparse_to_dense(nms_keep), 0))[1]
"""#tf.to_int32(
#np.intersect1d(keep, nms_keep).astype(np.int32)

result_keep = tf.concat([keep,nms_keep], axis = 0)
print('result_keep: {}'.format(result_keep.shape))
output_keep, idx_keep, count_keep = tf.unique_with_counts(result_keep)
print('output_keep: {}, idx_keep: {}, count_keep: {}, 2const: {}'.format(output_keep.shape,
idx_keep.shape, count_keep.shape, tf.constant(2).shape))
new_idx_keep = tf.where(count_keep >= tf.constant(2))[1] # keep coordinates of true elems
print('keep bbefore: {}, new_idx_keep: {}'.format(keep.shape, new_idx_keep.shape))
keep = tf.gather(output_keep, new_idx_keep)

# Keep top detections
roi_count = tf.convert_to_tensor(config.DETECTION_MAX_INSTANCES)
print('class scores: {}'.format(class_scores.shape))
class_scores_keep = tf.gather(class_scores, keep)
num_keep = tf.minimum(tf.shape(class_scores_keep)[0], roi_count)
top_ids = tf.nn.top_k(class_scores_keep, k=num_keep, sorted=True)[1]

#np.argsort(class_scores[keep])[::-1][:roi_count]
print('keep before: {}'.format(keep.shape))
keep = tf.gather(keep, top_ids)
print('keep after: {}'.format(keep.shape))

refined_rois_keep = tf.gather(tf.to_float(refined_rois), keep)
class_ids_keep = tf.gather(tf.to_float(class_ids), keep)[..., tf.newaxis]
class_scores_keep = tf.gather(class_scores, keep)[..., tf.newaxis]
print('refined_rois_keep = ', refined_rois_keep.shape)
print('class_ids_keep = ', class_ids_keep.shape)
print('class_scores_keep = ', class_scores_keep.shape)

# Arrange output as [N, (y1, x1, y2, x2, class_id, score)]
# Coordinates are in image domain.
detections = tf.concat((refined_rois_keep, class_ids_keep,
class_scores_keep), axis=1)
print('detections.shape = ', detections.shape)
#np.hstack((refined_rois[keep],
# class_ids[keep][..., np.newaxis],
# class_scores[keep][..., np.newaxis]))

# Pad with zeros if detections < DETECTION_MAX_INSTANCES
num_detections = tf.shape(detections)[0]
gap = roi_count - num_detections
print(gap, roi_count, num_detections)
pred = tf.less(tf.constant(0), gap)
#assert gap >= 0
#if gap > 0:
# paddings = tf.constant([[0, gap], [0, 0]])
# detections = tf.pad(detections, paddings, "CONSTANT")
def pad_detections():
print(detections.shape)
return tf.pad(detections, [(0, gap), (0, 0)], "CONSTANT")

detections = tf.cond(pred, pad_detections, lambda: detections)

return tf.to_float(detections)

class DetectionLayer(KE.Layer):
"""Takes classified proposal boxes and their bounding box deltas and
Expand All @@ -751,29 +892,32 @@ def __init__(self, config=None, **kwargs):
self.config = config

def call(self, inputs):
def wrapper(rois, mrcnn_class, mrcnn_bbox, image_meta):
detections_batch = []
_, _, window, _ = parse_image_meta(image_meta)
for b in range(self.config.BATCH_SIZE):
detections = refine_detections(
rois[b], mrcnn_class[b], mrcnn_bbox[b], window[b], self.config)
# Pad with zeros if detections < DETECTION_MAX_INSTANCES
gap = self.config.DETECTION_MAX_INSTANCES - detections.shape[0]
assert gap >= 0
if gap > 0:
detections = np.pad(
detections, [(0, gap), (0, 0)], 'constant', constant_values=0)
detections_batch.append(detections)

# Stack detections and cast to float32
# TODO: track where float64 is introduced
detections_batch = np.array(detections_batch).astype(np.float32)
# Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels
return np.reshape(detections_batch, [self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])

# Return wrapped function
return tf.py_func(wrapper, inputs, tf.float32)
config = self.config
rois = inputs[0]
mrcnn_class = inputs[1]
mrcnn_bbox = inputs[2]
image_meta = inputs[3]
print(rois.shape, mrcnn_class.shape, mrcnn_bbox.shape, image_meta.shape)

#parse_image_meta can be reused as slicing works same way in TF & numpy
window = get_image_meta_window(image_meta)
print('window after: ', window.shape)
detections_batch = utils.batch_slice(
[rois, mrcnn_class, mrcnn_bbox, window],
lambda x, y, w, z: refine_detections_graph(x, y, w, z, self.config),
self.config.IMAGES_PER_GPU)

# Stack detections and cast to float32
# TODO: track where float64 is introduced
#detections_batch = tf.stack(detections_batch)
#detections_batch = np.array(detections_batch).astype(np.float32)
# Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels

return tf.reshape(
detections_batch,
[self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])


def compute_output_shape(self, input_shape):
return (None, self.config.DETECTION_MAX_INSTANCES, 6)
Expand Down Expand Up @@ -2476,7 +2620,6 @@ def compose_image_meta(image_id, image_shape, window, active_class_ids):
return meta


# Two functions (for Numpy and TF) to parse image_meta tensors.
def parse_image_meta(meta):
"""Parses an image info Numpy array to its components.
See compose_image_meta() for more details.
Expand All @@ -2487,19 +2630,11 @@ def parse_image_meta(meta):
active_class_ids = meta[:, 8:]
return image_id, image_shape, window, active_class_ids


def parse_image_meta_graph(meta):
"""Parses a tensor that contains image attributes to its components.
def get_image_meta_window(meta):
"""Parses an image info Numpy array to its components.
See compose_image_meta() for more details.
meta: [batch, meta length] where meta length depends on NUM_CLASSES
"""
image_id = meta[:, 0]
image_shape = meta[:, 1:4]
window = meta[:, 4:8]
active_class_ids = meta[:, 8:]
return [image_id, image_shape, window, active_class_ids]

return meta[:, 4:8] # (y1, x1, y2, x2) window of image in in pixels

def mold_image(images, config):
"""Takes RGB images with 0-255 values and subtraces
Expand Down

0 comments on commit 1ee739c

Please sign in to comment.