Skip to content

Commit

Permalink
Uniform regularization; Visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
bgshih committed Dec 21, 2017
1 parent 62af1ff commit adc6b65
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 8 deletions.
22 changes: 14 additions & 8 deletions core/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from rare.core import feature_extractor_pb2
from rare.core import hyperparams
from rare.utils import visualization_utils


class FeatureExtractor(object):
Expand All @@ -22,6 +23,8 @@ def __init__(self,
def preprocess(self, resized_inputs, scope=None):
with tf.variable_scope(scope, 'FeatureExtractorPreprocess', [resized_inputs]):
preprocessed_inputs = (2.0 / 255.0) * resized_inputs - 1.0
if self._summarize_inputs:
tf.summary.image('preprocessed_inputs', preprocessed_inputs, max_outputs=1)
return preprocessed_inputs

@abstractmethod
Expand Down Expand Up @@ -168,16 +171,14 @@ class BaselineFeatureExtractor(FeatureExtractor):

def __init__(self,
conv_hyperparams=None,
summarize_inputs=False):
super(BaselineFeatureExtractor, self).__init__()
summarize_inputs=False,
is_training=None):
super(BaselineFeatureExtractor, self).__init__(
summarize_inputs=summarize_inputs,
is_training=is_training)
self._conv_hyperparams = conv_hyperparams # FIXME: add it back
self._summarize_inputs = summarize_inputs

def preprocess(self, resized_inputs, scope=None):
with tf.variable_scope(scope, 'ModelPreprocess', [resized_inputs]):
preprocessed_inputs = (2.0 / 255.0) * resized_inputs - 1.0
return preprocessed_inputs

def extract_features(self, preprocessed_inputs, scope=None):
"""Extract features
Args:
Expand Down Expand Up @@ -215,7 +216,12 @@ def extract_features(self, preprocessed_inputs, scope=None):
if self._summarize_inputs:
for layer in [conv1, pool1, conv2, pool2, conv3,
conv4, pool4, conv5, conv6, pool6, conv7]:
tf.summary.histogram(layer.name, layer)
tf.summary.histogram(layer.op.name, layer)
# tf.summary.image(
# layer.name,
# visualization_utils.tile_activation_maps_max_dimensions(layer, 512, 512),
# max_outputs=1
# )
return [conv7]


Expand Down
6 changes: 6 additions & 0 deletions core/hyperparams.proto
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ message Initializer {
TruncatedNormalInitializer truncated_normal_initializer = 1;
VarianceScalingInitializer variance_scaling_initializer = 2;
OrthogonalInitializer orthogonal_initializer = 3;
UniformInitializer uniform_initializer = 4;
}
}

Expand Down Expand Up @@ -95,6 +96,11 @@ message OrthogonalInitializer {
optional int32 seed = 2;
}

message UniformInitializer {
optional float minval = 1 [default = -0.1];
optional float maxval = 2 [default = 0.1];
}

// Configuration proto for batch norm to apply after convolution op. See
// https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm
message BatchNorm {
Expand Down
4 changes: 4 additions & 0 deletions core/hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def _build_initializer(initializer):
gain=initializer.orthogonal_initializer.gain,
seed=initializer.orthogonal_initializer.seed
)
if initializer_oneof == 'uniform_initializer':
return tf.random_uniform_initializer(
minval=initializer.uniform_initializer.minval,
maxval=initializer.uniform_initializer.maxval)
raise ValueError('Unknown initializer function: {}'.format(
initializer_oneof))

Expand Down
46 changes: 46 additions & 0 deletions utils/visualization_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import tensorflow as tf

from rare.utils import shape_utils


def tile_activation_maps_max_dimensions(maps, max_height, max_width):
batch_size, map_height, map_width, map_depth = \
shape_utils.combined_static_and_dynamic_shape(maps)
num_rows = max_height // map_height
num_cols = max_width // map_width
return tile_activation_maps_rows_cols(maps, num_rows, num_cols)


def tile_activation_maps_rows_cols(maps, num_rows, num_cols):
"""
Args:
maps: [batch_size, map_height, map_width, map_depth]
Return:
tiled_map: [batch_size, tiled_height, tiled_width]
"""
batch_size, map_height, map_width, map_depth = \
shape_utils.combined_static_and_dynamic_shape(maps)

# padding
num_maps = num_rows * num_cols
padded_map = tf.cond(
tf.greater(num_maps, map_depth),
true_fn=lambda: tf.pad(maps, [[0, 0], [0, 0], [0, 0], [0, tf.maximum(num_maps - map_depth, 0)]]),
false_fn=lambda: maps[:,:,:,:num_maps]
)

# reshape to [batch_size, map_height, map_width, num_rows, num_cols]
reshaped_map = tf.reshape(padded_map, [batch_size, map_height, map_width, num_rows, num_cols])

# unstack and concat along widths
width_concated_maps = tf.concat(
tf.unstack(reshaped_map, axis=4), # => list of [batch_size, map_height, map_width, num_rows]
axis=2) # => [batch_size, map_height, map_width * num_cols, num_rows]

tiled_map = tf.concat(
tf.unstack(width_concated_maps, axis=3), # => list of [batch_size, map_height, map_width * num_cols]
axis=1) # => [batch_size, map_height * num_rows, map_width * num_cols]

tiled_map = tf.expand_dims(tiled_map, axis=3)

return tiled_map
34 changes: 34 additions & 0 deletions utils/visualization_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import tensorflow as tf

from rare.utils import visualization_utils

class VisualizationUtilsTest(tf.test.TestCase):

def test_tile_activation_maps_with_padding(self):
test_maps = tf.random_uniform([64, 32, 100, 16])
tiled_map = visualization_utils.tile_activation_maps_rows_cols(test_maps, 5, 5)

with self.test_session() as sess:
tiled_map_output = tiled_map.eval()
self.assertAllEqual(tiled_map_output.shape, [64, 32 * 5, 100 * 5, 1])

def test_tile_activation_maps_with_slicing(self):
test_maps = tf.random_uniform([64, 32, 100, 16])
tiled_map = visualization_utils.tile_activation_maps_rows_cols(test_maps, 5, 1)

with self.test_session() as sess:
tiled_map_output = tiled_map.eval()
self.assertAllEqual(tiled_map_output.shape, [64, 32 * 5, 100 * 1, 1])

def test_tile_activation_maps_max_sizes(self):
test_maps = tf.random_uniform([64, 32, 100, 16])
tiled_map = visualization_utils.tile_activation_maps_max_dimensions(
test_maps, 512, 512)

with self.test_session() as sess:
tiled_map_output = tiled_map.eval()
self.assertAllEqual(tiled_map_output.shape, [64, 512, 500, 1])


if __name__ == '__main__':
tf.test.main()

0 comments on commit adc6b65

Please sign in to comment.