Skip to content

Commit

Permalink
bug fixes for training with random patches
Browse files Browse the repository at this point in the history
  • Loading branch information
jr0th committed May 16, 2017
1 parent 4b17482 commit 73343ff
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 31 deletions.
39 changes: 35 additions & 4 deletions code/helper/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import keras.preprocessing.image
import helper.external.SegDataGenerator

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import skimage.io

def data_from_array(data_dir):
Expand Down Expand Up @@ -78,7 +82,7 @@ def single_data_from_images(x_dir, y_dir, batch_size, bit_depth, dim1, dim2):
def single_data_from_images_1d_y(x_dir, y_dir, batch_size, bit_depth, dim1, dim2):

rescale_factor = 1./(2**bit_depth - 1)
rescale_labels = True
rescale_labels = False

if(rescale_labels):
rescale_factor_labels = rescale_factor
Expand Down Expand Up @@ -114,12 +118,25 @@ def single_data_from_images_1d_y(x_dir, y_dir, batch_size, bit_depth, dim1, dim2

def random_sample_generator(x_big_dir, y_big_dir, batch_size, bit_depth, dim1, dim2):

debug = False

# get images
x_big = skimage.io.imread_collection(x_big_dir + '*.png').concatenate()
print('Found',len(x_big), 'images.')
y_big = skimage.io.imread_collection(y_big_dir + '*.png').concatenate()
print('Found',len(y_big), 'annotations.')

if(debug):
fig = plt.figure()
plt.hist(y_big.flatten())
plt.savefig('/home/jr0th/github/segmentation/code/generated/y_hist')
plt.close(fig)

fig = plt.figure()
plt.hist(x_big.flatten())
plt.savefig('/home/jr0th/github/segmentation/code/generated/x_hist')
plt.close(fig)

# get dimensions right – understand data set
n_images = x_big.shape[0]
dim1_size = x_big.shape[1]
Expand All @@ -133,7 +150,7 @@ def random_sample_generator(x_big_dir, y_big_dir, batch_size, bit_depth, dim1, d
rescale_factor_labels = rescale_factor
else:
rescale_factor_labels = 1

while(True):

# buffers for a batch of data
Expand All @@ -151,11 +168,25 @@ def random_sample_generator(x_big_dir, y_big_dir, batch_size, bit_depth, dim1, d
start_dim2 = np.random.randint(low=0, high=dim2_size+1-dim2)

# save image to buffer
x[i, :, :, 0] = x_big[img_index, start_dim1, start_dim1 + dim1] * rescale_factor
y[i, :, :, 0] = x_big[img_index, start_dim1, start_dim1 + dim1] * rescale_factor_labels
x[i, :, :, 0] = x_big[img_index, start_dim1:start_dim1 + dim1, start_dim2:start_dim2 + dim2] * rescale_factor
y[i, :, :, 0] = y_big[img_index, start_dim1:start_dim1 + dim1, start_dim2:start_dim2 + dim2] * rescale_factor_labels

if(debug):
fig = plt.figure()
plt.imshow(x[i, :, :, 0])
plt.colorbar()
plt.savefig('/home/jr0th/github/segmentation/code/generated/x_' + str(i))
plt.close(fig)

fig = plt.figure()
plt.imshow(y[i, :, :, 0])
plt.colorbar()
plt.savefig('/home/jr0th/github/segmentation/code/generated/y_' + str(i))
plt.close(fig)

# return the buffer
yield(x, y)



def single_data_from_images_random(x_dir, y_dir, batch_size, bit_depth, dim1, dim2):
Expand Down
113 changes: 113 additions & 0 deletions code/helper/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,117 @@ def get_model_1_class(dim1, dim2):

model = keras.models.Model(x,y)

return model


def get_model_1_class_no_skip(dim1, dim2):

x = keras.layers.Input(shape=(dim1, dim2, 1))

# DOWN

y = keras.layers.Convolution2D(64, 3, 3, **option_dict_conv)(x)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.Convolution2D(64, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.MaxPooling2D()(y)

y = keras.layers.Convolution2D(128, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.Convolution2D(128, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.MaxPooling2D()(y)

y = keras.layers.Convolution2D(256, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.Convolution2D(256, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.MaxPooling2D()(y)

y = keras.layers.Convolution2D(512, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.Convolution2D(512, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

# UP

y = keras.layers.UpSampling2D()(y)

y = keras.layers.Convolution2D(256, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.Convolution2D(256, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.UpSampling2D()(y)

y = keras.layers.Convolution2D(128, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.Convolution2D(128, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.UpSampling2D()(y)

# HEAD

y = keras.layers.Convolution2D(64, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.Convolution2D(64, 3, 3, **option_dict_conv)(y)
if FLAG_BN:
y = keras.layers.BatchNormalization(**option_dict_bn)(y)
if FLAG_DO_LAST_LAYER:
y = keras.layers.Dropout(CONST_DO_RATE)(y)

y = keras.layers.Convolution2D(1, 1, padding="same", activation="sigmoid")(y)

model = keras.models.Model(x,y)

return model
32 changes: 16 additions & 16 deletions code/helper/visualize.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@

import matplotlib
matplotlib.use('PDF')

matplotlib.use('Agg')
import matplotlib.pyplot as plt

import skimage.io
import sklearn.metrics

import numpy as np

out_format = 'svg'

def visualize(pred_y, true_x, true_y, out_dir='./', label=''):

# TODO
skimage.io.imsave(out_dir + label + '_' + 'img_probmap_boundary_test.png', pred_y[1,:,:,2])

plt.figure()
plt.hist(pred_y[1,:,:,2].flatten())
plt.savefig(out_dir + label + '_' + 'hist_probmap_boundary')
plt.savefig(out_dir + label + '_' + 'hist_probmap_boundary' + '.' + out_format, format=out_format)

# print all samples for visual inspection
nSamples = pred_y.shape[0]
Expand Down Expand Up @@ -66,7 +66,7 @@ def visualize(pred_y, true_x, true_y, out_dir='./', label=''):
horizontalalignment='center',
verticalalignment='center', fontsize = 15)

plt.savefig(out_dir + label + '_' + str(sampleIndex) + '_vis')
plt.savefig(out_dir + label + '_' + str(sampleIndex) + '_vis' + '.' + out_format, format=out_format)
classNames = ['background', 'interior', 'boundary']
f = open(out_dir + '/' + label + '_' + str(sampleIndex) + '.txt', 'w')
f.write(sklearn.metrics.classification_report(pred.flatten(), true.flatten(), target_names=classNames) + '\n')
Expand All @@ -83,7 +83,7 @@ def visualize_boundary_hard(pred_y, true_x, true_y, out_dir='./', label=''):
print('VISUALIZE', pred_y.shape, true_y.shape)
plt.figure()
plt.hist(pred_y.flatten())
plt.savefig(out_dir + label + '_' + 'hist_probmap_boundary')
plt.savefig(out_dir + label + '_' + 'hist_probmap_boundary' + '.' + out_format, format=out_format)

# print all samples for visual inspection
nSamples = pred_y.shape[0]
Expand All @@ -93,7 +93,7 @@ def visualize_boundary_hard(pred_y, true_x, true_y, out_dir='./', label=''):
nCols = 3
nRows = 2
figure, axes = plt.subplots(ncols=nCols, nrows=2, figsize=(nCols*5+2, nRows*5+2))
# figure.tight_layout(pad = 1)
figure.tight_layout(pad = 1)

origFig = axes[0,0]
trueFig = axes[0,1]
Expand Down Expand Up @@ -152,7 +152,7 @@ def visualize_boundary_hard(pred_y, true_x, true_y, out_dir='./', label=''):
horizontalalignment='center',
verticalalignment='center', fontsize = 15)

plt.savefig(out_dir + label + '_' + str(sampleIndex) + '_vis')
plt.savefig(out_dir + label + '_' + str(sampleIndex) + '_vis' + '.' + out_format, format=out_format)
classNames = ['background', 'boundary']

# write cross entropy
Expand All @@ -166,7 +166,7 @@ def visualize_boundary_soft(pred_y, true_x, true_y, out_dir='./', label=''):

plt.figure()
plt.hist(pred_y.flatten())
plt.savefig(out_dir + label + '_' + 'hist_probmap_boundary')
plt.savefig(out_dir + label + '_' + 'hist_probmap_boundary' + '.' + out_format, format=out_format)

# print all samples for visual inspection
nSamples = pred_y.shape[0]
Expand Down Expand Up @@ -201,7 +201,7 @@ def visualize_boundary_soft(pred_y, true_x, true_y, out_dir='./', label=''):
trueFig.axis('off')
compFig.axis('off')

plt.savefig(out_dir + label + '_' + str(sampleIndex) + '_vis')
plt.savefig(out_dir + label + '_' + str(sampleIndex) + '_vis' + '.' + out_format, format=out_format)
classNames = ['background', 'boundary']

# write mean squared error
Expand All @@ -220,7 +220,7 @@ def visualize_learning_stats(statistics, out_dir, metrics):
plt.plot(statistics.history["val_loss"])
plt.legend(["Training", "Validation"])

plt.savefig(out_dir + "plot_loss")
plt.savefig(out_dir + "plot_loss" + '.' + out_format, format=out_format)

plt.figure()

Expand All @@ -230,18 +230,18 @@ def visualize_learning_stats(statistics, out_dir, metrics):
plt.plot(statistics.history["val_categorical_accuracy"])
plt.legend(["Training", "Validation"])

plt.savefig(out_dir + "plot_accuracy")
plt.savefig(out_dir + "plot_accuracy" + '.' + out_format, format=out_format)

def visualize_learning_stats_boundary(statistics, out_dir, metrics):
plt.figure()

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(statistics.history["loss"])
# plt.plot(statistics.history["val_loss"])
plt.legend(["Training"])#, "Validation"])
plt.plot(statistics.history["val_loss"])
plt.legend(["Training", "Validation"])

plt.savefig(out_dir + "plot_loss")
plt.savefig(out_dir + "plot_loss" + '.' + out_format, format=out_format)

plt.figure()

Expand All @@ -251,4 +251,4 @@ def visualize_learning_stats_boundary(statistics, out_dir, metrics):
plt.plot(statistics.history["val_binary_accuracy"])
plt.legend(["Training", "Validation"])

plt.savefig(out_dir + "plot_accuracy")
plt.savefig(out_dir + "plot_accuracy" + '.' + out_format, format=out_format)
10 changes: 4 additions & 6 deletions code/predict_generator-boundary.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import helper.visualize
import helper.data_provider
import helper.model_builder
import helper.visualize

import skimage.io

import matplotlib.pyplot as plt

import sys

data_dir_x = "/home/jr0th/github/segmentation/data/BBBC022_hand_200/test/x/"
data_dir_y = "/home/jr0th/github/segmentation/data/BBBC022_hand_200/test/y_boundary_4/"


out_label = 'pred_generator'
out_dir = '/home/jr0th/github/segmentation/out_boundary_4_generator/'
out_dir = '/home/jr0th/github/segmentation/out_boundary_4_generator_no_skip/'

weights_path = '/home/jr0th/github/segmentation/checkpoints/checkpoint_boundary_4_generator.hdf5'
weights_path = '/home/jr0th/github/segmentation/checkpoints/checkpoint_boundary_4_generator_no_skip.hdf5'
batch_size = 10
bit_depth = 8

Expand All @@ -33,7 +31,7 @@
)

# build model and laod weights
model = helper.model_builder.get_model_1_class(dim1, dim2)
model = helper.model_builder.get_model_1_class_no_skip(dim1, dim2)
model.load_weights(weights_path)

# get one batch of data from the generator
Expand Down
Loading

0 comments on commit 73343ff

Please sign in to comment.