Skip to content

Commit

Permalink
Merge branch 'tensorboard_cleanup' of https://github.com/wkeithvan/tf…
Browse files Browse the repository at this point in the history
…_unet into wkeithvan-tensorboard_cleanup

# Conflicts:
#	tf_unet/layers.py
#	tf_unet/unet.py
  • Loading branch information
jakeret committed Jun 23, 2018
2 parents 62f9820 + 2eb2e49 commit 2851bda
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 128 deletions.
61 changes: 32 additions & 29 deletions tf_unet/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,55 @@

import tensorflow as tf

def weight_variable(shape, stddev=0.1):
def weight_variable(shape, stddev=0.1, name="weight"):
initial = tf.truncated_normal(shape, stddev=stddev)
return tf.Variable(initial)
return tf.Variable(initial, name=name)

def weight_variable_devonc(shape, stddev=0.1):
return tf.Variable(tf.truncated_normal(shape, stddev=stddev))
def weight_variable_devonc(shape, stddev=0.1, name="weight_devonc"):
return tf.Variable(tf.truncated_normal(shape, stddev=stddev), name=name)

def bias_variable(shape):
def bias_variable(shape, name="bias"):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
return tf.Variable(initial, name=name)

def conv2d(x, W, b, keep_prob_):
conv_2d = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID')
conv_2d_b = tf.nn.bias_add(conv_2d, b)
return tf.nn.dropout(conv_2d_b, keep_prob_)
with tf.name_scope("conv2d"):
conv_2d = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID')
conv_2d_b = tf.nn.bias_add(conv_2d, b)
return tf.nn.dropout(conv_2d_b, keep_prob_)

def deconv2d(x, W,stride):
x_shape = tf.shape(x)
output_shape = tf.stack([x_shape[0], x_shape[1]*2, x_shape[2]*2, x_shape[3]//2])
return tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, stride, stride, 1], padding='VALID')
with tf.name_scope("deconv2d"):
x_shape = tf.shape(x)
output_shape = tf.stack([x_shape[0], x_shape[1]*2, x_shape[2]*2, x_shape[3]//2])
return tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, stride, stride, 1], padding='VALID', name="conv2d_transpose")

def max_pool(x,n):
return tf.nn.max_pool(x, ksize=[1, n, n, 1], strides=[1, n, n, 1], padding='VALID')

