Skip to content

Commit

Permalink
changed training script, saving checkpoints after each epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
jr0th committed May 22, 2017
1 parent 343445c commit fc654a6
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions code/training-hand-200-boundary-random-patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
# constants
const_lr = 1e-4

chkpt_file = "../checkpoints/checkpoint_boundary_4_generator.hdf5"
csv_log_file = "../logs/log_boundary_4_generator.csv"
chkpt_file = "../checkpoints/boundary_4_random_augment/checkpoint_{epoch:04d}.hdf5"
csv_log_file = "../logs/log_boundary_4_random_augment.csv"

out_dir = "../out_boundary_4_generator/"
tb_log_dir = "../logs/logs_tensorboard_boundary_4_generator/"
out_dir = "../out_boundary_4_random_augment/"
tb_log_dir = "../logs/logs_tensorboard_boundary_4_random_augment/"

train_dir_x = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/random_patches/training/x_big/all/'
train_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/random_patches/training/y_big_boundary_4/all/'
train_dir_x = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/training/x_big/'
train_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/training/y_big_boundary_4/'

val_dir_x = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/validation/x/'
val_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/validation/y_boundary_4/'
Expand All @@ -58,7 +58,7 @@
# build session running on a specific GPU
configuration = tf.ConfigProto()
configuration.gpu_options.allow_growth = True
configuration.gpu_options.visible_device_list = "1"
configuration.gpu_options.visible_device_list = "2"
session = tf.Session(config = configuration)

# apply session
Expand All @@ -83,7 +83,7 @@

# CALLBACKS
# save model after each epoch
callback_model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=chkpt_file, save_weights_only=True, save_best_only=True)
callback_model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=chkpt_file, save_weights_only=True, save_best_only=False)
callback_csv = keras.callbacks.CSVLogger(filename=csv_log_file)
callback_splits_and_merges = helper.callbacks.SplitsAndMergesLoggerBoundary(data_type, val_gen, gen_calls = val_steps, log_dir=tb_log_dir)
callback_tensorboard = keras.callbacks.TensorBoard(log_dir=tb_log_dir, histogram_freq=1)
Expand Down

0 comments on commit fc654a6

Please sign in to comment.