Skip to content

Commit

Permalink
simplifing image provider usage
Browse files Browse the repository at this point in the history
  • Loading branch information
jakeret committed Dec 29, 2018
1 parent 804158e commit 9b08e82
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 18 deletions.
4 changes: 4 additions & 0 deletions scripts/ultrasound_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def create_training_path(output_path):
@click.option('--features_root', default=64)
def launch(data_root, output_path, training_iters, epochs, restore, layers, features_root):
print("Using data from: %s"%data_root)

if not os.path.exists(data_root):
raise IOError("Kaggle Ultrasound Dataset not found")

data_provider = ultrasound_util.DataProvider(data_root + "/*.tif",
a_min=0,
a_max=210)
Expand Down
32 changes: 17 additions & 15 deletions tf_unet/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,16 @@ class SimpleDataProvider(BaseDataProvider):
:param label: label numpy array. Shape=[n, X, Y, classes]
:param a_min: (optional) min value used for clipping
:param a_max: (optional) max value used for clipping
:param channels: (optional) number of channels, default=1
:param n_class: (optional) number of classes, default=2
"""

def __init__(self, data, label, a_min=None, a_max=None, channels=1, n_class = 2):
def __init__(self, data, label, a_min=None, a_max=None):
super(SimpleDataProvider, self).__init__(a_min, a_max)
self.data = data
self.label = label
self.file_count = data.shape[0]
self.n_class = n_class
self.channels = channels
self.n_class = label.shape[-1]
self.channels = data.shape[-1]

def _next_data(self):
idx = np.random.choice(self.file_count)
Expand All @@ -149,30 +147,34 @@ class ImageDataProvider(BaseDataProvider):
:param data_suffix: suffix pattern for the data images. Default '.tif'
:param mask_suffix: suffix pattern for the label images. Default '_mask.tif'
:param shuffle_data: if the order of the loaded file path should be randomized. Default 'True'
:param channels: (optional) number of channels, default=1
:param n_class: (optional) number of classes, default=2
"""

def __init__(self, search_path, a_min=None, a_max=None, data_suffix=".tif", mask_suffix='_mask.tif', shuffle_data=True, n_class = 2):
def __init__(self, search_path, a_min=None, a_max=None, data_suffix=".tif", mask_suffix='_mask.tif', shuffle_data=True):
super(ImageDataProvider, self).__init__(a_min, a_max)
self.data_suffix = data_suffix
self.mask_suffix = mask_suffix
self.file_idx = -1
self.shuffle_data = shuffle_data
self.n_class = n_class


self.data_files = self._find_data_files(search_path)

if self.shuffle_data:
np.random.shuffle(self.data_files)

assert len(self.data_files) > 0, "No training files"
print("Number of files used: %s" % len(self.data_files))

img = self._load_file(self.data_files[0])

image_path = self.data_files[0]
label_path = image_path.replace(self.data_suffix, self.mask_suffix)
img = self._load_file(image_path)
mask = self._load_file(label_path)
self.channels = 1 if len(img.shape) == 2 else img.shape[-1]

self.n_class = 2 if len(mask.shape) == 2 else mask.shape[-1]

print("Number of channels: %s"%self.channels)
print("Number of classes: %s"%self.n_class)

def _find_data_files(self, search_path):
all_files = glob.glob(search_path)
return [name for name in all_files if self.data_suffix in name and not self.mask_suffix in name]
Expand Down
6 changes: 3 additions & 3 deletions tf_unet/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,13 @@ class Unet(object):
"""
A unet implementation
:param channels: (optional) number of channels in the input image
:param n_class: (optional) number of output labels
:param channels: number of channels in the input image
:param n_class: number of output labels
:param cost: (optional) name of the cost function. Default is 'cross_entropy'
:param cost_kwargs: (optional) kwargs passed to the cost function. See Unet._get_cost for more options
"""

def __init__(self, channels=3, n_class=2, cost="cross_entropy", cost_kwargs={}, **kwargs):
def __init__(self, channels, n_class, cost="cross_entropy", cost_kwargs={}, **kwargs):
tf.reset_default_graph()

self.n_class = n_class
Expand Down

0 comments on commit 9b08e82

Please sign in to comment.