Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A Vision Transformer without Attention example to Keras 3 #2000

Merged
merged 4 commits into from
Dec 6, 2024

Conversation

sitamgithub-MSIT
Copy link
Contributor

@sitamgithub-MSIT sitamgithub-MSIT commented Dec 3, 2024

This PR changes the A Vision Transformer without Attention example to Keras 3.0

For example, here is the notebook link provided: https://colab.research.google.com/drive/1tgIPW6DOux-dNicegYcnG8lnWswrE0Ji?usp=sharing

cc: @fchollet

The following describes the Git difference for the changed files:

Changes:
diff --git a/examples/vision/shiftvit.py b/examples/vision/shiftvit.py
index e0c42d1d..bd053246 100644
--- a/examples/vision/shiftvit.py
+++ b/examples/vision/shiftvit.py
@@ -27,11 +27,7 @@ operation with a shifting operation.
 In this example, we minimally implement the paper with close alignement to the author's
 [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).
 
-This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can
-be installed using the following command:
-"""
-"""shell
-pip install -qq -U tensorflow-addons
+This example requires TensorFlow 2.9 or higher.
 """
 
 """
@@ -41,10 +37,10 @@ pip install -qq -U tensorflow-addons
 import numpy as np
 import matplotlib.pyplot as plt
 
+import keras
+from keras import ops
+from keras import layers
 import tensorflow as tf
-from tensorflow import keras
-from tensorflow.keras import layers
-import tensorflow_addons as tfa
 
 import pathlib
 import glob
@@ -276,13 +272,16 @@ class DropPath(layers.Layer):
     def __init__(self, drop_path_prob, **kwargs):
         super().__init__(**kwargs)
         self.drop_path_prob = drop_path_prob
+        self.seed_generator = keras.random.SeedGenerator(1337)
 
     def call(self, x, training=False):
         if training:
             keep_prob = 1 - self.drop_path_prob
-            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
-            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
-            random_tensor = tf.floor(random_tensor)
+            shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)
+            random_tensor = keep_prob + keras.random.uniform(
+                shape, 0, 1, seed=self.seed_generator
+            )
+            random_tensor = ops.floor(random_tensor)
             return (x / keep_prob) * random_tensor
         return x
 
@@ -396,17 +395,17 @@ class ShiftViTBlock(layers.Layer):
             offset_width = 0
             target_height = self.shift_pixel
             target_width = 0
