Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A Vision Transformer without Attention example to Keras 3 #2000

Merged
merged 4 commits into from
Dec 6, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
shiftvit migrated to keras 3
  • Loading branch information
sitamgithub-MSIT committed Dec 3, 2024
commit 66864fea0e145474a2a4aa7d120c3aaa34fbd3e5
78 changes: 37 additions & 41 deletions examples/vision/shiftvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
In this example, we minimally implement the paper with close alignement to the author's
[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).

This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can
be installed using the following command:
"""
"""shell
pip install -qq -U tensorflow-addons
This example requires TensorFlow 2.9 or higher.
"""

"""
Expand All @@ -41,10 +37,10 @@
import numpy as np
import matplotlib.pyplot as plt

import keras
from keras import ops
from keras import layers
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

import pathlib
import glob
Expand Down Expand Up @@ -276,13 +272,16 @@ class DropPath(layers.Layer):
def __init__(self, drop_path_prob, **kwargs):
super().__init__(**kwargs)
self.drop_path_prob = drop_path_prob
self.seed_generator = keras.random.SeedGenerator(1337)

def call(self, x, training=False):
if training:
keep_prob = 1 - self.drop_path_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)
random_tensor = keep_prob + keras.random.uniform(
shape, 0, 1, seed=self.seed_generator
)
random_tensor = ops.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x

Expand Down Expand Up @@ -396,25 +395,25 @@ def get_shift_pad(self, x, mode):
offset_width = 0
target_height = self.shift_pixel
target_width = 0
crop = tf.image.crop_to_bounding_box(
crop = ops.image.crop_images(
x,
offset_height=offset_height,
offset_width=offset_width,
top_cropping=offset_height,
left_cropping=offset_width,
target_height=self.H - target_height,
target_width=self.W - target_width,
)
shift_pad = tf.image.pad_to_bounding_box(
shift_pad = ops.image.pad_images(
crop,
offset_height=offset_height,
offset_width=offset_width,
top_padding=offset_height,
left_padding=offset_width,
target_height=self.H,
target_width=self.W,
)
return shift_pad

def call(self, x, training=False):
# Split the feature maps
x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1)
x_splits = ops.split(x, indices_or_sections=self.C // self.num_div, axis=-1)

# Shift the feature maps
x_splits[0] = self.get_shift_pad(x_splits[0], mode="left")
Expand All @@ -423,7 +422,7 @@ def call(self, x, training=False):
x_splits[3] = self.get_shift_pad(x_splits[3], mode="down")

# Concatenate the shifted and unshifted feature maps
x = tf.concat(x_splits, axis=-1)
x = ops.concatenate(x_splits, axis=-1)

# Add the residual connection
shortcut = x
Expand Down Expand Up @@ -779,7 +778,7 @@ def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
self.lr_max = lr_max
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.pi = tf.constant(np.pi)
self.pi = ops.array(np.pi)

def __call__(self, step):
# Check whether the total number of steps is larger than the warmup
Expand All @@ -793,10 +792,10 @@ def __call__(self, step):
# `cos_annealed_lr` is a graph that increases to 1 from the initial
# step to the warmup step. After that this graph decays to -1 at the
# final step mark.
cos_annealed_lr = tf.cos(
cos_annealed_lr = ops.cos(
self.pi
* (tf.cast(step, tf.float32) - self.warmup_steps)
/ tf.cast(self.total_steps - self.warmup_steps, tf.float32)
* (ops.cast(step, dtype="float32") - self.warmup_steps)
/ ops.cast(self.total_steps - self.warmup_steps, dtype="float32")
)

# Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
Expand All @@ -821,20 +820,18 @@ def __call__(self, step):

# With the formula for a straight line (y = mx+c) build the warmup
# schedule
warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
warmup_rate = slope * ops.cast(step, dtype="float32") + self.lr_start

# When the current step is lesser that warmup steps, get the line
# graph. When the current step is greater than the warmup steps, get
# the scaled cos graph.
learning_rate = tf.where(
learning_rate = ops.where(
step < self.warmup_steps, warmup_rate, learning_rate
)

# When the current step is more that the total steps, return 0 else return
# the calculated graph.
return tf.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
return ops.where(step > self.total_steps, 0.0, learning_rate)

def get_config(self):
config = {
Expand Down Expand Up @@ -871,7 +868,7 @@ def get_config(self):
)

# Get the optimizer.
optimizer = tfa.optimizers.AdamW(
optimizer = keras.optimizers.AdamW(
learning_rate=scheduled_lrs, weight_decay=config.weight_decay
)

Expand Down Expand Up @@ -913,7 +910,7 @@ def get_config(self):

It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well.
"""
model.save("ShiftViT")
model.export("ShiftViT")

"""
## Model inference
Expand All @@ -932,12 +929,9 @@ def get_config(self):
"""
**Load saved model**
"""
# Custom objects are not included when the model is saved.
# At loading time, these objects need to be passed for reconstruction of the model
saved_model = tf.keras.models.load_model(
"ShiftViT",
custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW},
)
# Using TFSMLayer to reload the TF SavedModel as a Keras layer.
# This is not limited to SavedModels that originate from Keras – it will work with any SavedModel, e.g. TF-Hub models.
saved_model = keras.layers.TFSMLayer("ShiftViT", call_endpoint="serving_default")

"""
**Utility functions for inference**
Expand All @@ -952,9 +946,9 @@ def process_image(img_path):
img = tf.io.decode_jpeg(img, channels=3)

# resize image to match input size accepted by model
# use `method` as `nearest` to preserve dtype of input passed to `resize()`
img = tf.image.resize(
img, [config.input_shape[0], config.input_shape[1]], method="nearest"
# use `interpolation` as `nearest` to preserve dtype of input passed to `resize()`
img = ops.image.resize(
img, [config.input_shape[0], config.input_shape[1]], interpolation="nearest"
)
return img

Expand All @@ -976,10 +970,12 @@ def create_tf_dataset(image_dir):

def predict(predict_ds):
# ShiftViT model returns logits (non-normalized predictions)
logits = saved_model.predict(predict_ds)
model = keras.Sequential([saved_model])
output_dict = model.predict(predict_ds)
logits = list(output_dict.values())[0]

# normalize predictions by calling softmax()
probabilities = tf.nn.softmax(logits)
probabilities = ops.softmax(logits)
return probabilities


Expand Down
Loading