Skip to content

Commit

Permalink
Changes
Browse files Browse the repository at this point in the history
bgshih committed Jan 5, 2018
1 parent d436454 commit 58eb91e
Showing 22 changed files with 271 additions and 130 deletions.
27 changes: 23 additions & 4 deletions builders/convnet_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from rare.builders import hyperparams_builder
from rare.protos import convnet_pb2
from rare.convnets import crnn_net
from rare.convnets import resnet


def build(config, is_training):
@@ -9,8 +10,8 @@ def build(config, is_training):
convnet_oneof = config.WhichOneof('convnet_oneof')
if convnet_oneof == 'crnn_net':
return _build_crnn_net(config.crnn_net, is_training)
elif convnet_oneof == 'res_net':
return _build_res_net(config.resnet, is_training)
elif convnet_oneof == 'resnet':
return _build_resnet(config.resnet, is_training)
else:
raise ValueError('Unknown convnet_oneof: {}'.format(convnet_oneof))

@@ -34,5 +35,23 @@ def _build_crnn_net(config, is_training):
summarize_activations=config.summarize_activations,
is_training=is_training)

def _build_res_net(config, is_training):
raise NotImplementedError

def _build_resnet(config, is_training):
if not isinstance(config, convnet_pb2.ResNet):
raise ValueError('config is not of type convnet_pb2.ResNet')

if config.net_type != convnet_pb2.ResNet.SINGLE_BRANCH:
raise ValueError('Only SINGLE_BRANCH is supported for ResNet')

resnet_depth = config.net_depth
if resnet_depth == convnet_pb2.ResNet.RESNET_50:
resnet_class = resnet.Resnet50Layer
else:
raise ValueError('Unknown resnet depth: {}'.format(resnet_depth))

conv_hyperparams = hyperparams_builder.build(config.conv_hyperparams, is_training)
return resnet_class(
conv_hyperparams=conv_hyperparams,
summarize_activations=config.summarize_activations,
is_training=is_training,
)
34 changes: 33 additions & 1 deletion builders/convnet_builder_test.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
from rare.builders import convnet_builder
from rare.protos import convnet_pb2
from rare.convnets import crnn_net
# from rare.convnets import resnet
from rare.convnets import resnet

class FeatureExtractorTest(tf.test.TestCase):

@@ -100,6 +100,38 @@ def test_crnn_net_three_branches(self):
feature_maps = convnet_object.extract_features(test_input_image)
self.assertTrue(len(feature_maps) == 3)
print('Outputs of test_crnn_net_three_branches: {}'.format(feature_maps))

def test_resnet_50layer(self):
feature_extractor_text_proto = """
resnet {
net_type: SINGLE_BRANCH
net_depth: RESNET_50
conv_hyperparams {
op: CONV
regularizer { l2_regularizer { weight: 1e-4 } }
initializer { variance_scaling_initializer { } }
batch_norm { }
}
summarize_activations: false
}
"""
convnet_proto = convnet_pb2.Convnet()
text_format.Merge(feature_extractor_text_proto, convnet_proto)
convnet_object = convnet_builder.build(convnet_proto, True)
self.assertTrue(
isinstance(convnet_object, resnet.Resnet50Layer))
test_image_shape = [2, 32, 128, 3]
test_input_image = tf.random_uniform(
test_image_shape,
minval=0,
maxval=255.0,
dtype=tf.float32,
seed=1
)
feature_maps = convnet_object.extract_features(test_input_image)
self.assertTrue(len(feature_maps) == 1)
print('Outputs of test_resnet_single_branch: {}'.format(feature_maps))


if __name__ == '__main__':
tf.test.main()
17 changes: 13 additions & 4 deletions builders/label_map_builder.py
Original file line number Diff line number Diff line change
@@ -28,11 +28,20 @@ def _build_character_set(config):
file_path = config.text_file
with open(file_path, 'r') as f:
character_set_string = f.read()
character_set = character_set_string.split('\n')
elif source_oneof == 'text_string':
character_set_string = config.text_string

