Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update 3D example #267

Merged
merged 2 commits into from
Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 93 additions & 126 deletions examples/vision/3D_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from tensorflow.keras import layers

"""
## Downloading the MosMedData:Chest CT Scans with COVID-19 Related Findings
## Downloading the MosMedData: Chest CT Scans with COVID-19 Related Findings

In this example, we use a subset of the
[MosMedData: Chest CT Scans with COVID-19 Related Findings](https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1).
Expand All @@ -51,8 +51,8 @@
keras.utils.get_file(filename, url)

# Download url of abnormal CT scans.
url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-1.zip"
filename = os.path.join(os.getcwd(), "CT-1.zip")
url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-23.zip"
filename = os.path.join(os.getcwd(), "CT-23.zip")
keras.utils.get_file(filename, url)

# Make a directory to store the data.
Expand All @@ -62,30 +62,33 @@
with zipfile.ZipFile("CT-0.zip", "r") as z_fp:
z_fp.extractall("./MosMedData/")

with zipfile.ZipFile("CT-1.zip", "r") as z_fp:
with zipfile.ZipFile("CT-23.zip", "r") as z_fp:
z_fp.extractall("./MosMedData/")

"""
## Load data
## Loading data and preprocessing

The files are provided in Nifti format with the extension .nii. To read the
scans, we use the `nibabel` package.
You can install the package via `pip install nibabel`.
You can install the package via `pip install nibabel`. CT scans store raw voxel
intensity in Hounsfield units (HU). They range from -1024 to above 2000 in this dataset.
Above 400 are bones with different radiointensity, so this is used as a higher bound. A threshold
between -1000 and 400 is commonly used to normalize CT scans.

To process the data, we do the following:

* We first rotate the volumes by 90 degrees, so the orientation is fixed
* We scale the HU values to be between 0 and 1.
* We resize width, height and depth.

Here we define several helper functions to process the data. These functions
will be used when building training and validation datasets.
"""

import numpy as np

import nibabel as nib
import cv2

from scipy.ndimage import zoom
from scipy import ndimage


def read_nifti_file(filepath):
Expand All @@ -94,45 +97,52 @@ def read_nifti_file(filepath):
scan = nib.load(filepath)
# Get raw data
scan = scan.get_fdata()
# Rotate
scan = np.rot90(np.array(scan))
return scan


def resize_slices(img):
"""Resize width and height"""
# Resize all slices
flatten = [
cv2.resize(img[:, :, i], (128, 128), interpolation=cv2.INTER_CUBIC)
for i in range(img.shape[-1])
]
# Stack along the z-axis
img = np.array(np.dstack(flatten))
return img
def normalize(volume):
"""Normalize the volume"""
min = -1000
max = 400
volume[volume < min] = min
volume[volume > max] = max
volume = (volume - min) / (max - min)
volume = volume.astype("float32")
return volume


def resize_depth(img):
def resize_volume(img):
"""Resize across z-axis"""
# Set the desired depth
desired_depth = 64
desired_width = 128
desired_height = 128
# Get current depth
current_depth = img.shape[-1]
current_width = img.shape[0]
current_height = img.shape[1]
# Compute depth factor
depth = current_depth / desired_depth
width = current_width / desired_width
height = current_height / desired_height
depth_factor = 1 / depth
width_factor = 1 / width
height_factor = 1 / height
# Rotate
img = ndimage.rotate(img, 90, reshape=False)
# Resize across z-axis
img_new = zoom(img, (1, 1, depth_factor), mode="nearest")
return img_new
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
return img


def process_scan(path):
"""Read and resize volume"""
# Read scan
volume = read_nifti_file(path)
# Resize width and height
volume = resize_slices(volume)
# Resize across z-axis
volume = resize_depth(volume)
# Normalize
volume = normalize(volume)
# Resize width, height and depth
volume = resize_volume(volume)
return volume


Expand All @@ -146,70 +156,26 @@ def process_scan(path):
os.path.join(os.getcwd(), "MosMedData/CT-0", x)
for x in os.listdir("MosMedData/CT-0")
]
# Folder "CT-1" consist of CT scans having several ground-glass opacifications,
# Folder "CT-23" consist of CT scans having several ground-glass opacifications,
# involvement of lung parenchyma.
abnormal_scan_paths = [
os.path.join(os.getcwd(), "MosMedData/CT-1", x)
for x in os.listdir("MosMedData/CT-1")
os.path.join(os.getcwd(), "MosMedData/CT-23", x)
for x in os.listdir("MosMedData/CT-23")
]

print("CT scans with normal lung tissue: " + str(len(normal_scan_paths)))
print("CT scans with abnormal lung tissue: " + str(len(abnormal_scan_paths)))

"""
Let's visualize a CT scan and it's shape.
"""

import matplotlib.pyplot as plt

# Read a scan.
img = read_nifti_file(normal_scan_paths[15])
print("Dimension of the CT scan is:", img.shape)
plt.imshow(img[:, :, 15], cmap="gray")

"""
Since a CT scan has many slices, let's visualize a montage of the slices.
"""


def plot_slices(num_rows, num_columns, width, height, data):
"""Plot a montage of 20 CT slices"""
data = np.rot90(np.array(data))
data = np.transpose(data)
data = np.reshape(data, (num_rows, num_columns, width, height))
rows_data, columns_data = data.shape[0], data.shape[1]
heights = [slc[0].shape[0] for slc in data]
widths = [slc.shape[1] for slc in data[0]]
fig_width = 12.0
fig_height = fig_width * sum(heights) / sum(widths)
f, axarr = plt.subplots(
rows_data,
columns_data,
figsize=(fig_width, fig_height),
gridspec_kw={"height_ratios": heights},
)
for i in range(rows_data):
for j in range(columns_data):
axarr[i, j].imshow(data[i][j], cmap="gray")
axarr[i, j].axis("off")
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
plt.show()


# Display 20 slices from the CT scan.
# Here we visualize 20 slices, 2 rows and 10 columns
# adapt it according to your need.
plot_slices(2, 10, 512, 512, img[:, :, :20])

"""
## Build train and validation datasets
Read the scans from the class directories and assign labels. Downsample the scans to have
shape of 128x128x64.
shape of 128x128x64. Rescale the raw HU values to the range 0 to 1.
Lastly, split the dataset into train and validation subsets.
"""

# Read and process the scans.
# Each scan is resized across width, height, and depth.
# Each scan is resized across height, width, and depth and rescaled.
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])

Expand All @@ -229,34 +195,19 @@ def plot_slices(num_rows, num_columns, width, height, data):
)

"""
## Preprocessing and data augmentation
## Data augmentation

