Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding changes by Cpruce #1

Merged
merged 5 commits into from
Jan 10, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 170 additions & 35 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def clip_boxes_graph(boxes, window):
y2 = tf.maximum(tf.minimum(y2, wy2), wy1)
x2 = tf.maximum(tf.minimum(x2, wx2), wx1)
clipped = tf.concat([y1, x1, y2, x2], axis=1, name="clipped_boxes")
clipped.set_shape((clipped.shape[0], 4))
return clipped


Expand Down Expand Up @@ -737,6 +738,146 @@ def refine_detections(rois, probs, deltas, window, config):
class_scores[keep][..., np.newaxis]))
return result

def refine_detections_graph(rois, probs, deltas, window, config):
"""Refine classified proposals and filter overlaps and return final
detections.

Inputs:
rois: [N, (y1, x1, y2, x2)] in normalized coordinates
probs: [N, num_classes]. Class probabilities.
deltas: [N, num_classes, (dy, dx, log(dh), log(dw))]. Class-specific
bounding box deltas.
window: (y1, x1, y2, x2) in image coordinates. The part of the image
that contains the image excluding the padding.

Returns detections shaped: [N, (y1, x1, y2, x2, class_id, score)]
"""
# Class IDs per ROI
class_ids = tf.argmax(probs, axis=1, output_type=tf.int32)
# Class probability of the top class of each ROI
scores_select_size = class_ids.shape[0]
scores_select = tf.range(scores_select_size)
score_indices = tf.stack([scores_select,class_ids], axis=1)
class_scores = tf.gather_nd(probs, score_indices)
# Class-specific bounding box deltas
deltas_range_size = deltas.shape[0]
deltas_range = tf.range(deltas_range_size)
deltas_indices = tf.stack([deltas_range, class_ids], axis=1)
deltas_specific = tf.gather_nd(deltas, deltas_indices)
# Apply bounding box deltas
# Shape: [boxes, (y1, x1, y2, x2)] in normalized coordinates
refined_rois = apply_box_deltas_graph(
rois, deltas_specific * config.BBOX_STD_DEV)

# Convert coordiates to image domain
# TODO: better to keep them normalized until later
height, width = config.IMAGE_SHAPE[:2]
refined_rois *= tf.constant([height, width, height, width], dtype=tf.float32)
#np.array([height, width, height, width])
# Clip boxes to image window
refined_rois = clip_boxes_graph(refined_rois, window)
# Round and cast to int since we're deadling with pixels now
refined_rois = tf.to_int32(tf.rint(refined_rois))
# TODO: Filter out boxes with zero area

# Filter out background boxes
keep = tf.where(class_ids > 0)[:,0]
# Filter out low confidence boxes
if config.DETECTION_MIN_CONFIDENCE:
keep = tf.sets.set_intersection(
keep, tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[:,0])

# Apply per-class NMS
pre_nms_class_ids = tf.gather(class_ids, keep) #class_ids[keep]
pre_nms_scores = tf.gather(class_scores, keep) #class_scores[keep]
pre_nms_rois = tf.gather(refined_rois, keep) #refined_rois[keep]
print('pre_nms_class_ids = {}'.format(pre_nms_class_ids.shape))
print('pre_nms_scores = {}'.format(pre_nms_scores.shape))
print('pre_nms_rois = {}'.format(pre_nms_rois.shape))

uniq_pre_nms_class_ids = tf.unique(pre_nms_class_ids)[0]

nms_keep = []
def nms_keep_map(class_id):
print('pre_nms_class_ids.shape', pre_nms_class_ids.shape)
print('class_id', class_id.shape)
ixs = tf.where(pre_nms_class_ids == tf.expand_dims(class_id, -1))[0]

# Apply NMS
class_keep = tf.image.non_max_suppression(
tf.to_float(tf.gather(pre_nms_rois,ixs)),
tf.gather(pre_nms_scores, ixs),
ixs.shape[0],
iou_threshold=config.DETECTION_NMS_THRESHOLD)

# Map indicies
return tf.gather(keep, tf.gather(ixs, class_keep))

print('uniq_pre_nms_class_ids: {}'.format(uniq_pre_nms_class_ids.shape))
nms_keep = tf.to_int64(tf.unique(tf.concat(
tf.map_fn(nms_keep_map, uniq_pre_nms_class_ids), axis=0))[0])

print(keep.shape, nms_keep.shape)
print(keep.dtype, nms_keep.dtype)
"""keep = tf.sets.set_intersection(
tf.expand_dims(keep, 0),
tf.expand_dims(tf.sparse_to_dense(nms_keep), 0))[1]
"""#tf.to_int32(
#np.intersect1d(keep, nms_keep).astype(np.int32)

result_keep = tf.concat([keep,nms_keep], axis = 0)
print('result_keep: {}'.format(result_keep.shape))
output_keep, idx_keep, count_keep = tf.unique_with_counts(result_keep)
print('output_keep: {}, idx_keep: {}, count_keep: {}, 2const: {}'.format(output_keep.shape,
idx_keep.shape, count_keep.shape, tf.constant(2).shape))
new_idx_keep = tf.where(count_keep >= tf.constant(2))[1] # keep coordinates of true elems
print('keep bbefore: {}, new_idx_keep: {}'.format(keep.shape, new_idx_keep.shape))
keep = tf.gather(output_keep, new_idx_keep)

