From adc6b6506879a224ed0c5d22fcd27110e6b663fd Mon Sep 17 00:00:00 2001
From: Baoguang Shi
Date: Thu, 21 Dec 2017 20:43:20 +0800
Subject: [PATCH] Uniform regularization; Visualization
---
core/feature_extractor.py | 22 +++++++++------
core/hyperparams.proto | 6 ++++
core/hyperparams.py | 4 +++
utils/visualization_utils.py | 46 +++++++++++++++++++++++++++++++
utils/visualization_utils_test.py | 34 +++++++++++++++++++++++
5 files changed, 104 insertions(+), 8 deletions(-)
create mode 100644 utils/visualization_utils.py
create mode 100644 utils/visualization_utils_test.py
diff --git a/core/feature_extractor.py b/core/feature_extractor.py
index ae3c21d..1d070fb 100644
--- a/core/feature_extractor.py
+++ b/core/feature_extractor.py
@@ -7,6 +7,7 @@
from rare.core import feature_extractor_pb2
from rare.core import hyperparams
+from rare.utils import visualization_utils
class FeatureExtractor(object):
@@ -22,6 +23,8 @@ def __init__(self,
def preprocess(self, resized_inputs, scope=None):
with tf.variable_scope(scope, 'FeatureExtractorPreprocess', [resized_inputs]):
preprocessed_inputs = (2.0 / 255.0) * resized_inputs - 1.0
+ if self._summarize_inputs:
+ tf.summary.image('preprocessed_inputs', preprocessed_inputs, max_outputs=1)
return preprocessed_inputs
@abstractmethod
@@ -168,16 +171,14 @@ class BaselineFeatureExtractor(FeatureExtractor):
def __init__(self,
conv_hyperparams=None,
- summarize_inputs=False):
- super(BaselineFeatureExtractor, self).__init__()
+ summarize_inputs=False,
+ is_training=None):
+ super(BaselineFeatureExtractor, self).__init__(
+ summarize_inputs=summarize_inputs,
+ is_training=is_training)
self._conv_hyperparams = conv_hyperparams # FIXME: add it back
self._summarize_inputs = summarize_inputs
- def preprocess(self, resized_inputs, scope=None):
- with tf.variable_scope(scope, 'ModelPreprocess', [resized_inputs]):
- preprocessed_inputs = (2.0 / 255.0) * resized_inputs - 1.0
- return preprocessed_inputs
-
def extract_features(self, preprocessed_inputs, scope=None):
"""Extract features
Args:
@@ -215,7 +216,12 @@ def extract_features(self, preprocessed_inputs, scope=None):
if self._summarize_inputs:
for layer in [conv1, pool1, conv2, pool2, conv3,
conv4, pool4, conv5, conv6, pool6, conv7]:
- tf.summary.histogram(layer.name, layer)
+ tf.summary.histogram(layer.op.name, layer)
+ # tf.summary.image(
+ # layer.name,
+ # visualization_utils.tile_activation_maps_max_dimensions(layer, 512, 512),
+ # max_outputs=1
+ # )
return [conv7]
diff --git a/core/hyperparams.proto b/core/hyperparams.proto
index 755c647..f2236c8 100644
--- a/core/hyperparams.proto
+++ b/core/hyperparams.proto
@@ -66,6 +66,7 @@ message Initializer {
TruncatedNormalInitializer truncated_normal_initializer = 1;
VarianceScalingInitializer variance_scaling_initializer = 2;
OrthogonalInitializer orthogonal_initializer = 3;
+ UniformInitializer uniform_initializer = 4;
}
}
@@ -95,6 +96,11 @@ message OrthogonalInitializer {
optional int32 seed = 2;
}
+message UniformInitializer {
+ optional float minval = 1 [default = -0.1];
+ optional float maxval = 2 [default = 0.1];
+}
+
// Configuration proto for batch norm to apply after convolution op. See
// https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm
message BatchNorm {
diff --git a/core/hyperparams.py b/core/hyperparams.py
index b7ea884..2b9343f 100644
--- a/core/hyperparams.py
+++ b/core/hyperparams.py
@@ -150,6 +150,10 @@ def _build_initializer(initializer):
gain=initializer.orthogonal_initializer.gain,
seed=initializer.orthogonal_initializer.seed
)
+ if initializer_oneof == 'uniform_initializer':
+ return tf.random_uniform_initializer(
+ minval=initializer.uniform_initializer.minval,
+ maxval=initializer.uniform_initializer.maxval)
raise ValueError('Unknown initializer function: {}'.format(
initializer_oneof))
diff --git a/utils/visualization_utils.py b/utils/visualization_utils.py
new file mode 100644
index 0000000..56b64a0
--- /dev/null
+++ b/utils/visualization_utils.py
@@ -0,0 +1,46 @@
+import tensorflow as tf
+
+from rare.utils import shape_utils
+
+
+def tile_activation_maps_max_dimensions(maps, max_height, max_width):
+ batch_size, map_height, map_width, map_depth = \
+ shape_utils.combined_static_and_dynamic_shape(maps)
+ num_rows = max_height // map_height
+ num_cols = max_width // map_width
+ return tile_activation_maps_rows_cols(maps, num_rows, num_cols)
+
+
+def tile_activation_maps_rows_cols(maps, num_rows, num_cols):
+ """
+ Args:
+ maps: [batch_size, map_height, map_width, map_depth]
+ Return:
+ tiled_map: [batch_size, tiled_height, tiled_width]
+ """
+ batch_size, map_height, map_width, map_depth = \
+ shape_utils.combined_static_and_dynamic_shape(maps)
+
+ # padding
+ num_maps = num_rows * num_cols
+ padded_map = tf.cond(
+ tf.greater(num_maps, map_depth),
+ true_fn=lambda: tf.pad(maps, [[0, 0], [0, 0], [0, 0], [0, tf.maximum(num_maps - map_depth, 0)]]),
+ false_fn=lambda: maps[:,:,:,:num_maps]
+ )
+
+ # reshape to [batch_size, map_height, map_width, num_rows, num_cols]
+ reshaped_map = tf.reshape(padded_map, [batch_size, map_height, map_width, num_rows, num_cols])
+
+ # unstack and concat along widths
+ width_concated_maps = tf.concat(
+ tf.unstack(reshaped_map, axis=4), # => list of [batch_size, map_height, map_width, num_rows]
+ axis=2) # => [batch_size, map_height, map_width * num_cols, num_rows]
+
+ tiled_map = tf.concat(
+ tf.unstack(width_concated_maps, axis=3), # => list of [batch_size, map_height, map_width * num_cols]
+ axis=1) # => [batch_size, map_height * num_rows, map_width * num_cols]
+
+ tiled_map = tf.expand_dims(tiled_map, axis=3)
+
+ return tiled_map
diff --git a/utils/visualization_utils_test.py b/utils/visualization_utils_test.py
new file mode 100644
index 0000000..330dfb1
--- /dev/null
+++ b/utils/visualization_utils_test.py
@@ -0,0 +1,34 @@
+import tensorflow as tf
+
+from rare.utils import visualization_utils
+
+class VisualizationUtilsTest(tf.test.TestCase):
+
+ def test_tile_activation_maps_with_padding(self):
+ test_maps = tf.random_uniform([64, 32, 100, 16])
+ tiled_map = visualization_utils.tile_activation_maps_rows_cols(test_maps, 5, 5)
+
+ with self.test_session() as sess:
+ tiled_map_output = tiled_map.eval()
+ self.assertAllEqual(tiled_map_output.shape, [64, 32 * 5, 100 * 5, 1])
+
+ def test_tile_activation_maps_with_slicing(self):
+ test_maps = tf.random_uniform([64, 32, 100, 16])
+ tiled_map = visualization_utils.tile_activation_maps_rows_cols(test_maps, 5, 1)
+
+ with self.test_session() as sess:
+ tiled_map_output = tiled_map.eval()
+ self.assertAllEqual(tiled_map_output.shape, [64, 32 * 5, 100 * 1, 1])
+
+ def test_tile_activation_maps_max_sizes(self):
+ test_maps = tf.random_uniform([64, 32, 100, 16])
+ tiled_map = visualization_utils.tile_activation_maps_max_dimensions(
+ test_maps, 512, 512)
+
+ with self.test_session() as sess:
+ tiled_map_output = tiled_map.eval()
+ self.assertAllEqual(tiled_map_output.shape, [64, 512, 500, 1])
+
+
+if __name__ == '__main__':
+ tf.test.main()