Skip to content

Commit

Permalink
minor changes to improve visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
jr0th committed May 8, 2017
1 parent 3878ba2 commit 4360b34
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ data/
logs/
checkpoints/
out/
out_boundary/
out_boundary_*/
results/

experiments/BBBC022_hand/GT_labels
Expand Down
85 changes: 75 additions & 10 deletions code/helper/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,19 @@ def visualize_boundary_hard(pred_y, true_x, true_y, out_dir='./', label=''):

for sampleIndex in range(nSamples):

nCols = 4
figure, axes = plt.subplots(ncols=nCols, nrows=1, figsize=(nCols*5, 5))
figure.tight_layout(pad = 1)
nCols = 3
nRows = 2
figure, axes = plt.subplots(ncols=nCols, nrows=2, figsize=(nCols*5+2, nRows*5+2))
# figure.tight_layout(pad = 1)

predFig = axes[0]
trueFig = axes[1]
compFig = axes[2]
cmatFig = axes[3]
origFig = axes[0,0]
trueFig = axes[0,1]

predProbMapFig = axes[1,0]
predFig = axes[1,1]

compFig = axes[0,2]
cmatFig = axes[1,2]

pred_prob_map = pred_y[sampleIndex,:,:,0]
pred = pred_prob_map >= 0.5
Expand All @@ -113,14 +118,25 @@ def visualize_boundary_hard(pred_y, true_x, true_y, out_dir='./', label=''):

cmat = sklearn.metrics.confusion_matrix(y_true = true.flatten(), y_pred = pred.flatten()) #, labels=[0,1])

predFig.imshow(skimage.color.label2rgb(pred, image=true_x[sampleIndex,:,:,0]))
mappable = origFig.imshow(true_x[sampleIndex,:,:,0])
figure.colorbar(mappable, ax=origFig)

trueFig.imshow(skimage.color.label2rgb(true, image=true_x[sampleIndex,:,:,0]))

mappable = predProbMapFig.imshow(pred_prob_map)
figure.colorbar(mappable, ax=predProbMapFig)

predFig.imshow(skimage.color.label2rgb(pred, image=true_x[sampleIndex,:,:,0]))

compFig.imshow(skimage.color.label2rgb(comp, image=true_x[sampleIndex,:,:,0]))
cmatFig.matshow(cmat, cmap = "cool")

predFig.set_title('Prediction')

predProbMapFig.set_title('Prediction (not thresholded)')
origFig.set_title('Image')
predFig.set_title('Prediction (thresholded)')
trueFig.set_title('Truth')
compFig.set_title('Errors')
cmatFig.set_title('Confusion Matrix')

predFig.axis('off')
trueFig.axis('off')
Expand All @@ -145,6 +161,55 @@ def visualize_boundary_hard(pred_y, true_x, true_y, out_dir='./', label=''):
f = open(out_dir + '/' + label + '_' + str(sampleIndex) + '.txt', 'w')
f.write('Cross Entropy: ' + str(ce) + '\n')
f.close()

def visualize_boundary_soft(pred_y, true_x, true_y, out_dir='./', label=''):

plt.figure()
plt.hist(pred_y.flatten())
plt.savefig(out_dir + label + '_' + 'hist_probmap_boundary')

# print all samples for visual inspection
nSamples = pred_y.shape[0]

for sampleIndex in range(nSamples):

nCols = 4
figure, axes = plt.subplots(ncols=nCols, nrows=1, figsize=(nCols*5+2, 5+2))

origFig = axes[0]
predFig = axes[1]
trueFig = axes[2]
compFig = axes[3]

pred_prob_map = pred_y[sampleIndex,:,:,0]

true_prob_map = true_y[sampleIndex,:,:,0]

comp = pred_prob_map - true_prob_map

origFig.imshow(true_x[sampleIndex,:,:,0])
predFig.imshow(pred_prob_map)
trueFig.imshow(true_prob_map)
compFig.imshow(comp)

origFig.set_title('Image')
predFig.set_title('Prediction')
trueFig.set_title('Truth')
compFig.set_title('Errors (MSE)')

predFig.axis('off')
trueFig.axis('off')
compFig.axis('off')

plt.savefig(out_dir + label + '_' + str(sampleIndex) + '_vis')
classNames = ['background', 'boundary']

# write mean squared error
ce = sklearn.metrics.mean_squared_error(y_pred = pred_prob_map.flatten(), y_true = true_prob_map.flatten())

f = open(out_dir + '/' + label + '_' + str(sampleIndex) + '.txt', 'w')
f.write('Mean Squared Error: ' + str(ce) + '\n')
f.close()

def visualize_learning_stats(statistics, out_dir, metrics):
plt.figure()
Expand Down
12 changes: 6 additions & 6 deletions code/training-hand-200-boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@
import numpy as np

# constants
const_lr = 1e-3
const_lr = 1e-4

chkpt_file = "../checkpoints/checkpoint_boundary_2.hdf5"
chkpt_file = "../checkpoints/checkpoint_boundary_6.hdf5"

out_dir = "../out_boundary_2/"
out_dir = "../out_boundary_6/"
tb_log_dir = "../logs/logs_tensorboard_boundary/"

train_dir_x = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/training/x'
train_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/training/y_boundary_2'
train_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/training/y_boundary_6'

val_dir_x = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/validation/x'
val_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/validation/y_boundary_2'
val_dir_y = '/home/jr0th/github/segmentation/data/BBBC022_hand_200/validation/y_boundary_6'

data_type = "images" # "images" or "array"

nb_epoch = 10 # 500
nb_epoch = 30 # 500
batch_size = 10
nb_batches = int(400 / batch_size) # 100 images, 400 patches

Expand Down

0 comments on commit 4360b34

Please sign in to comment.