Skip to content

Commit

Permalink
squeeze
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Aug 17, 2017
1 parent e3d6846 commit b8018e2
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
7 changes: 7 additions & 0 deletions keras_rcnn/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,10 @@ def crop_and_resize(image, boxes, size):
boxes = keras.backend.reshape(boxes, [-1, 4])

return tensorflow.image.crop_and_resize(image, boxes, box_ind, size)


def squeeze(a, axis=None):
"""
Remove single-dimensional entries from the shape of an array.
"""
return tensorflow.squeeze(a, axis)
2 changes: 1 addition & 1 deletion keras_rcnn/layers/object_detection/_anchor_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def unmap(data, count, inds_inside, fill=0):

inds_nd = keras.backend.concatenate([inds_ii, inds_coords], 1)

inverse_ret = tensorflow.squeeze(tensorflow.gather_nd(-1 * ret, inds_nd))
inverse_ret = keras_rcnn.backend.squeeze(tensorflow.gather_nd(-1 * ret, inds_nd))

ret = keras_rcnn.backend.scatter_add_tensor(ret, inds_nd, inverse_ret + data)

Expand Down
18 changes: 17 additions & 1 deletion tests/backend/test_tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy
import tensorflow

import keras_rcnn.backend
import keras_rcnn.backend.tensorflow_backend
import keras_rcnn.backend.common


Expand Down Expand Up @@ -35,3 +35,19 @@ def test_crop_and_resize():
slices = keras_rcnn.backend.crop_and_resize(image, boxes, size)

assert keras.backend.eval(slices).shape == (2, 7, 7, 3)


def test_squeeze():
x = [[[0], [1], [2]]]

x = keras.backend.variable(x)

assert keras.backend.int_shape(x) == (1, 3, 1)

y = keras_rcnn.backend.tensorflow_backend.squeeze(x)

assert keras.backend.int_shape(y) == (3,)

y = keras_rcnn.backend.tensorflow_backend.squeeze(x, axis=0)

assert keras.backend.int_shape(y) == (3, 1)

0 comments on commit b8018e2

Please sign in to comment.