def crop_and_concat(x1,x2):
x1_shape = tf.shape(x1)
x2_shape = tf.shape(x2)
# offsets for the top left corner of the crop
offsets = [0, (x1_shape[1] - x2_shape[1]) // 2, (x1_shape[2] - x2_shape[2]) // 2, 0]
size = [-1, x2_shape[1], x2_shape[2], -1]
x1_crop = tf.slice(x1, offsets, size)
return tf.concat([x1_crop, x2], 3)
with tf.name_scope("crop_and_concat"):
x1_shape = tf.shape(x1)
x2_shape = tf.shape(x2)
# offsets for the top left corner of the crop
offsets = [0, (x1_shape[1] - x2_shape[1]) // 2, (x1_shape[2] - x2_shape[2]) // 2, 0]
size = [-1, x2_shape[1], x2_shape[2], -1]
x1_crop = tf.slice(x1, offsets, size)
return tf.concat([x1_crop, x2], 3)

def pixel_wise_softmax(output_map):
exponential_map = tf.exp(output_map)
evidence = tf.add(exponential_map,tf.reverse(exponential_map,[False,False,False,True]))
return tf.div(exponential_map,evidence, name="pixel_wise_softmax")
with tf.name_scope("pixel_wise_softmax"):
exponential_map = tf.exp(output_map)
evidence = tf.add(exponential_map,tf.reverse(exponential_map,[False,False,False,True]))
return tf.div(exponential_map,evidence, name="pixel_wise_softmax")

def pixel_wise_softmax_2(output_map):
exponential_map = tf.exp(output_map)
sum_exp = tf.reduce_sum(exponential_map, 3, keepdims=True)
tensor_sum_exp = tf.tile(sum_exp, tf.stack([1, 1, 1, tf.shape(output_map)[3]]))
return tf.div(exponential_map,tensor_sum_exp)


with tf.name_scope("pixel_wise_softmax_2"):
exponential_map = tf.exp(output_map)
sum_exp = tf.reduce_sum(exponential_map, 3, keepdims=True)
tensor_sum_exp = tf.tile(sum_exp, tf.stack([1, 1, 1, tf.shape(output_map)[3]]))
return tf.div(exponential_map,tensor_sum_exp)

def cross_entropy(y_,output_map):
return -tf.reduce_mean(y_*tf.log(tf.clip_by_value(output_map,1e-10,1.0)), name="cross_entropy")
# return tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(output_map), reduction_indices=[1]))
#return tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(output_map), reduction_indices=[1]))
207 changes: 108 additions & 99 deletions tf_unet/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ def create_conv_net(x, keep_prob, channels, n_class, layers=3, features_root=16,
features=features_root,
filter_size=filter_size,
pool_size=pool_size))

# Placeholder for the input image
nx = tf.shape(x)[1]
ny = tf.shape(x)[2]
x_image = tf.reshape(x, tf.stack([-1, nx, ny, channels]))
in_node = x_image
batch_size = tf.shape(x_image)[0]
with tf.name_scope("preprocessing"):
nx = tf.shape(x)[1]
ny = tf.shape(x)[2]
x_image = tf.reshape(x, tf.stack([-1, nx, ny, channels]))
in_node = x_image
batch_size = tf.shape(x_image)[0]

weights = []
biases = []
Expand All @@ -76,69 +78,72 @@ def create_conv_net(x, keep_prob, channels, n_class, layers=3, features_root=16,
size = in_size
# down layers
for layer in range(0, layers):
features = 2 ** layer * features_root
stddev = np.sqrt(2 / (filter_size ** 2 * features))
if layer == 0:
w1 = weight_variable([filter_size, filter_size, channels, features], stddev)
else:
w1 = weight_variable([filter_size, filter_size, features // 2, features], stddev)

w2 = weight_variable([filter_size, filter_size, features, features], stddev)
b1 = bias_variable([features])
b2 = bias_variable([features])

conv1 = conv2d(in_node, w1, b1, keep_prob)
tmp_h_conv = tf.nn.relu(conv1)
conv2 = conv2d(tmp_h_conv, w2, b2, keep_prob)
dw_h_convs[layer] = tf.nn.relu(conv2)

weights.append((w1, w2))
biases.append((b1, b2))
convs.append((conv1, conv2))

size -= 4
if layer < layers - 1:
pools[layer] = max_pool(dw_h_convs[layer], pool_size)
in_node = pools[layer]
size /= 2

in_node = dw_h_convs[layers - 1]
with tf.name_scope("down_conv_{}".format(str(layer))):
features = 2 ** layer * features_root
stddev = np.sqrt(2 / (filter_size ** 2 * features))
if layer == 0:
w1 = weight_variable([filter_size, filter_size, channels, features], stddev, name="w1")
else:
w1 = weight_variable([filter_size, filter_size, features // 2, features], stddev, name="w1")

# up layers
for layer in range(layers - 2, -1, -1):
features = 2 ** (layer + 1) * features_root
stddev = np.sqrt(2 / (filter_size ** 2 * features))
w2 = weight_variable([filter_size, filter_size, features, features], stddev, name="w2")
b1 = bias_variable([features], name="b1")
b2 = bias_variable([features], name="b2")

wd = weight_variable_devonc([pool_size, pool_size, features // 2, features], stddev)
bd = bias_variable([features // 2])
h_deconv = tf.nn.relu(deconv2d(in_node, wd, pool_size) + bd)
h_deconv_concat = crop_and_concat(dw_h_convs[layer], h_deconv)
deconv[layer] = h_deconv_concat
conv1 = conv2d(in_node, w1, b1, keep_prob)
tmp_h_conv = tf.nn.relu(conv1)
conv2 = conv2d(tmp_h_conv, w2, b2, keep_prob)
dw_h_convs[layer] = tf.nn.relu(conv2)

w1 = weight_variable([filter_size, filter_size, features, features // 2], stddev)
w2 = weight_variable([filter_size, filter_size, features // 2, features // 2], stddev)
b1 = bias_variable([features // 2])
b2 = bias_variable([features // 2])
weights.append((w1, w2))
biases.append((b1, b2))
convs.append((conv1, conv2))

conv1 = conv2d(h_deconv_concat, w1, b1, keep_prob)
h_conv = tf.nn.relu(conv1)
conv2 = conv2d(h_conv, w2, b2, keep_prob)
in_node = tf.nn.relu(conv2)
up_h_convs[layer] = in_node
size -= 4
if layer < layers - 1:
pools[layer] = max_pool(dw_h_convs[layer], pool_size)
in_node = pools[layer]
size /= 2

weights.append((w1, w2))
biases.append((b1, b2))
convs.append((conv1, conv2))
in_node = dw_h_convs[layers - 1]

size *= 2
size -= 4
# up layers
for layer in range(layers - 2, -1, -1):
with tf.name_scope("up_conv_{}".format(str(layer))):
features = 2 ** (layer + 1) * features_root
stddev = np.sqrt(2 / (filter_size ** 2 * features))

wd = weight_variable_devonc([pool_size, pool_size, features // 2, features], stddev, name="wd")
bd = bias_variable([features // 2], name="bd")
h_deconv = tf.nn.relu(deconv2d(in_node, wd, pool_size) + bd)
h_deconv_concat = crop_and_concat(dw_h_convs[layer], h_deconv)
deconv[layer] = h_deconv_concat

w1 = weight_variable([filter_size, filter_size, features, features // 2], stddev, name="w1")
w2 = weight_variable([filter_size, filter_size, features // 2, features // 2], stddev, name="w2")
b1 = bias_variable([features // 2], name="b1")
b2 = bias_variable([features // 2], name="b2")

conv1 = conv2d(h_deconv_concat, w1, b1, keep_prob)
h_conv = tf.nn.relu(conv1)
conv2 = conv2d(h_conv, w2, b2, keep_prob)
in_node = tf.nn.relu(conv2)
up_h_convs[layer] = in_node

weights.append((w1, w2))
biases.append((b1, b2))
convs.append((conv1, conv2))

size *= 2
size -= 4

# Output Map
weight = weight_variable([1, 1, features_root, n_class], stddev)
bias = bias_variable([n_class])
conv = conv2d(in_node, weight, bias, tf.constant(1.0))
output_map = tf.nn.relu(conv + bias)
up_h_convs["out"] = output_map
with tf.name_scope("output_map"):
weight = weight_variable([1, 1, features_root, n_class], stddev)
bias = bias_variable([n_class], name="bias")
conv = conv2d(in_node, weight, tf.constant(1.0))
output_map = tf.nn.relu(conv + bias)
up_h_convs["out"] = output_map

if summaries:
for i, (c1, c2) in enumerate(convs):
Expand Down Expand Up @@ -185,22 +190,25 @@ def __init__(self, channels=3, n_class=2, cost="cross_entropy", cost_kwargs={},
self.n_class = n_class
self.summaries = kwargs.get("summaries", True)

self.x = tf.placeholder("float", shape=[None, None, None, channels])
self.y = tf.placeholder("float", shape=[None, None, None, n_class])
self.keep_prob = tf.placeholder(tf.float32) # dropout (keep probability)
self.x = tf.placeholder("float", shape=[None, None, None, channels], name="x")
self.y = tf.placeholder("float", shape=[None, None, None, n_class], name="y")
self.keep_prob = tf.placeholder(tf.float32, name="dropout_probability") # dropout (keep probability)

logits, self.variables, self.offset = create_conv_net(self.x, self.keep_prob, channels, n_class, **kwargs)

self.cost = self._get_cost(logits, cost, cost_kwargs)

self.gradients_node = tf.gradients(self.cost, self.variables)

self.cross_entropy = tf.reduce_mean(cross_entropy(tf.reshape(self.y, [-1, n_class]),
tf.reshape(pixel_wise_softmax_2(logits), [-1, n_class])))

self.predicter = pixel_wise_softmax_2(logits)
self.correct_pred = tf.equal(tf.argmax(self.predicter, 3), tf.argmax(self.y, 3))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))
with tf.name_scope("xent"):
self.cross_entropy = cross_entropy(tf.reshape(self.y, [-1, n_class]),
tf.reshape(pixel_wise_softmax_2(logits), [-1, n_class]))

with tf.name_scope("results"):
self.predicter = pixel_wise_softmax_2(logits)
self.correct_pred = tf.equal(tf.argmax(self.predicter, 3), tf.argmax(self.y, 3))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))

def _get_cost(self, logits, cost_name, cost_kwargs):
"""
Expand All @@ -210,42 +218,43 @@ def _get_cost(self, logits, cost_name, cost_kwargs):
regularizer: power of the L2 regularizers added to the loss function
"""

flat_logits = tf.reshape(logits, [-1, self.n_class])
flat_labels = tf.reshape(self.y, [-1, self.n_class])
if cost_name == "cross_entropy":
class_weights = cost_kwargs.pop("class_weights", None)
with tf.name_scope("cost"):
flat_logits = tf.reshape(logits, [-1, self.n_class])
flat_labels = tf.reshape(self.y, [-1, self.n_class])
if cost_name == "cross_entropy":
class_weights = cost_kwargs.pop("class_weights", None)

if class_weights is not None:
class_weights = tf.constant(np.array(class_weights, dtype=np.float32))
if class_weights is not None:
class_weights = tf.constant(np.array(class_weights, dtype=np.float32))

weight_map = tf.multiply(flat_labels, class_weights)
weight_map = tf.reduce_sum(weight_map, axis=1)
weight_map = tf.multiply(flat_labels, class_weights)
weight_map = tf.reduce_sum(weight_map, axis=1)

loss_map = tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits,
labels=flat_labels)
weighted_loss = tf.multiply(loss_map, weight_map)
loss_map = tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits,
labels=flat_labels)
weighted_loss = tf.multiply(loss_map, weight_map)

loss = tf.reduce_mean(weighted_loss)
loss = tf.reduce_mean(weighted_loss)

else:
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits,
labels=flat_labels))
elif cost_name == "dice_coefficient":
eps = 1e-5
prediction = pixel_wise_softmax_2(logits)
intersection = tf.reduce_sum(prediction * self.y)
union = eps + tf.reduce_sum(prediction) + tf.reduce_sum(self.y)
loss = -(2 * intersection / (union))
else:
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits,
labels=flat_labels))
elif cost_name == "dice_coefficient":
eps = 1e-5
prediction = pixel_wise_softmax_2(logits)
intersection = tf.reduce_sum(prediction * self.y)
union = eps + tf.reduce_sum(prediction) + tf.reduce_sum(self.y)
loss = -(2 * intersection / (union))

else:
raise ValueError("Unknown cost function: " % cost_name)
else:
raise ValueError("Unknown cost function: " % cost_name)

regularizer = cost_kwargs.pop("regularizer", None)
if regularizer is not None:
regularizers = sum([tf.nn.l2_loss(variable) for variable in self.variables])
loss += (regularizer * regularizers)
regularizer = cost_kwargs.pop("regularizer", None)
if regularizer is not None:
regularizers = sum([tf.nn.l2_loss(variable) for variable in self.variables])
loss += (regularizer * regularizers)

return loss
return loss

def predict(self, model_path, x_test):
"""
Expand Down Expand Up @@ -332,7 +341,7 @@ def _get_optimizer(self, training_iters, global_step):
global_step=global_step)
elif self.optimizer == "adam":
learning_rate = self.opt_kwargs.pop("learning_rate", 0.001)
self.learning_rate_node = tf.Variable(learning_rate)
self.learning_rate_node = tf.Variable(learning_rate, name="learning_rate")

optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_node,
**self.opt_kwargs).minimize(self.net.cost,
Expand All @@ -341,9 +350,9 @@ def _get_optimizer(self, training_iters, global_step):
return optimizer

def _initialize(self, training_iters, output_path, restore, prediction_path):
global_step = tf.Variable(0)
global_step = tf.Variable(0, name="global_step")

self.norm_gradients_node = tf.Variable(tf.constant(0.0, shape=[len(self.net.gradients_node)]))
self.norm_gradients_node = tf.Variable(tf.constant(0.0, shape=[len(self.net.gradients_node)]), name="norm_gradients")

if self.net.summaries and self.norm_grads:
tf.summary.histogram('norm_grads', self.norm_gradients_node)
Expand Down

0 comments on commit 2851bda

Please sign in to comment.