if not config.delimiter:
character_set = list(character_set_string)
else:
character_set = character_set_string.split()
elif source_oneof == 'built_in_set':
if config.built_in_set == label_map_pb2.CharacterSet.LOWERCASE:
character_set = list(string.digits + string.ascii_lowercase)
elif config.built_in_set == label_map_pb2.CharacterSet.ALLCASES:
character_set = list(string.digits + string.ascii_letters)
elif config.built_in_set == label_map_pb2.CharacterSet.ALLCASES_SYMBOLS:
character_set = list(string.printable[:-6])
else:
raise ValueError('Unknown built_in_set')
else:
raise ValueError('Unknown source_oneof: {}'.format(source_oneof))

return character_set
13 changes: 6 additions & 7 deletions builders/label_map_builder_test.py
Original file line number Diff line number Diff line change
@@ -10,8 +10,7 @@ class LabelMapTest(tf.test.TestCase):
def test_build_label_map(self):
label_map_text_proto = """
character_set {
text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
built_in_set: LOWERCASE
}
label_offset: 3
unk_label: -2
@@ -36,19 +35,19 @@ def test_build_label_map(self):
})
self.assertAllEqual(
outputs['test_labels'],
[[3, -1, -1],
[4, -1, -1],
[[13, -1, -1],
[14, -1, -1],
[-1, -1, -1],
[3, 4, 28],
[-2, 3, -2]]
[13, 14, 38],
[3, 13, -2]]
)
self.assertAllEqual(
outputs['text_lengths'],
[1, 1, 0, 3, 3]
)
self.assertAllEqual(
outputs['text_from_labels'],
[b'a', b'b', b'', b'abz', b'a']
[b'a', b'b', b'', b'abz', b'0a']
)

if __name__ == '__main__':
25 changes: 12 additions & 13 deletions builders/model_builder_test.py
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@
}
predictor {
bahdanau_attention_predictor {
attention_predictor {
rnn_cell {
lstm_cell {
num_units: 256
@@ -57,8 +57,7 @@
reverse: false
label_map {
character_set {
text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
built_in_set: ALLCASES
}
label_offset: 2
}
@@ -109,7 +108,8 @@
}
predictor {
bahdanau_attention_predictor {
name: "Forward"
attention_predictor {
rnn_cell {
lstm_cell {
num_units: 256
@@ -125,8 +125,7 @@
reverse: false
label_map {
character_set {
text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
built_in_set: ALLCASES
}
label_offset: 2
}
@@ -140,7 +139,8 @@
}
predictor {
bahdanau_attention_predictor {
name: "Backward"
attention_predictor {
rnn_cell {
lstm_cell {
num_units: 256
@@ -156,8 +156,7 @@
reverse: true
label_map {
character_set {
text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
built_in_set: ALLCASES
}
label_offset: 2
}
@@ -182,7 +181,7 @@ def test_single_predictor_model_training(self):
test_groundtruth_text_list = [
tf.constant(b'hello', dtype=tf.string),
tf.constant(b'world', dtype=tf.string)]
model_object.provide_groundtruth(test_groundtruth_text_list)
model_object.provide_groundtruth({'groundtruth_text': test_groundtruth_text_list})
test_input_image = tf.random_uniform(
shape=[2, 32, 100, 3], minval=0, maxval=255,
dtype=tf.float32, seed=1)
@@ -202,7 +201,7 @@ def test_single_predictor_model_inference(self):
test_groundtruth_text_list = [
tf.constant(b'hello', dtype=tf.string),
tf.constant(b'world', dtype=tf.string)]
model_object.provide_groundtruth(test_groundtruth_text_list)
model_object.provide_groundtruth({'groundtruth_text': test_groundtruth_text_list})
test_input_image = tf.random_uniform(
shape=[2, 32, 100, 3], minval=0, maxval=255,
dtype=tf.float32, seed=1)
@@ -222,7 +221,7 @@ def test_multi_predictors_model_training(self):
test_groundtruth_text_list = [
tf.constant(b'hello', dtype=tf.string),
tf.constant(b'world', dtype=tf.string)]
model_object.provide_groundtruth(test_groundtruth_text_list)
model_object.provide_groundtruth({'groundtruth_text': test_groundtruth_text_list})
test_input_image = tf.random_uniform(
shape=[2, 32, 100, 3], minval=0, maxval=255,
dtype=tf.float32, seed=1)
@@ -242,7 +241,7 @@ def test_multi_predictor_model_inference(self):
test_groundtruth_text_list = [
tf.constant(b'hello', dtype=tf.string),
tf.constant(b'world', dtype=tf.string)]
model_object.provide_groundtruth(test_groundtruth_text_list)
model_object.provide_groundtruth({'groundtruth_text': test_groundtruth_text_list})
test_input_image = tf.random_uniform(
shape=[2, 32, 100, 3], minval=0, maxval=255,
dtype=tf.float32, seed=1)
5 changes: 5 additions & 0 deletions builders/optimizer_builder.py
Original file line number Diff line number Diff line change
@@ -38,6 +38,11 @@ def build(optimizer_config, global_summaries):
config = optimizer_config.adam_optimizer
optimizer = tf.train.AdamOptimizer(
_create_learning_rate(config.learning_rate, global_summaries))

if optimizer_type == 'nadam_optimizer':
config = optimizer_config.nadam_optimizer
optimizer = tf.contrib.opt.NadamOptimizer(
_create_learning_rate(config.learning_rate, global_summaries))

if optimizer_type == 'adadelta_optimizer':
config = optimizer_config.adadelta_optimizer
17 changes: 17 additions & 0 deletions builders/optimizer_builder_test.py
Original file line number Diff line number Diff line change
@@ -144,6 +144,23 @@ def testBuildAdamOptimizer(self):
optimizer_object = optimizer_builder.build(optimizer_proto, global_summaries)
self.assertTrue(isinstance(optimizer_object, tf.train.AdamOptimizer))

def testBuildNadamOptimizer(self):
optimizer_text_proto = """
nadam_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.002
}
}
}
use_moving_average: false
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer_object = optimizer_builder.build(optimizer_proto, global_summaries)
self.assertTrue(isinstance(optimizer_object, tf.contrib.opt.NadamOptimizer))

