Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A Vision Transformer without Attention example to Keras 3 #2000

Merged
merged 4 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 49 additions & 70 deletions examples/vision/ipynb/shiftvit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)<br>\n",
"**Date created:** 2022/02/24<br>\n",
"**Last modified:** 2022/10/15<br>\n",
"**Last modified:** 2024/12/06<br>\n",
"**Description:** A minimal implementation of ShiftViT."
]
},
Expand Down Expand Up @@ -39,19 +39,7 @@
"In this example, we minimally implement the paper with close alignement to the author's\n",
"[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).\n",
"\n",
"This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can\n",
"be installed using the following command:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"!pip install -qq -U tensorflow-addons"
"This example requires TensorFlow 2.9 or higher."
]
},
{
Expand All @@ -74,10 +62,10 @@
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import keras\n",
"from keras import ops\n",
"from keras import layers\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import tensorflow_addons as tfa\n",
"\n",
"import pathlib\n",
"import glob\n",
Expand Down Expand Up @@ -237,8 +225,7 @@
" layers.Rescaling(1 / 255.0),\n",
" ]\n",
" )\n",
" return data_augmentation\n",
""
" return data_augmentation\n"
]
},
{
Expand Down Expand Up @@ -341,7 +328,7 @@
" [\n",
" layers.Dense(\n",
" units=initial_filters,\n",
" activation=tf.nn.gelu,\n",
" activation=\"gelu\",\n",
" ),\n",
" layers.Dropout(rate=self.mlp_dropout_rate),\n",
" layers.Dense(units=input_channels),\n",
Expand All @@ -351,8 +338,7 @@
"\n",
" def call(self, x):\n",
" x = self.mlp(x)\n",
" return x\n",
""
" return x\n"
]
},
{
Expand Down Expand Up @@ -389,16 +375,18 @@
" def __init__(self, drop_path_prob, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.drop_path_prob = drop_path_prob\n",
" self.seed_generator = keras.random.SeedGenerator(1337)\n",
"\n",
" def call(self, x, training=False):\n",
" if training:\n",
" keep_prob = 1 - self.drop_path_prob\n",
" shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)\n",
" random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)\n",
" random_tensor = tf.floor(random_tensor)\n",
" shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)\n",
" random_tensor = keep_prob + keras.random.uniform(\n",
" shape, 0, 1, seed=self.seed_generator\n",
" )\n",
" random_tensor = ops.floor(random_tensor)\n",
" return (x / keep_prob) * random_tensor\n",
" return x\n",
""
" return x\n"
]
},
{
Expand Down Expand Up @@ -523,25 +511,25 @@
" offset_width = 0\n",
" target_height = self.shift_pixel\n",
" target_width = 0\n",
" crop = tf.image.crop_to_bounding_box(\n",
" crop = ops.image.crop_images(\n",
" x,\n",
" offset_height=offset_height,\n",
" offset_width=offset_width,\n",
" top_cropping=offset_height,\n",
" left_cropping=offset_width,\n",
" target_height=self.H - target_height,\n",
" target_width=self.W - target_width,\n",
" )\n",
" shift_pad = tf.image.pad_to_bounding_box(\n",
" shift_pad = ops.image.pad_images(\n",
" crop,\n",
" offset_height=offset_height,\n",
" offset_width=offset_width,\n",
" top_padding=offset_height,\n",
" left_padding=offset_width,\n",
" target_height=self.H,\n",
" target_width=self.W,\n",
" )\n",
" return shift_pad\n",
"\n",
" def call(self, x, training=False):\n",
" # Split the feature maps\n",
" x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1)\n",
" x_splits = ops.split(x, indices_or_sections=self.C // self.num_div, axis=-1)\n",
"\n",
" # Shift the feature maps\n",
" x_splits[0] = self.get_shift_pad(x_splits[0], mode=\"left\")\n",
Expand All @@ -550,13 +538,12 @@
" x_splits[3] = self.get_shift_pad(x_splits[3], mode=\"down\")\n",
"\n",
" # Concatenate the shifted and unshifted feature maps\n",
" x = tf.concat(x_splits, axis=-1)\n",
" x = ops.concatenate(x_splits, axis=-1)\n",
"\n",
" # Add the residual connection\n",
" shortcut = x\n",
" x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training)\n",
" return x\n",
""
" return x\n"
]
},
{
Expand Down Expand Up @@ -622,8 +609,7 @@
" # Apply the patch merging algorithm on the feature maps\n",
" x = self.layer_norm(x)\n",
" x = self.reduction(x)\n",
" return x\n",
""
" return x\n"
]
},
{
Expand Down Expand Up @@ -737,8 +723,7 @@
" \"mlp_expand_ratio\": self.mlp_expand_ratio,\n",
" }\n",
" )\n",
" return config\n",
""
" return config\n"
]
},
{
Expand Down Expand Up @@ -903,8 +888,7 @@
" x = stage(x, training=False)\n",
" x = self.global_avg_pool(x)\n",
" logits = self.classifier(x)\n",
" return logits\n",
""
" return logits\n"
]
},
{
Expand Down Expand Up @@ -979,7 +963,7 @@
" self.lr_max = lr_max\n",
" self.warmup_steps = warmup_steps\n",
" self.total_steps = total_steps\n",
" self.pi = tf.constant(np.pi)\n",
" self.pi = ops.array(np.pi)\n",
"\n",
" def __call__(self, step):\n",
" # Check whether the total number of steps is larger than the warmup\n",
Expand All @@ -993,10 +977,10 @@
" # `cos_annealed_lr` is a graph that increases to 1 from the initial\n",
" # step to the warmup step. After that this graph decays to -1 at the\n",
" # final step mark.\n",
" cos_annealed_lr = tf.cos(\n",
" cos_annealed_lr = ops.cos(\n",
" self.pi\n",
" * (tf.cast(step, tf.float32) - self.warmup_steps)\n",
" / tf.cast(self.total_steps - self.warmup_steps, tf.float32)\n",
" * (ops.cast(step, dtype=\"float32\") - self.warmup_steps)\n",
" / ops.cast(self.total_steps - self.warmup_steps, dtype=\"float32\")\n",
" )\n",
"\n",
" # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes\n",
Expand All @@ -1021,20 +1005,18 @@
"\n",
" # With the formula for a straight line (y = mx+c) build the warmup\n",
" # schedule\n",
" warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start\n",
" warmup_rate = slope * ops.cast(step, dtype=\"float32\") + self.lr_start\n",
"\n",
" # When the current step is lesser that warmup steps, get the line\n",
" # graph. When the current step is greater than the warmup steps, get\n",
" # the scaled cos graph.\n",
" learning_rate = tf.where(\n",
" learning_rate = ops.where(\n",
" step < self.warmup_steps, warmup_rate, learning_rate\n",
" )\n",
"\n",
" # When the current step is more that the total steps, return 0 else return\n",
" # the calculated graph.\n",
" return tf.where(\n",
" step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
" )\n",
" return ops.where(step > self.total_steps, 0.0, learning_rate)\n",
"\n",
" def get_config(self):\n",
" config = {\n",
Expand All @@ -1043,8 +1025,7 @@
" \"total_steps\": self.total_steps,\n",
" \"warmup_steps\": self.warmup_steps,\n",
" }\n",
" return config\n",
""
" return config\n"
]
},
{
Expand Down Expand Up @@ -1085,7 +1066,7 @@
")\n",
"\n",
"# Get the optimizer.\n",
"optimizer = tfa.optimizers.AdamW(\n",
"optimizer = keras.optimizers.AdamW(\n",
" learning_rate=scheduled_lrs, weight_decay=config.weight_decay\n",
")\n",
"\n",
Expand Down Expand Up @@ -1142,7 +1123,7 @@
},
"outputs": [],
"source": [
"model.save(\"ShiftViT\")"
"model.export(\"ShiftViT\")"
]
},
{
Expand Down Expand Up @@ -1192,12 +1173,9 @@
},
"outputs": [],
"source": [
"# Custom objects are not included when the model is saved.\n",
"# At loading time, these objects need to be passed for reconstruction of the model\n",
"saved_model = tf.keras.models.load_model(\n",
" \"ShiftViT\",\n",
" custom_objects={\"WarmUpCosine\": WarmUpCosine, \"AdamW\": tfa.optimizers.AdamW},\n",
")"
"# Using TFSMLayer to reload the TF SavedModel as a Keras layer.\n",
"# This is not limited to SavedModels that originate from Keras – it will work with any SavedModel, e.g. TF-Hub models.\n",
"saved_model = keras.layers.TFSMLayer(\"ShiftViT\", call_endpoint=\"serving_default\")"
]
},
{
Expand Down Expand Up @@ -1226,9 +1204,9 @@
" img = tf.io.decode_jpeg(img, channels=3)\n",
"\n",
" # resize image to match input size accepted by model\n",
" # use `method` as `nearest` to preserve dtype of input passed to `resize()`\n",
" img = tf.image.resize(\n",
" img, [config.input_shape[0], config.input_shape[1]], method=\"nearest\"\n",
" # use `interpolation` as `nearest` to preserve dtype of input passed to `resize()`\n",
" img = ops.image.resize(\n",
" img, [config.input_shape[0], config.input_shape[1]], interpolation=\"nearest\"\n",
" )\n",
" return img\n",
"\n",
Expand All @@ -1250,10 +1228,12 @@
"\n",
"def predict(predict_ds):\n",
" # ShiftViT model returns logits (non-normalized predictions)\n",
" logits = saved_model.predict(predict_ds)\n",
" model = keras.Sequential([saved_model])\n",
" output_dict = model.predict(predict_ds)\n",
" logits = list(output_dict.values())[0]\n",
"\n",
" # normalize predictions by calling softmax()\n",
" probabilities = tf.nn.softmax(logits)\n",
" probabilities = ops.softmax(logits)\n",
" return probabilities\n",
"\n",
"\n",
Expand All @@ -1270,8 +1250,7 @@
" config.label_map[label]: np.round((probabilities[label]) * 100, 2)\n",
" for label in labels\n",
" }\n",
" return confidences\n",
""
" return confidences\n"
]
},
{
Expand Down Expand Up @@ -1398,4 +1377,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading