Skip to content

Commit

Permalink
Merge pull request broadinstitute#71 from delftrobotics-forks/feature…
Browse files Browse the repository at this point in the history
…s/data-generator

Refactor data generator
  • Loading branch information
0x00b1 authored Aug 25, 2017
2 parents c7b51b1 + c33bf09 commit c6afb5d
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 90 deletions.
2 changes: 2 additions & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ Jane Hung <jyenhung@gmail.com> @jhung0
Jihong Ju <daniel.jihong.ju@gmail.com> @jujihong

Claire McQuin <mcquincl@gmail.com> @mcquin

Mihai Morariu <mihaimorariu@gmail.com> @mihaimorariu
154 changes: 64 additions & 90 deletions keras_rcnn/preprocessing/_object_detection.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,92 @@
import threading

import keras.backend
import keras.preprocessing.image
import numpy
import numpy.random
import skimage.transform
import skimage.io
import keras.utils
import sklearn.preprocessing


class Iterator:
def __init__(self, n, batch_size, shuffle, seed):
self.batch_index = 0

self.batch_size = batch_size

self.index_generator = self._flow_index(n, batch_size, shuffle, seed)

self.lock = threading.Lock()

self.n = n

self.shuffle = shuffle

self.total_batches_seen = 0

def reset(self):
self.batch_index = 0

def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
self.reset()

while True:
if seed is not None:
numpy.random.seed(seed + self.total_batches_seen)

if self.batch_index == 0:
index_array = numpy.arange(n)
import time

if shuffle:
index_array = numpy.random.permutation(n)
def scale_size(size, min_size=224, max_size=224):
"""
Rescales a given image size such that the larger axis is
no larger than max_size and the smallest axis is as close
as possible to min_size.
"""
assert(len(size) == 2)

current_index = (self.batch_index * batch_size) % n
scale = min_size / numpy.min(size)

if n > current_index + batch_size:
current_batch_size = batch_size
# Prevent the biggest axis from being larger than max_size.
if numpy.round(scale * numpy.max(size)) > max_size:
scale = max_size / numpy.max(size)

self.batch_index += 1
else:
current_batch_size = n - current_index
rows, cols = size
rows *= scale
cols *= scale

self.batch_index = 0
return (int(rows), int(cols)), scale

self.total_batches_seen += 1

yield index_array[current_index:current_index + current_batch_size], current_index, current_batch_size

def __iter__(self):
return self

def __next__(self, *args, **kwargs):
return self.next(*args, **kwargs)

def next(self, *args, **kwargs):
pass


class DictionaryIterator(Iterator):
def __init__(self, dictionary, generator, shuffle=False, seed=None):
class DictionaryIterator(keras.preprocessing.image.Iterator):
def __init__(self, dictionary, classes, generator, batch_size=1, shuffle=False, seed=None):
self.dictionary = dictionary
self.classes = classes
self.generator = generator

assert(len(self.dictionary) != 0)

self.encoder = sklearn.preprocessing.LabelEncoder()
# Compute and store the target image shape.
cols, rows, channels = dictionary[0]["shape"]
self.image_shape = (rows, cols, channels)

self.generator = generator
self.target_shape, self.scale = scale_size(self.image_shape[0:2])
self.target_shape = self.target_shape + (self.image_shape[2],)

self.encoder.fit(generator.classes)
# Metadata needs to be computed only once.
rows, cols, channels = self.target_shape
self.metadata = numpy.array([[rows, cols, self.scale]])

Iterator.__init__(self, len(dictionary), 1, shuffle, seed)
super().__init__(len(self.dictionary), batch_size, shuffle, seed)

def next(self):
# Lock indexing to prevent race conditions.
with self.lock:
index_array, current_index, current_batch_size = next(
self.index_generator)

index = index_array[0]
selection, _, batch_size = next(self.index_generator)

pathname = self.dictionary[index]["filename"]
# Labels has num_classes + 1 elements, since 0 is reserved for background.
num_classes = len(self.classes)
images = numpy.zeros((batch_size,) + self.target_shape, dtype=keras.backend.floatx())
boxes = numpy.zeros((batch_size, 0, 4), dtype=keras.backend.floatx())
labels = numpy.zeros((batch_size, 0, num_classes + 1), dtype=numpy.uint8)

image = skimage.io.imread(pathname)
for batch_index, image_index in enumerate(selection):
path = self.dictionary[image_index]["filename"]
image = skimage.io.imread(path)

image = numpy.expand_dims(image, 0)
# Assert that the loaded image has the predefined image shape.
if image.shape != self.image_shape:
raise Exception("All input images need to be of the same shape.")

ds = self.dictionary[index]["boxes"]
# Copy image to batch blob.
images[batch_index] = skimage.transform.rescale(image, scale=self.scale, mode="reflect")

boxes = numpy.asarray([[d[k] for k in ['y1', 'x1', 'y2', 'x2']] for d in ds])
# Set ground truth boxes.
for i, b in enumerate(self.dictionary[image_index]["boxes"]):
if b["class"] not in self.classes:
raise Exception("Class {} not found in '{}'.".format(b["class"], self.classes))

labels = numpy.asarray([d['class'] for d in ds])
box = [b["y1"], b["x1"], b["y2"], b["x2"]]
boxes = numpy.append(boxes, [[box]], axis=1)

labels = self.encoder.transform(labels)
# Store the labels in one-hot form.
label = [0] * (num_classes + 1)
label[self.classes[b["class"]]] = 1
labels = numpy.append(labels, [[label]], axis = 1)

labels = keras.utils.to_categorical(labels)

metadata = list(image.shape[1:-1]) + [1]

# boxes = numpy.expand_dims(boxes, 0)

# labels = numpy.expand_dims(labels, 0)

# metadata = numpy.expand_dims(metadata, 0)

return [image, boxes, labels, metadata], [boxes, labels]
# Scale the ground truth boxes to the selected image scale.
boxes[batch_index, :, :4] *= self.scale

return [images, boxes, self.metadata, labels], None

class ObjectDetectionGenerator:
def __init__(self, classes):
self.classes = classes

def flow(self, dictionary, shuffle=True, seed=None):
return DictionaryIterator(dictionary, self, shuffle=shuffle, seed=seed)
def flow(self, dictionary, classes):
return DictionaryIterator(dictionary, classes, self)
16 changes: 16 additions & 0 deletions tests/preprocessing/test_object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import keras_rcnn.preprocessing._object_detection
import numpy


def test_scale_shape():
min_size = 200
max_size = 300
size = (600, 1000)

size, scale = keras_rcnn.preprocessing._object_detection.scale_size(size, min_size, max_size)

expected = (180, 300)
numpy.testing.assert_equal(size, expected)

expected = 0.3
assert numpy.isclose(scale, expected)

0 comments on commit c6afb5d

Please sign in to comment.