# Keep top detections
roi_count = tf.convert_to_tensor(config.DETECTION_MAX_INSTANCES)
print('class scores: {}'.format(class_scores.shape))
class_scores_keep = tf.gather(class_scores, keep)
num_keep = tf.minimum(tf.shape(class_scores_keep)[0], roi_count)
top_ids = tf.nn.top_k(class_scores_keep, k=num_keep, sorted=True)[1]

#np.argsort(class_scores[keep])[::-1][:roi_count]
print('keep before: {}'.format(keep.shape))
keep = tf.gather(keep, top_ids)
print('keep after: {}'.format(keep.shape))

refined_rois_keep = tf.gather(tf.to_float(refined_rois), keep)
class_ids_keep = tf.gather(tf.to_float(class_ids), keep)[..., tf.newaxis]
class_scores_keep = tf.gather(class_scores, keep)[..., tf.newaxis]
print('refined_rois_keep = ', refined_rois_keep.shape)
print('class_ids_keep = ', class_ids_keep.shape)
print('class_scores_keep = ', class_scores_keep.shape)

# Arrange output as [N, (y1, x1, y2, x2, class_id, score)]
# Coordinates are in image domain.
detections = tf.concat((refined_rois_keep, class_ids_keep,
class_scores_keep), axis=1)
print('detections.shape = ', detections.shape)
#np.hstack((refined_rois[keep],
# class_ids[keep][..., np.newaxis],
# class_scores[keep][..., np.newaxis]))

# Pad with zeros if detections < DETECTION_MAX_INSTANCES
num_detections = tf.shape(detections)[0]
gap = roi_count - num_detections
print(gap, roi_count, num_detections)
pred = tf.less(tf.constant(0), gap)
#assert gap >= 0
#if gap > 0:
# paddings = tf.constant([[0, gap], [0, 0]])
# detections = tf.pad(detections, paddings, "CONSTANT")
def pad_detections():
print(detections.shape)
return tf.pad(detections, [(0, gap), (0, 0)], "CONSTANT")

detections = tf.cond(pred, pad_detections, lambda: detections)

return tf.to_float(detections)

class DetectionLayer(KE.Layer):
"""Takes classified proposal boxes and their bounding box deltas and
Expand All @@ -751,29 +892,32 @@ def __init__(self, config=None, **kwargs):
self.config = config

def call(self, inputs):
def wrapper(rois, mrcnn_class, mrcnn_bbox, image_meta):
detections_batch = []
_, _, window, _ = parse_image_meta(image_meta)
for b in range(self.config.BATCH_SIZE):
detections = refine_detections(
rois[b], mrcnn_class[b], mrcnn_bbox[b], window[b], self.config)
# Pad with zeros if detections < DETECTION_MAX_INSTANCES
gap = self.config.DETECTION_MAX_INSTANCES - detections.shape[0]
assert gap >= 0
if gap > 0:
detections = np.pad(
detections, [(0, gap), (0, 0)], 'constant', constant_values=0)
detections_batch.append(detections)

# Stack detections and cast to float32
# TODO: track where float64 is introduced
detections_batch = np.array(detections_batch).astype(np.float32)
# Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels
return np.reshape(detections_batch, [self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])

# Return wrapped function
return tf.py_func(wrapper, inputs, tf.float32)
config = self.config
rois = inputs[0]
mrcnn_class = inputs[1]
mrcnn_bbox = inputs[2]
image_meta = inputs[3]
print(rois.shape, mrcnn_class.shape, mrcnn_bbox.shape, image_meta.shape)

#parse_image_meta can be reused as slicing works same way in TF & numpy
window = get_image_meta_window(image_meta)
print('window after: ', window.shape)
detections_batch = utils.batch_slice(
[rois, mrcnn_class, mrcnn_bbox, window],
lambda x, y, w, z: refine_detections_graph(x, y, w, z, self.config),
self.config.IMAGES_PER_GPU)

# Stack detections and cast to float32
# TODO: track where float64 is introduced
#detections_batch = tf.stack(detections_batch)
#detections_batch = np.array(detections_batch).astype(np.float32)
# Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels

return tf.reshape(
detections_batch,
[self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])


def compute_output_shape(self, input_shape):
return (None, self.config.DETECTION_MAX_INSTANCES, 6)
Expand Down Expand Up @@ -2476,7 +2620,6 @@ def compose_image_meta(image_id, image_shape, window, active_class_ids):
return meta


# Two functions (for Numpy and TF) to parse image_meta tensors.
def parse_image_meta(meta):
"""Parses an image info Numpy array to its components.
See compose_image_meta() for more details.
Expand All @@ -2487,19 +2630,11 @@ def parse_image_meta(meta):
active_class_ids = meta[:, 8:]
return image_id, image_shape, window, active_class_ids


def parse_image_meta_graph(meta):
"""Parses a tensor that contains image attributes to its components.
def get_image_meta_window(meta):
"""Parses an image info Numpy array to its components.
See compose_image_meta() for more details.

meta: [batch, meta length] where meta length depends on NUM_CLASSES
"""
image_id = meta[:, 0]
image_shape = meta[:, 1:4]
window = meta[:, 4:8]
active_class_ids = meta[:, 8:]
return [image_id, image_shape, window, active_class_ids]

return meta[:, 4:8] # (y1, x1, y2, x2) window of image in in pixels

def mold_image(images, config):
"""Takes RGB images with 0-255 values and subtraces
Expand Down