Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding changes by Cpruce #1

Merged
merged 5 commits into from
Jan 10, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Remove py_func for cross-environment serializability. Convert Detecti…
…onLayer/refine_detections to tf
  • Loading branch information
Cpruce committed Jan 4, 2018
commit 296d5b55206586fb77ca074d7da66594f1d6eae5
209 changes: 185 additions & 24 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,165 @@ def refine_detections(rois, probs, deltas, window, config):
class_scores[keep][..., np.newaxis]))
return result

def refine_detections_graph(rois, probs, deltas, window, config):
"""Refine classified proposals and filter overlaps and return final
detections.

Inputs:
rois: [N, (y1, x1, y2, x2)] in normalized coordinates
probs: [N, num_classes]. Class probabilities.
deltas: [N, num_classes, (dy, dx, log(dh), log(dw))]. Class-specific
bounding box deltas.
window: (y1, x1, y2, x2) in image coordinates. The part of the image
that contains the image excluding the padding.

Returns detections shaped: [N, (y1, x1, y2, x2, class_id, score)]
"""
# Class IDs per ROI
class_ids = tf.argmax(probs, axis=1, output_type=tf.int32)
# Class probability of the top class of each ROI
scores_select_size = class_ids.shape[0]
scores_select = tf.range(scores_select_size)
score_indices = tf.stack([scores_select,class_ids], axis=1)
class_scores = tf.gather_nd(probs, score_indices)
# Class-specific bounding box deltas
deltas_range_size = deltas.shape[0]
deltas_range = tf.range(deltas_range_size)
deltas_indices = tf.stack([deltas_range, class_ids], axis=1)
deltas_specific = tf.gather_nd(deltas, deltas_indices)
# Apply bounding box deltas
# Shape: [boxes, (y1, x1, y2, x2)] in normalized coordinates
refined_rois = apply_box_deltas_graph(
rois, deltas_specific * config.BBOX_STD_DEV)
# Convert coordiates to image domain
# TODO: better to keep them normalized until later
height, width = config.IMAGE_SHAPE[:2]
refined_rois *= tf.constant([height, width, height, width], dtype=tf.float32)
#np.array([height, width, height, width])
# Clip boxes to image window
refined_rois = clip_boxes_graph(refined_rois, window)
# Round and cast to int since we're deadling with pixels now
refined_rois = tf.to_int32(tf.rint(refined_rois))

# TODO: Filter out boxes with zero area

# Filter out background boxes
keep = tf.where(class_ids > 0)[0]
# Filter out low confidence boxes
if config.DETECTION_MIN_CONFIDENCE:
keep = tf.sets.set_intersection(
keep, tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[0])

# Apply per-class NMS
pre_nms_class_ids = tf.gather(class_ids, keep) #class_ids[keep]
pre_nms_scores = tf.gather(class_scores, keep) #class_scores[keep]
pre_nms_rois = tf.gather(refined_rois, keep) #refined_rois[keep]

uniq_pre_nms_class_ids = tf.unique(pre_nms_class_ids)[0]

"""
i = tf.constant(0)
uniq_pre_nms_class_ids_size = uniq_pre_nms_class_ids.shape[0]
while_condition = \
lambda i, nms_keep, pre_nms_class_ids, uniq_pre_nms_class_ids: \
i < uniq_pre_nms_class_ids_size

def class_keep_nms(i, nms_keep, pre_nms_class_ids, uniq_pre_nms_class_ids):
class_id = uniq_pre_nms_class_ids[i]
ixs = tf.where(pre_nms_class_ids == class_id)[0]

# Apply NMS
class_keep = tf.image.non_max_suppression(
tf.to_float(tf.gather(pre_nms_rois,ixs)),
tf.gather(pre_nms_scores, ixs),
ixs.shape[0],
iou_threshold=config.DETECTION_NMS_THRESHOLD)

# Map indicies
class_keep = tf.gather(keep, tf.gather(ixs, class_keep))

if nms_keep == []:
nms_keep = class_keep
else:
nms_keep = tf.unique(tf.concat([nms_keep, class_keep], 0))[0]
#tf.sets.set_union(nms_keep, class_keep)

return [i+1, nms_keep, pre_nms_class_ids, uniq_pre_nms_class_ids]

print(pre_nms_class_ids.shape, uniq_pre_nms_class_ids.shape, uniq_pre_nms_class_ids_size)
nms_keep = tf.while_loop(while_condition, class_keep_nms,
[i, nms_keep, pre_nms_class_ids, uniq_pre_nms_class_ids])
"""

nms_keep = []
def nms_keep_map(class_id):
ixs = tf.where(pre_nms_class_ids == class_id)[0]

# Apply NMS
class_keep = tf.image.non_max_suppression(
tf.to_float(tf.gather(pre_nms_rois,ixs)),
tf.gather(pre_nms_scores, ixs),
ixs.shape[0],
iou_threshold=config.DETECTION_NMS_THRESHOLD)