-        crop = tf.image.crop_to_bounding_box(
+        crop = ops.image.crop_images(
             x,
-            offset_height=offset_height,
-            offset_width=offset_width,
+            top_cropping=offset_height,
+            left_cropping=offset_width,
             target_height=self.H - target_height,
             target_width=self.W - target_width,
         )
-        shift_pad = tf.image.pad_to_bounding_box(
+        shift_pad = ops.image.pad_images(
             crop,
-            offset_height=offset_height,
-            offset_width=offset_width,
+            top_padding=offset_height,
+            left_padding=offset_width,
             target_height=self.H,
             target_width=self.W,
         )
@@ -414,7 +413,7 @@ class ShiftViTBlock(layers.Layer):
 
     def call(self, x, training=False):
         # Split the feature maps
-        x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1)
+        x_splits = ops.split(x, indices_or_sections=self.C // self.num_div, axis=-1)
 
         # Shift the feature maps
         x_splits[0] = self.get_shift_pad(x_splits[0], mode="left")
@@ -423,7 +422,7 @@ class ShiftViTBlock(layers.Layer):
         x_splits[3] = self.get_shift_pad(x_splits[3], mode="down")
 
         # Concatenate the shifted and unshifted feature maps
-        x = tf.concat(x_splits, axis=-1)
+        x = ops.concatenate(x_splits, axis=-1)
 
         # Add the residual connection
         shortcut = x
@@ -779,7 +778,7 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
         self.lr_max = lr_max
         self.warmup_steps = warmup_steps
         self.total_steps = total_steps
-        self.pi = tf.constant(np.pi)
+        self.pi = ops.array(np.pi)
 
     def __call__(self, step):
         # Check whether the total number of steps is larger than the warmup
@@ -793,10 +792,10 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
         # `cos_annealed_lr` is a graph that increases to 1 from the initial
         # step to the warmup step. After that this graph decays to -1 at the
         # final step mark.
-        cos_annealed_lr = tf.cos(
+        cos_annealed_lr = ops.cos(
             self.pi
-            * (tf.cast(step, tf.float32) - self.warmup_steps)
-            / tf.cast(self.total_steps - self.warmup_steps, tf.float32)
+            * (ops.cast(step, dtype="float32") - self.warmup_steps)
+            / ops.cast(self.total_steps - self.warmup_steps, dtype="float32")
         )
 
         # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
@@ -821,20 +820,18 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
 
             # With the formula for a straight line (y = mx+c) build the warmup
             # schedule
-            warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
+            warmup_rate = slope * ops.cast(step, dtype="float32") + self.lr_start
 
             # When the current step is lesser that warmup steps, get the line
             # graph. When the current step is greater than the warmup steps, get
             # the scaled cos graph.
-            learning_rate = tf.where(
+            learning_rate = ops.where(
                 step < self.warmup_steps, warmup_rate, learning_rate
             )
 
         # When the current step is more that the total steps, return 0 else return
         # the calculated graph.
-        return tf.where(
-            step > self.total_steps, 0.0, learning_rate, name="learning_rate"
-        )
+        return ops.where(step > self.total_steps, 0.0, learning_rate)
 
     def get_config(self):
         config = {
@@ -871,7 +868,7 @@ scheduled_lrs = WarmUpCosine(
 )
 
 # Get the optimizer.
-optimizer = tfa.optimizers.AdamW(
+optimizer = keras.optimizers.AdamW(
     learning_rate=scheduled_lrs, weight_decay=config.weight_decay
 )
 
@@ -913,7 +910,7 @@ Since we created the model by Subclassing, we can't save the model in HDF5 forma
 
 It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well.
 """
-model.save("ShiftViT")
+model.export("ShiftViT")
 
 """
 ## Model inference
@@ -932,12 +929,9 @@ unzip -q inference_set.zip
 """
 **Load saved model**
 """
-# Custom objects are not included when the model is saved.
-# At loading time, these objects need to be passed for reconstruction of the model
-saved_model = tf.keras.models.load_model(
-    "ShiftViT",
-    custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW},
-)
+# Using TFSMLayer to reload the TF SavedModel as a Keras layer.
+# This is not limited to SavedModels that originate from Keras – it will work with any SavedModel, e.g. TF-Hub models.
+saved_model = keras.layers.TFSMLayer("ShiftViT", call_endpoint="serving_default")
 
 """
 **Utility functions for inference**
@@ -952,9 +946,9 @@ def process_image(img_path):
     img = tf.io.decode_jpeg(img, channels=3)
 
     # resize image to match input size accepted by model
-    # use `method` as `nearest` to preserve dtype of input passed to `resize()`
-    img = tf.image.resize(
-        img, [config.input_shape[0], config.input_shape[1]], method="nearest"
+    # use `interpolation` as `nearest` to preserve dtype of input passed to `resize()`
+    img = ops.image.resize(
+        img, [config.input_shape[0], config.input_shape[1]], interpolation="nearest"
     )
     return img
 
@@ -976,10 +970,12 @@ def create_tf_dataset(image_dir):
 
 def predict(predict_ds):
     # ShiftViT model returns logits (non-normalized predictions)
-    logits = saved_model.predict(predict_ds)
+    model = keras.Sequential([saved_model])
+    output_dict = model.predict(predict_ds)
+    logits = list(output_dict.values())[0]
 
     # normalize predictions by calling softmax()
-    probabilities = tf.nn.softmax(logits)
+    probabilities = ops.softmax(logits)
     return probabilities
 
 
(END)

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks great, thank you!

Note that you can use the built-in Keras gelu as well instead of tf.nn.gelu.

Aside from that, you can add the generated files.

@fchollet
Copy link
Contributor

fchollet commented Dec 5, 2024

Also please update scripts/examples_master.py as well.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- please add the generated files 👍

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@fchollet fchollet merged commit 4a08dfb into keras-team:master Dec 6, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants