Skip to content

Commit

Permalink
updating BASNet with prediction output
Browse files Browse the repository at this point in the history
  • Loading branch information
hamidriasat committed Jun 18, 2023
1 parent f396cc9 commit 56bc4ff
Showing 1 changed file with 60 additions and 53 deletions.
113 changes: 60 additions & 53 deletions examples/vision/basnet_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Title: Highly accurate boundaries segmentation using BASNet
Author: [Hamid Ali](https://github.com/hamidriasat)
Date created: 2022/05/30
Last modified: 2022/05/30
Date created: 2023/05/30
Last modified: 2023/06/19
Description: Boundaries aware segmentation model trained on the DUTS dataset.
Accelerator: GPU
"""
Expand All @@ -19,14 +19,15 @@
### References:
- [Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704)
- [BASNet Keras Implementation](https://github.com/hamidriasat/BASNet/tree/basnet_keras)
- [Learning to Detect Salient Objects with Image-level Supervision](https://openaccess.thecvf.com/content_cvpr_2017/html/Wang_Learning_to_Detect_CVPR_2017_paper.html)
"""

"""
## Download the Data
We will use the [DUTS-TE](http://saliencydetection.net/duts/) dataset for training. It has 5,019
images but we will use 100 for training and validation to save notebook running time. DUTS is
images but we will use 140 for training and validation to save notebook running time. DUTS is
relatively large salient object segmentation dataset. which contain diversified textures and
structures common to real-world images in both foreground and background.
"""
Expand Down Expand Up @@ -59,14 +60,14 @@
"""
## Create TensorFlow Dataset
We will use `load_paths()` to load and split 100 paths into train and validation set, and
We will use `load_paths()` to load and split 140 paths into train and validation set, and
`load_dataset()` to convert paths into `tf.data.Dataset` object.
"""


def load_paths(path, split_ratio):
images = sorted(glob(os.path.join(path, "DUTS-TE-Image/*")))[:100]
masks = sorted(glob(os.path.join(path, "DUTS-TE-Mask/*")))[:100]
images = sorted(glob(os.path.join(path, "DUTS-TE-Image/*")))[:140]
masks = sorted(glob(os.path.join(path, "DUTS-TE-Mask/*")))[:140]
len_ = int(len(images) * split_ratio)
return (images[:len_], masks[:len_]), (images[len_:], masks[len_:])

Expand All @@ -78,14 +79,14 @@ def read_image(path, size, mode):
return x


def preprocess(x, y, img_size, out_classes):
def preprocess(x_batch, y_batch, img_size, out_classes):
def f(_x, _y):
_x, _y = _x.decode(), _y.decode()
_x = read_image(_x, (img_size, img_size), mode="rgb") # image
_y = read_image(_y, (img_size, img_size), mode="grayscale") # mask
return _x, _y

images, masks = tf.numpy_function(f, [x, y], [tf.float32, tf.float32])
images, masks = tf.numpy_function(f, [x_batch, y_batch], [tf.float32, tf.float32])
images.set_shape([img_size, img_size, 3])
masks.set_shape([img_size, img_size, out_classes])
return images, masks
Expand Down Expand Up @@ -133,7 +134,7 @@ def display(display_list):


for image, mask in val_dataset.take(1):
display([image[2], mask[2]])
display([image[0], mask[0]])

"""
## Analyze Mask
Expand All @@ -145,9 +146,9 @@ def display(display_list):
boundaries and fine structures.
"""

print(f"Unique values count: {len(np.unique((mask[2] * 255)))}")
print(f"Unique values count: {len(np.unique((mask[0] * 255)))}")
print("Unique values:")
print(np.unique((mask[2] * 255)).astype(int))
print(np.unique((mask[0] * 255)).astype(int))

"""
## Building the BASNet Model
Expand All @@ -160,12 +161,12 @@ def display(display_list):
"""


def basic_block(x_in, filters, stride=1, down_sample=None, no_relu=False):
def basic_block(x_input, filters, stride=1, down_sample=None, activation=None):
"""Creates a residual(identity) block with two 3*3 convolutions."""
residual = x_in
residual = x_input

x = layers.Conv2D(filters, (3, 3), strides=stride, padding="same", use_bias=False)(
x_in
x_input
)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
Expand All @@ -180,22 +181,22 @@ def basic_block(x_in, filters, stride=1, down_sample=None, no_relu=False):

x = layers.Add()([x, residual])

if not no_relu:
x = layers.Activation("relu")(x)
if activation is not None:
x = layers.Activation(activation)(x)

return x


def convolution_block(x_in, filters, dilation=1):
def convolution_block(x_input, filters, dilation=1):
"""Apply convolution + batch normalization + relu layer."""
x = layers.Conv2D(filters, (3, 3), padding="same", dilation_rate=dilation)(x_in)
x = layers.Conv2D(filters, (3, 3), padding="same", dilation_rate=dilation)(x_input)
x = layers.BatchNormalization()(x)
return layers.Activation("relu")(x)


def segmentation_head(x_in, out_classes, final_size):
def segmentation_head(x_input, out_classes, final_size):
"""Map each decoder stage output to model output classes."""
x = layers.Conv2D(out_classes, kernel_size=(3, 3), padding="same")(x_in)
x = layers.Conv2D(out_classes, kernel_size=(3, 3), padding="same")(x_input)

if final_size is not None:
x = layers.Resizing(final_size[0], final_size[1])(x)
Expand Down Expand Up @@ -249,15 +250,15 @@ def basnet_predict(input_shape, out_classes):
x = layers.Activation("relu")(x)
else: # Last 2 stages consist of three basic resnet blocks.
x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x)
x = basic_block(x, filters * 8)
x = basic_block(x, filters * 8)
x = basic_block(x, filters * 8)
x = basic_block(x, filters=filters * 8, activation="relu")
x = basic_block(x, filters=filters * 8, activation="relu")
x = basic_block(x, filters=filters * 8, activation="relu")
encoder_blocks.append(x)

# -------------Bridge-------------
x = convolution_block(x, filters * 8, dilation=2)
x = convolution_block(x, filters * 8, dilation=2)
x = convolution_block(x, filters * 8, dilation=2)
x = convolution_block(x, filters=filters * 8, dilation=2)
x = convolution_block(x, filters=filters * 8, dilation=2)
x = convolution_block(x, filters=filters * 8, dilation=2)
encoder_blocks.append(x)

# -------------Decoder-------------
Expand All @@ -268,18 +269,9 @@ def basnet_predict(input_shape, out_classes):
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)

x = layers.concatenate([encoder_blocks[i], x], axis=-1)
x = convolution_block(
x,
filters * 8,
)
x = convolution_block(
x,
filters * 8,
)
x = convolution_block(
x,
filters * 8,
)
x = convolution_block(x, filters=filters * 8)
x = convolution_block(x, filters=filters * 8)
x = convolution_block(x, filters=filters * 8)
decoder_blocks.append(x)

decoder_blocks.reverse() # Change order from last to first decoder stage.
Expand Down Expand Up @@ -317,25 +309,19 @@ def basnet_rrm(base_model, out_classes):

encoder_blocks = []
for _ in range(num_stages):
x = convolution_block(
x,
filters,
)
x = convolution_block(x, filters=filters)
encoder_blocks.append(x)
x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x)

# -------------Bridge--------------
x = convolution_block(
x,
filters,
)
x = convolution_block(x, filters=filters)

# -------------Decoder--------------
for i in reversed(range(num_stages)):
shape = keras.backend.int_shape(x)
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
x = layers.concatenate([encoder_blocks[i], x], axis=-1)
x = convolution_block(x, filters)
x = convolution_block(x, filters=filters)

x = segmentation_head(x, out_classes, None) # Segmentation head.

Expand Down Expand Up @@ -421,10 +407,7 @@ def call(self, y_true, y_pred):
) # Create model.
basnet_model.summary() # Show model summary.

optimizer = keras.optimizers.Adam(
learning_rate=1e-4,
epsilon=1e-8,
)
optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)
# Compile model.
basnet_model.compile(
loss=BasnetLoss(),
Expand All @@ -446,6 +429,30 @@ def call(self, y_true, y_pred):
evaluated on DUTS-TE dataset and achieved a mean absolute error of `0.042`.
Since BASNet is a deep model and cannot be trained in a short amount of time which is a
requirement for keras example notebook, so we will display output from [paper](https://arxiv.org/abs/2101.04704)
![](https://i.imgur.com/4r8crKm.png)
requirement for keras example notebook, so we will load pretrained weights from [here](https://github.com/hamidriasat/BASNet/tree/basnet_keras)
to show model prediction. Due to computer power limitation this model was trained for 60k
iterations but it still demonstrates its capabilities. For further details about
trainings parameters please check given link.
"""

"""shell
!gdown 1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg
"""


def normalize_output(prediction):
max_value = np.max(prediction)
min_value = np.min(prediction)
return (prediction - min_value) / (max_value - min_value)


# Load weights.
basnet_model.load_weights("./basnet_weights.h5")

"""
### Make Predictions
"""

for image, mask in val_dataset.take(1):
pred_mask = basnet_model.predict(image)
display([image[0], mask[0], normalize_output(pred_mask[0][0])])

0 comments on commit 56bc4ff

Please sign in to comment.