Skip to content

Commit

Permalink
Merge pull request jakeret#265 from siavashk/master
Browse files Browse the repository at this point in the history
Performing a check to make sure that the segmentation map from the ch…
  • Loading branch information
jakeret authored May 7, 2019
2 parents 2cc5e58 + 24314c1 commit f533550
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 45 deletions.
60 changes: 33 additions & 27 deletions tf_unet/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
#
# tf_unet is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
#
# You should have received a copy of the GNU General Public License
# along with tf_unet. If not, see <http://www.gnu.org/licenses/>.

Expand All @@ -34,7 +34,7 @@ class BaseDataProvider(object):
:param a_max: (optional) max value used for clipping
"""

channels = 1
n_class = 2

Expand All @@ -44,28 +44,34 @@ def __init__(self, a_min=None, a_max=None):

def _load_data_and_label(self):
data, label = self._next_data()

train_data = self._process_data(data)
labels = self._process_labels(label)

train_data, labels = self._post_process(train_data, labels)

nx = train_data.shape[1]
ny = train_data.shape[0]

return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class),

def _process_labels(self, label):
if self.n_class == 2:
nx = label.shape[1]
ny = label.shape[0]
labels = np.zeros((ny, nx, self.n_class), dtype=np.float32)

# It is the responsibility of the child class to make sure that the label
# is a boolean array, but we a chech here just in case.
if label.dtype != 'bool':
label = label.astype(np.bool)

labels[..., 1] = label
labels[..., 0] = ~label
return labels

return label

def _process_data(self, data):
# normalization
data = np.clip(np.fabs(data), self.a_min, self.a_max)
Expand All @@ -75,37 +81,37 @@ def _process_data(self, data):
data /= np.amax(data)

return data

def _post_process(self, data, labels):
"""
Post processing hook that can be used for data augmentation
:param data: the data array
:param labels: the label array
"""
return data, labels

def __call__(self, n):
train_data, labels = self._load_data_and_label()
nx = train_data.shape[1]
ny = train_data.shape[2]

X = np.zeros((n, nx, ny, self.channels))
Y = np.zeros((n, nx, ny, self.n_class))

X[0] = train_data
Y[0] = labels
for i in range(1, n):
train_data, labels = self._load_data_and_label()
X[i] = train_data
Y[i] = labels

return X, Y


class SimpleDataProvider(BaseDataProvider):
"""
A simple data provider for numpy arrays.
A simple data provider for numpy arrays.
Assumes that the data and label are numpy array with the dimensions
data `[n, X, Y, channels]`, label `[n, X, Y, classes]`. Where
`n` is the number of images, `X`, `Y` the size of the image.
Expand All @@ -116,7 +122,7 @@ class SimpleDataProvider(BaseDataProvider):
:param a_max: (optional) max value used for clipping
"""

def __init__(self, data, label, a_min=None, a_max=None):
super(SimpleDataProvider, self).__init__(a_min, a_max)
self.data = data
Expand All @@ -134,13 +140,13 @@ class ImageDataProvider(BaseDataProvider):
"""
Generic data provider for images, supports gray scale and colored images.
Assumes that the data images and label images are stored in the same folder
and that the labels have a different file suffix
and that the labels have a different file suffix
e.g. 'train/fish_1.tif' and 'train/fish_1_mask.tif'
Number of pixels in x and y of the images and masks should be even.
Usage:
data_provider = ImageDataProvider("..fishes/train/*.tif")
:param search_path: a glob search pattern to find all data and label images
:param a_min: (optional) min value used for clipping
:param a_max: (optional) max value used for clipping
Expand All @@ -149,7 +155,7 @@ class ImageDataProvider(BaseDataProvider):
:param shuffle_data: if the order of the loaded file path should be randomized. Default 'True'
"""

def __init__(self, search_path, a_min=None, a_max=None, data_suffix=".tif", mask_suffix='_mask.tif', shuffle_data=True):
super(ImageDataProvider, self).__init__(a_min, a_max)
self.data_suffix = data_suffix
Expand All @@ -158,10 +164,10 @@ def __init__(self, search_path, a_min=None, a_max=None, data_suffix=".tif", mask
self.shuffle_data = shuffle_data

self.data_files = self._find_data_files(search_path)

if self.shuffle_data:
np.random.shuffle(self.data_files)

assert len(self.data_files) > 0, "No training files"
print("Number of files used: %s" % len(self.data_files))

Expand All @@ -178,23 +184,23 @@ def __init__(self, search_path, a_min=None, a_max=None, data_suffix=".tif", mask
def _find_data_files(self, search_path):
all_files = glob.glob(search_path)
return [name for name in all_files if self.data_suffix in name and not self.mask_suffix in name]

def _load_file(self, path, dtype=np.float32):
return np.array(Image.open(path), dtype)

def _cylce_file(self):
self.file_idx += 1
if self.file_idx >= len(self.data_files):
self.file_idx = 0
self.file_idx = 0
if self.shuffle_data:
np.random.shuffle(self.data_files)

def _next_data(self):
self._cylce_file()
image_name = self.data_files[self.file_idx]
label_name = image_name.replace(self.data_suffix, self.mask_suffix)

img = self._load_file(image_name, np.float32)
label = self._load_file(label_name, np.bool)

return img,label
56 changes: 38 additions & 18 deletions tf_unet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
#
# tf_unet is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
#
# You should have received a copy of the GNU General Public License
# along with tf_unet. If not, see <http://www.gnu.org/licenses/>.

Expand All @@ -27,13 +27,13 @@
def plot_prediction(x_test, y_test, prediction, save=False):
import matplotlib
import matplotlib.pyplot as plt

test_size = x_test.shape[0]
fig, ax = plt.subplots(test_size, 3, figsize=(12,12), sharey=True, sharex=True)

x_test = crop_to_shape(x_test, prediction.shape)
y_test = crop_to_shape(y_test, prediction.shape)

ax = np.atleast_2d(ax)
for i in range(test_size):
cax = ax[i, 0].imshow(x_test[i])
Expand All @@ -50,7 +50,7 @@ def plot_prediction(x_test, y_test, prediction, save=False):
ax[i, 1].set_title("y")
ax[i, 2].set_title("pred")
fig.tight_layout()

if save:
fig.savefig(save)
else:
Expand All @@ -61,17 +61,17 @@ def to_rgb(img):
"""
Converts the given array into a RGB image. If the number of channels is not
3 the array is tiled such that it has 3 channels. Finally, the values are
rescaled to [0,255)
rescaled to [0,255)
:param img: the array to convert [nx, ny, channels]
:returns img: the rgb image [nx, ny, 3]
"""
img = np.atleast_3d(img)
channels = img.shape[2]
if channels < 3:
img = np.tile(img, 3)

img[np.isnan(img)] = 0
img -= np.amin(img)
if np.amax(img) != 0:
Expand All @@ -83,7 +83,7 @@ def to_rgb(img):
def crop_to_shape(data, shape):
"""
Crops the array to the given image shape by removing the border (expects a tensor of shape [batches, nx, ny, channels].
:param data: the array to crop
:param shape: the target shape
"""
Expand All @@ -101,27 +101,47 @@ def crop_to_shape(data, shape):
assert cropped.shape[2] == shape[2]
return cropped

def expand_to_shape(data, shape, border=0):
"""
Expands the array to the given image shape by padding it with a border (expects a tensor of shape [batches, nx, ny, channels].
:param data: the array to expand
:param shape: the target shape
"""
diff_nx = shape[1] - data.shape[1]
diff_ny = shape[2] - data.shape[2]

offset_nx_left = diff_nx // 2
offset_nx_right = diff_nx - offset_nx_left
offset_ny_left = diff_ny // 2
offset_ny_right = diff_ny - offset_ny_left

expanded = np.full(shape, border, dtype=np.float32)
expanded[:, offset_nx_left:(-offset_nx_right), offset_ny_left:(-offset_ny_right)] = data

return expanded

def combine_img_prediction(data, gt, pred):
"""
Combines the data, grouth thruth and the prediction into one rgb image
:param data: the data tensor
:param gt: the ground thruth tensor
:param pred: the prediction tensor
:returns img: the concatenated rgb image
:returns img: the concatenated rgb image
"""
ny = pred.shape[2]
ch = data.shape[3]
img = np.concatenate((to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)),
to_rgb(crop_to_shape(gt[..., 1], pred.shape).reshape(-1, ny, 1)),
img = np.concatenate((to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)),
to_rgb(crop_to_shape(gt[..., 1], pred.shape).reshape(-1, ny, 1)),
to_rgb(pred[..., 1].reshape(-1, ny, 1))), axis=1)
return img

def save_image(img, path):
"""
Writes the image to disk
:param img: the rgb image to save
:param path: the target path
"""
Expand All @@ -140,4 +160,4 @@ def create_training_path(output_path, prefix="run_"):
while os.path.exists(path):
idx += 1
path = os.path.join(output_path, "{:}{:03d}".format(prefix, idx))
return path
return path

0 comments on commit f533550

Please sign in to comment.