def testBuildMovingAverageOptimizer(self):
optimizer_text_proto = """
adam_optimizer: {
13 changes: 12 additions & 1 deletion builders/preprocessor_builder.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

from rare.core import preprocessor
from rare.protos import preprocessor_pb2
from rare.builders import label_map_builder


def _get_step_config_from_proto(preprocessor_step_config, step_name):
@@ -60,7 +61,7 @@ def _get_dict_from_proto(config):
'image_to_float': preprocessor.image_to_float,
'subtract_channel_mean': preprocessor.subtract_channel_mean,
'rgb_to_gray': preprocessor.rgb_to_gray,
'string_filtering': preprocessor.string_filtering,
# 'string_filtering': preprocessor.string_filtering,
}


@@ -113,4 +114,14 @@ def build(preprocessor_step_config):
'method': method
})

if step_type == 'string_filtering':
config = preprocessor_step_config.string_filtering
include_charset_list = label_map_builder._build_character_set(config.include_charset)
include_charset = ''.join(include_charset_list)
return (preprocessor.string_filtering,
{
'lower_case': config.lower_case,
'include_charset': include_charset
})

raise ValueError('Unknown preprocessing step: {}'.format(step_type))
12 changes: 8 additions & 4 deletions builders/preprocessor_builder_test.py
Original file line number Diff line number Diff line change
@@ -139,7 +139,9 @@ def test_string_filtering(self):
preprocessor_text_proto = """
string_filtering {
lower_case: true
include_charset: "abc"
include_charset {
text_string: "abc"
}
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
@@ -162,7 +164,9 @@ def test_string_filtering_2(self):
preprocessor_text_proto = """
string_filtering {
lower_case: false
include_charset: "abcdABCD"
include_charset {
built_in_set: ALLCASES
}
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
@@ -171,11 +175,11 @@ def test_string_filtering_2(self):
self.assertEqual(function, preprocessor.string_filtering)
self.assert_dictionary_close(args, {
'lower_case': False,
'include_charset': "abcdABCD"
'include_charset': "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
})

test_input_strings = [t.encode('utf-8') for t in ['abc', 'abcde', '!=ABC DE~']]
expected_output_string = [t.encode('utf-8') for t in ['abc', 'abcd', 'ABCD']]
expected_output_string = [t.encode('utf-8') for t in ['abc', 'abcde', 'ABCDE']]
test_processed_strings = [function(t, **args) for t in test_input_strings]
with self.test_session() as sess:
outputs = sess.run(test_processed_strings)
Loading
Oops, something went wrong.

0 comments on commit 58eb91e

Please sign in to comment.