# Map indicies
return tf.gather(keep, tf.gather(ixs, class_keep))

nms_keep = tf.to_int64(tf.unique(tf.concat(
tf.map_fn(nms_keep_map, uniq_pre_nms_class_ids), axis=0))[0])

print(keep.shape, nms_keep.shape)
print(keep.dtype, nms_keep.dtype)
#keep = tf.sets.set_intersection(keep, nms_keep)
#tf.to_int32(
#np.intersect1d(keep, nms_keep).astype(np.int32)
result_keep = tf.concat([keep,nms_keep], axis = 0)
output_keep, idx_keep, count_keep = tf.unique_with_counts(result_keep)
new_idx_keep = tf.where(count_keep >= tf.constant(2))
keep = tf.gather(output_keep, new_idx_keep)

# Keep top detections
roi_count = tf.convert_to_tensor(config.DETECTION_MAX_INSTANCES)
class_scores_keep = tf.gather(class_scores, keep)
num_keep = tf.minimum(tf.shape(class_scores_keep)[0], roi_count)
top_ids = tf.nn.top_k(class_scores_keep, k=num_keep, sorted=True)[1]
"""
roi_count = config.DETECTION_MAX_INSTANCES
class_scores_keep = tf.gather(class_scores, keep)
print(type(class_scores_keep))
num_keep = min(max(class_scores_keep.get_shape()[0], 0), roi_count)
top_ids = tf.nn.top_k(class_scores_keep, k=num_keep, sorted=True)#[::-1]
"""
#np.argsort(class_scores[keep])[::-1][:roi_count]
keep = tf.gather(keep, top_ids) #keep[top_ids]

# Arrange output as [N, (y1, x1, y2, x2, class_id, score)]
# Coordinates are in image domain.
detections = tf.stack(
(tf.gather(tf.to_float(refined_rois), keep),
tf.gather(tf.to_float(class_ids), keep)[..., tf.newaxis],
tf.gather(class_scores, keep)[..., tf.newaxis])
)
#np.hstack((refined_rois[keep],
# class_ids[keep][..., np.newaxis],
# class_scores[keep][..., np.newaxis]))

# Pad with zeros if detections < DETECTION_MAX_INSTANCES
num_detections = tf.shape(detections)[0]
gap = roi_count - num_detections
print(gap, roi_count, num_detections)
pred = tf.less(tf.constant(0), gap)
#assert gap >= 0
#if gap > 0:
# paddings = tf.constant([[0, gap], [0, 0]])
# detections = tf.pad(detections, paddings, "CONSTANT")
def pad_detections():
paddings = tf.constant([[0, gap], [0, 0]])
return tf.pad(detections, paddings, "CONSTANT")

detections = tf.cond(pred, pad_detections, lambda: detections)

return tf.to_float(detections)

class DetectionLayer(KE.Layer):
"""Takes classified proposal boxes and their bounding box deltas and
Expand All @@ -751,30 +910,32 @@ def __init__(self, config=None, **kwargs):
self.config = config

def call(self, inputs):
def wrapper(rois, mrcnn_class, mrcnn_bbox, image_meta):
detections_batch = []
_, _, window, _ = parse_image_meta(image_meta)
for b in range(self.config.BATCH_SIZE):
detections = refine_detections(
rois[b], mrcnn_class[b], mrcnn_bbox[b], window[b], self.config)
# Pad with zeros if detections < DETECTION_MAX_INSTANCES
gap = self.config.DETECTION_MAX_INSTANCES - detections.shape[0]
assert gap >= 0
if gap > 0:
detections = np.pad(
detections, [(0, gap), (0, 0)], 'constant', constant_values=0)
detections_batch.append(detections)

# Stack detections and cast to float32
# TODO: track where float64 is introduced
detections_batch = np.array(detections_batch).astype(np.float32)
# Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels
return np.reshape(detections_batch, [self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])

# Return wrapped function
return tf.py_func(wrapper, inputs, tf.float32)

config = self.config
rois = inputs[0]
mrcnn_class = inputs[1]
mrcnn_bbox = inputs[2]
image_meta = inputs[3]
print(rois.shape, mrcnn_class.shape, mrcnn_bbox.shape, image_meta.shape)

_, _, window, _ = parse_image_meta_graph(image_meta)
print('window after: ', window.shape)
detections_batch = utils.batch_slice(
[rois, mrcnn_class, mrcnn_bbox, window],
lambda x, y, w, z: refine_detections_graph(x, y, w, z, self.config),
self.config.IMAGES_PER_GPU)

# Stack detections and cast to float32
# TODO: track where float64 is introduced
#detections_batch = tf.stack(detections_batch)
#detections_batch = np.array(detections_batch).astype(np.float32)
# Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels

return tf.reshape(
detections_batch,
[self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])


def compute_output_shape(self, input_shape):
return (None, self.config.DETECTION_MAX_INSTANCES, 6)

Expand Down