Skip to content

Commit

Permalink
Refactor data generator.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mihai Morariu committed Aug 24, 2017
1 parent d782b80 commit 17b30b4
Showing 1 changed file with 60 additions and 94 deletions.
154 changes: 60 additions & 94 deletions keras_rcnn/preprocessing/_object_detection.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,84 @@
import threading

import keras.backend
import keras.preprocessing.image
import numpy
import numpy.random
import skimage.transform
import skimage.io
import keras.utils
import sklearn.preprocessing


class Iterator:
def __init__(self, n, batch_size, shuffle, seed):
self.batch_index = 0

self.batch_size = batch_size

self.index_generator = self._flow_index(n, batch_size, shuffle, seed)

self.lock = threading.Lock()

self.n = n

self.shuffle = shuffle

self.total_batches_seen = 0

def reset(self):
self.batch_index = 0

def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
self.reset()

while True:
if seed is not None:
numpy.random.seed(seed + self.total_batches_seen)

if self.batch_index == 0:
index_array = numpy.arange(n)

if shuffle:
index_array = numpy.random.permutation(n)
import time

current_index = (self.batch_index * batch_size) % n
def scale_shape(shape, min_size=224, max_size=224):
"""
Rescales a given shape such that the larger axis is no
larger than max_size and the smallest axis is as close
as possible to min_size.
"""
min_shape = numpy.min(shape[0:2])
max_shape = numpy.max(shape[0:2])
scale = min_size / min_shape

if n > current_index + batch_size:
current_batch_size = batch_size
# Prevent the biggest axis from being larger than max_size
if numpy.round(scale * max_shape) > max_size:
scale = max_size / max_shape

self.batch_index += 1
else:
current_batch_size = n - current_index
return (int(shape[0] * scale), int(shape[1] * scale), shape[2]), scale

self.batch_index = 0

self.total_batches_seen += 1
def scale_image(image, min_size=224, max_size=224):
"""
Rescales an image according to the heuristics from 'scale_shape'.
"""
target_shape, scale = scale_shape(image.shape, min_size, max_size)

yield index_array[current_index:current_index + current_batch_size], current_index, current_batch_size
return skimage.transform.rescale(image, scale=scale, mode="reflect"), scale

def __iter__(self):
return self

def __next__(self, *args, **kwargs):
return self.next(*args, **kwargs)
class DictionaryIterator(keras.preprocessing.image.Iterator):
def __init__(self, data, classes, image_data_generator, image_shape, batch_size=1,
shuffle=True, seed=numpy.uint32(time.time() * 1000)):
self.data = data
self.classes = classes
self.image_data_generator = image_data_generator
self.image_shape = image_shape
self.target_shape, _ = scale_shape(image_shape)

def next(self, *args, **kwargs):
pass


class DictionaryIterator(Iterator):
def __init__(self, dictionary, generator, shuffle=False, seed=None):
self.dictionary = dictionary

self.encoder = sklearn.preprocessing.LabelEncoder()

self.generator = generator

self.encoder.fit(generator.classes)

Iterator.__init__(self, len(dictionary), 1, shuffle, seed)
super().__init__(len(self.data), batch_size, shuffle, seed)

def next(self):
# Lock indexing to prevent race conditions
with self.lock:
index_array, current_index, current_batch_size = next(
self.index_generator)

index = index_array[0]
selection, _, batch_size = next(self.index_generator)

pathname = self.dictionary[index]["filename"]
# Transformation of images is not under thread lock so it can be done in parallel
image_batch = numpy.zeros((batch_size,) + self.target_shape, dtype=keras.backend.floatx())
gt_boxes_batch = numpy.zeros((batch_size, 0, 5), dtype=keras.backend.floatx())
metadata = numpy.zeros((batch_size, 3), dtype=keras.backend.floatx())

image = skimage.io.imread(pathname)
for batch_index, image_index in enumerate(selection):
path = self.data[image_index]["filename"]
image = skimage.io.imread(path, as_grey=(self.target_shape[2] == 1))
image, scale = scale_image(image)
image = self.image_data_generator.random_transform(image)
image = self.image_data_generator.standardize(image)

image = numpy.expand_dims(image, 0)
# Copy image to batch blob
image_batch[batch_index] = image

ds = self.dictionary[index]["boxes"]
# Set ground truth boxes
boxes = self.data[image_index]["boxes"]
for i, b in enumerate(boxes):
if b["class"] not in self.classes:
raise Exception("Class {} not found in '{}'.".format(b["class"], self.classes))

boxes = numpy.asarray([[d[k] for k in ['y1', 'x1', 'y2', 'x2']] for d in ds])
gt_data = [b["y1"], b["x1"], b["y2"], b["x2"], self.classes[b["class"]]]
gt_boxes_batch = numpy.append(gt_boxes_batch, [[gt_data]], axis=1)

labels = numpy.asarray([d['class'] for d in ds])
# Scale the ground truth boxes to the selected image scale
gt_boxes_batch[batch_index, :, :4] *= scale

labels = self.encoder.transform(labels)

labels = keras.utils.to_categorical(labels)

metadata = list(image.shape[1:-1]) + [1]

# boxes = numpy.expand_dims(boxes, 0)

# labels = numpy.expand_dims(labels, 0)

# metadata = numpy.expand_dims(metadata, 0)

return [image, boxes, labels, metadata], [boxes, labels]
# Create metadata
metadata[batch_index, :] = [image.shape[0], image.shape[1], scale]

return [image_batch, gt_boxes_batch, metadata], None

class ObjectDetectionGenerator:
def __init__(self, classes):
self.classes = classes

def flow(self, dictionary, shuffle=True, seed=None):
return DictionaryIterator(dictionary, self, shuffle=shuffle, seed=seed)
def flow(self, data, classes, image_shape):
return DictionaryIterator(data, classes, keras.preprocessing.image.ImageDataGenerator(), image_shape)

0 comments on commit 17b30b4

Please sign in to comment.