Skip to content

Commit

Permalink
cleaned up training script, minor adaptions due to runs of experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
jr0th committed May 16, 2017
1 parent c6e55b5 commit 6107fba
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 24 deletions.
6 changes: 3 additions & 3 deletions code/predict_generator-boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@


out_label = 'pred_generator'
out_dir = '/home/jr0th/github/segmentation/out_boundary_4_generator_no_skip/'
out_dir = '/home/jr0th/github/segmentation/out_boundary_4_generator/'

weights_path = '/home/jr0th/github/segmentation/checkpoints/checkpoint_boundary_4_generator_no_skip.hdf5'
weights_path = '/home/jr0th/github/segmentation/checkpoints/checkpoint_boundary_4_generator.hdf5'
batch_size = 10
bit_depth = 8

Expand All @@ -31,7 +31,7 @@
)

# build model and laod weights
model = helper.model_builder.get_model_1_class_no_skip(dim1, dim2)
model = helper.model_builder.get_model_1_class(dim1, dim2)
model.load_weights(weights_path)

# get one batch of data from the generator
Expand Down
4 changes: 2 additions & 2 deletions code/predict_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@


out_label = 'pred_generator'
out_dir = '/home/jr0th/github/segmentation/out/'
out_dir = '/home/jr0th/github/segmentation/out_3class_random/'
# data_dir = '/home/jr0th/github/segmentation/data/BBBC022/'
weights_path = '/home/jr0th/github/segmentation/checkpoints/checkpoint.hdf5'
weights_path = '/home/jr0th/github/segmentation/checkpoints/checkpoint_3class_random.hdf5'
batch_size = 10
bit_depth = 8

Expand Down
35 changes: 16 additions & 19 deletions code/training-hand-200.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,27 @@
# constants
const_lr = 1e-4

out_dir = "../out/"
tb_log_dir = "../logs/logs_tensorboard/"
out_dir = "../out_3class_random_4/"
tb_log_dir = "../logs/logs_tensorboard_3class_random_4/"

train_dir_x = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/training/x'
train_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/training/y_label_binary'
out_dir_log = "../logs/log_3class_random_4.csv"
checkpoint_path = "../checkpoints/checkpoint_3class_random_4.hdf5"

train_dir_x = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/training/x_big/'
train_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/training/y_big_label_binary_4/'

val_dir_x = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/validation/x'
val_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/validation/y_label_binary'
val_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/validation/y_label_binary_4'

data_type = "images" # "images" or "array"

nb_epoch = 500
nb_epoch = 100
batch_size = 10
nb_batches = int(400 / batch_size) # 100 images, 400 patches

# images and masks are in 8 bit
bit_depth = 8

# SEGMENTATION DATA GENERATOR
file_path = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/all_files_wo_ext.txt'
classes = 3

# make sure these matches number for to the validation set
val_steps = int(200 / batch_size) # 50 images, 200 patches

Expand All @@ -57,19 +56,15 @@
# build session running on GPU 1
configuration = tf.ConfigProto()
configuration.gpu_options.allow_growth = True
configuration.gpu_options.visible_device_list = "1"
configuration.gpu_options.visible_device_list = "0"
session = tf.Session(config = configuration)

# apply session
keras.backend.set_session(session)

# get training generator

#train_gen = helper.data_provider.data_from_images_segmentation(file_path, images_dir, labels_dir, classes, batch_size, dim1, dim2)
train_gen = helper.data_provider.single_data_from_images(train_dir_x, train_dir_y, batch_size, bit_depth, dim1, dim2)
# get data generators
train_gen = helper.data_provider.random_sample_generator(train_dir_x, train_dir_y, batch_size, bit_depth, dim1, dim2)
val_gen = helper.data_provider.single_data_from_images(val_dir_x, val_dir_y, batch_size, bit_depth, dim1, dim2)

callback_splits_and_merges = helper.callbacks.SplitsAndMergesLogger(data_type, val_gen, gen_calls = val_steps, log_dir='../logs/logs_tensorboard')

# build model
model = helper.model_builder.get_model_3_class(dim1, dim2)
Expand All @@ -81,8 +76,10 @@

# CALLBACKS
# save model after each epoch
callback_model_checkpoint = keras.callbacks.ModelCheckpoint(filepath="../checkpoints/checkpoint.hdf5", save_weights_only=True, save_best_only=True)
callback_csv = keras.callbacks.CSVLogger(filename="../logs/log.csv")
callback_splits_and_merges = helper.callbacks.SplitsAndMergesLogger3Class(data_type, val_gen, gen_calls = val_steps, log_dir=tb_log_dir)
callback_model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, save_best_only=True)
callback_csv = keras.callbacks.CSVLogger(filename=out_dir_log)

callback_tensorboard = keras.callbacks.TensorBoard(log_dir=tb_log_dir, histogram_freq=1)

callbacks=[callback_model_checkpoint, callback_csv, callback_splits_and_merges]
Expand Down

0 comments on commit 6107fba

Please sign in to comment.