Skip to content

Commit

Permalink
numerically stable cross entropy computation
Browse files Browse the repository at this point in the history
  • Loading branch information
jakeret committed Jun 25, 2018
1 parent 1d4d552 commit 67bd0ba
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 19 deletions.
18 changes: 6 additions & 12 deletions tf_unet/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,11 @@ def crop_and_concat(x1,x2):

def pixel_wise_softmax(output_map):
with tf.name_scope("pixel_wise_softmax"):
exponential_map = tf.exp(output_map)
evidence = tf.add(exponential_map,tf.reverse(exponential_map,[False,False,False,True]))
return tf.div(exponential_map,evidence, name="pixel_wise_softmax")

def pixel_wise_softmax_2(output_map):
with tf.name_scope("pixel_wise_softmax_2"):
exponential_map = tf.exp(output_map)
sum_exp = tf.reduce_sum(exponential_map, 3, keepdims=True)
tensor_sum_exp = tf.tile(sum_exp, tf.stack([1, 1, 1, tf.shape(output_map)[3]]))
return tf.div(exponential_map,tensor_sum_exp)
max_axis = tf.reduce_max(output_map, axis=3, keepdims=True)
exponential_map = tf.exp(output_map - max_axis)
normalize = tf.reduce_sum(exponential_map, axis=3, keepdims=True)
return exponential_map / normalize

def cross_entropy(y_,output_map):
return -tf.reduce_mean(y_*tf.log(tf.clip_by_value(output_map,1e-10,1.0)), name="cross_entropy")
#return tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(output_map), reduction_indices=[1]))
with tf.name_scope("xent"):
return -tf.reduce_mean(y_*tf.log(tf.clip_by_value(output_map,1e-10,1.0)), name="cross_entropy")
12 changes: 5 additions & 7 deletions tf_unet/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from tf_unet import util
from tf_unet.layers import (weight_variable, weight_variable_devonc, bias_variable,
conv2d, deconv2d, max_pool, crop_and_concat, pixel_wise_softmax_2,
conv2d, deconv2d, max_pool, crop_and_concat, pixel_wise_softmax,
cross_entropy)

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
Expand Down Expand Up @@ -200,13 +200,11 @@ def __init__(self, channels=3, n_class=2, cost="cross_entropy", cost_kwargs={},

self.gradients_node = tf.gradients(self.cost, self.variables)


with tf.name_scope("xent"):
self.cross_entropy = cross_entropy(tf.reshape(self.y, [-1, n_class]),
tf.reshape(pixel_wise_softmax_2(logits), [-1, n_class]))
self.cross_entropy = cross_entropy(tf.reshape(self.y, [-1, n_class]),
tf.reshape(pixel_wise_softmax(logits), [-1, n_class]))

with tf.name_scope("results"):
self.predicter = pixel_wise_softmax_2(logits)
self.predicter = pixel_wise_softmax(logits)
self.correct_pred = tf.equal(tf.argmax(self.predicter, 3), tf.argmax(self.y, 3))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))

Expand Down Expand Up @@ -241,7 +239,7 @@ def _get_cost(self, logits, cost_name, cost_kwargs):
labels=flat_labels))
elif cost_name == "dice_coefficient":
eps = 1e-5
prediction = pixel_wise_softmax_2(logits)
prediction = pixel_wise_softmax(logits)
intersection = tf.reduce_sum(prediction * self.y)
union = eps + tf.reduce_sum(prediction) + tf.reduce_sum(self.y)
loss = -(2 * intersection / (union))
Expand Down

0 comments on commit 67bd0ba

Please sign in to comment.