CT scans store raw voxel intensity in Hounsfield units (HU). They range from
-1024 to above 2000 in this dataset. Above 400 are bones with different
radiointensity, so this is used as a higher bound. A threshold between
-1000 and 400 is commonly used to normalize CT scans. The CT scans are
also augmented by rotating and blurring. There are different kinds of
preprocessing and augmentation techniques out there, this example shows a few
simple ones to get started.
The CT scans also augmented by rotating at random angles during training. Since
the data is stored in rank-3 tensors of shape `(samples, height, width, depth)`,
we add a dimension of size 1 at axis 4 to be able to perform 3D convolutions on
the data. The new shape is thus `(samples, height, width, depth, 1)`. There are
different kinds of preprocessing and augmentation techniques out there,
this example shows a few simple ones to get started.
"""

import random

from scipy import ndimage
from scipy.ndimage import gaussian_filter


@tf.function
def normalize(volume):
"""Normalize the volume"""
min = -1000
max = 400
volume = volume - min / max - min
volume_min = tf.reduce_min(volume)
volume_max = tf.reduce_max(volume)
normalized_volume = (volume - volume_min) / (volume_max - volume_min)
normalized_volume = tf.expand_dims(normalized_volume, axis=3)
return normalized_volume


@tf.function
Expand All @@ -270,46 +221,32 @@ def scipy_rotate(volume):
angle = random.choice(angles)
# rotate volume
volume = ndimage.rotate(volume, angle, reshape=False)
volume[volume < 0] = 0
volume[volume > 1] = 1
return volume

augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float64)
return augmented_volume


@tf.function
def blur(volume):
"""Blur the volume"""

def scipy_blur(volume):
# gaussian blur
volume = gaussian_filter(volume, sigma=1)
return volume

augmented_volume = tf.numpy_function(scipy_blur, [volume], tf.float64)
augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)
return augmented_volume


def train_preprocessing(volume, label):
"""Process training data by rotating, blur and normalizing."""
# rotate data
"""Process training data by rotating and adding a channel."""
# Rotate volume
volume = rotate(volume)
# blur data
volume = blur(volume)
# normalize
volume = normalize(volume)
volume = tf.expand_dims(volume, axis=3)
return volume, label


def validation_preprocessing(volume, label):
"""Process validation data by only normalizing."""
volume = normalize(volume)
"""Process validation data by only adding a channel."""
volume = tf.expand_dims(volume, axis=3)
return volume, label


"""
While defining the train and validation data loader, the training data is passed through
and augmentation function which randomly rotates or blurs the volume and finally normalizes
it to have values between 0 and 1. For the validation data, the volumes are only normalized.
and augmentation function which randomly rotates volume at different angles. Note that both
training and validation data are already rescaled to have values between 0 and 1.
"""

# Define data loaders.
Expand Down Expand Up @@ -345,8 +282,38 @@ def validation_preprocessing(volume, label):
print("Dimension of the CT scan is:", image.shape)
plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")


"""
Since a CT scan has many slices, let's visualize a montage of the slices.
"""


def plot_slices(num_rows, num_columns, width, height, data):
"""Plot a montage of 20 CT slices"""
data = np.rot90(np.array(data))
data = np.transpose(data)
data = np.reshape(data, (num_rows, num_columns, width, height))
rows_data, columns_data = data.shape[0], data.shape[1]
heights = [slc[0].shape[0] for slc in data]
widths = [slc.shape[1] for slc in data[0]]
fig_width = 12.0
fig_height = fig_width * sum(heights) / sum(widths)
f, axarr = plt.subplots(
rows_data,
columns_data,
figsize=(fig_width, fig_height),
gridspec_kw={"height_ratios": heights},
)
for i in range(rows_data):
for j in range(columns_data):
axarr[i, j].imshow(data[i][j], cmap="gray")
axarr[i, j].axis("off")
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
plt.show()


# Visualize montage of slices.
# 10 rows and 10 columns for 100 slices of the CT scan.
# 4 rows and 10 columns for 100 slices of the CT scan.
plot_slices(4, 10, 128, 128, image[:, :, :40])

"""
Expand All @@ -359,7 +326,7 @@ def validation_preprocessing(volume, label):


def get_model(width=128, height=128, depth=64):
"""build a 3D convolutional neural network model"""
"""Build a 3D convolutional neural network model."""

inputs = keras.Input((width, height, depth, 1))

Expand Down Expand Up @@ -413,7 +380,7 @@ def get_model(width=128, height=128, depth=64):
checkpoint_cb = keras.callbacks.ModelCheckpoint(
"3d_image_classification.h5", save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=10)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)

# Train the model, doing validation at the end of each epoch
epochs = 100
Expand Down
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Loading