{ "cells": [ { "cell_type": "markdown", "id": "lovely-budapest", "metadata": {}, "source": [ "# This is a notebook that shows how to produce Grad-CAM visualizations for ALBEF" ] }, { "cell_type": "markdown", "id": "czech-surprise", "metadata": {}, "source": [ "# 1. Set the paths for model checkpoint and configuration" ] }, { "cell_type": "code", "execution_count": 37, "id": "institutional-sarah", "metadata": {}, "outputs": [], "source": [ "model_path = '../VL/Example/refcoco.pth'\n", "bert_config_path = 'configs/config_bert.json'\n", "use_cuda = False" ] }, { "cell_type": "markdown", "id": "lovely-passage", "metadata": {}, "source": [ "# 2. Model defination" ] }, { "cell_type": "code", "execution_count": 38, "id": "documented-symbol", "metadata": {}, "outputs": [], "source": [ "from functools import partial\n", "from models.vit import VisionTransformer\n", "from models.xbert import BertConfig, BertModel\n", "from models.tokenization_bert import BertTokenizer\n", "\n", "import torch\n", "from torch import nn\n", "from torchvision import transforms\n", "\n", "import json\n", "\n", "class VL_Transformer_ITM(nn.Module):\n", " def __init__(self, \n", " text_encoder = None,\n", " config_bert = ''\n", " ):\n", " super().__init__()\n", " \n", " bert_config = BertConfig.from_json_file(config_bert)\n", "\n", " self.visual_encoder = VisionTransformer(\n", " img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, \n", " mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) \n", "\n", " self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) \n", " \n", " self.itm_head = nn.Linear(768, 2) \n", "\n", " \n", " def forward(self, image, text):\n", " image_embeds = self.visual_encoder(image) \n", "\n", " image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)\n", "\n", " output = self.text_encoder(text.input_ids, \n", " attention_mask = text.attention_mask,\n", " encoder_hidden_states = image_embeds,\n", " encoder_attention_mask = image_atts, \n", " return_dict = True,\n", " ) \n", " \n", " vl_embeddings = output.last_hidden_state[:,0,:]\n", " vl_output = self.itm_head(vl_embeddings) \n", " return vl_output" ] }, { "cell_type": "markdown", "id": "renewable-eight", "metadata": {}, "source": [ "# 3. Text Preprocessing" ] }, { "cell_type": "code", "execution_count": 39, "id": "optional-brooklyn", "metadata": {}, "outputs": [], "source": [ "import re\n", "\n", "def pre_caption(caption,max_words=30):\n", " caption = re.sub(\n", " r\"([,.'!?\\\"()*#:;~])\",\n", " '',\n", " caption.lower(),\n", " ).replace('-', ' ').replace('/', ' ')\n", "\n", " caption = re.sub(\n", " r\"\\s{2,}\",\n", " ' ',\n", " caption,\n", " )\n", " caption = caption.rstrip('\\n') \n", " caption = caption.strip(' ')\n", "\n", " #truncate caption\n", " caption_words = caption.split(' ')\n", " if len(caption_words)>max_words:\n", " caption = ' '.join(caption_words[:max_words]) \n", " return caption" ] }, { "cell_type": "markdown", "id": "based-roads", "metadata": {}, "source": [ "# 4. Image Preprocessing and Postpressing" ] }, { "cell_type": "code", "execution_count": 40, "id": "subsequent-flesh", "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "\n", "import cv2\n", "import numpy as np\n", "\n", "from skimage import transform as skimage_transform\n", "from scipy.ndimage import filters\n", "from matplotlib import pyplot as plt\n", "\n", "def getAttMap(img, attMap, blur = True, overlap = True):\n", " attMap -= attMap.min()\n", " if attMap.max() > 0:\n", " attMap /= attMap.max()\n", " attMap = skimage_transform.resize(attMap, (img.shape[:2]), order = 3, mode = 'constant')\n", " if blur:\n", " attMap = filters.gaussian_filter(attMap, 0.02*max(img.shape[:2]))\n", " attMap -= attMap.min()\n", " attMap /= attMap.max()\n", " cmap = plt.get_cmap('jet')\n", " attMapV = cmap(attMap)\n", " attMapV = np.delete(attMapV, 3, 2)\n", " if overlap:\n", " attMap = 1*(1-attMap**0.7).reshape(attMap.shape + (1,))*img + (attMap**0.7).reshape(attMap.shape+(1,)) * attMapV\n", " return attMap\n", "\n", "\n", "normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((384,384),interpolation=Image.BICUBIC),\n", " transforms.ToTensor(),\n", " normalize,\n", "]) " ] }, { "cell_type": "markdown", "id": "occasional-trace", "metadata": {}, "source": [ "# 5. Load model and tokenizer" ] }, { "cell_type": "code", "execution_count": 41, "id": "qualified-sleep", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.pooler.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'bert.pooler.dense.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.self.query.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", "\n", "model = VL_Transformer_ITM(text_encoder='bert-base-uncased', config_bert=bert_config_path)\n", "\n", "checkpoint = torch.load(model_path, map_location='cpu') \n", "msg = model.load_state_dict(checkpoint,strict=False)\n", "model.eval()\n", "\n", "block_num = 8\n", "\n", "model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.save_attention = True\n", "\n", "if use_cuda:\n", " model.cuda() " ] }, { "cell_type": "markdown", "id": "apparent-captain", "metadata": {}, "source": [ "# 6. Load Image and Text" ] }, { "cell_type": "code", "execution_count": 42, "id": "finite-angle", "metadata": {}, "outputs": [], "source": [ "image_path = 'examples/image0.jpg'\n", "image_pil = Image.open(image_path).convert('RGB') \n", "image = transform(image_pil).unsqueeze(0) \n", "\n", "caption = 'the woman is working on her computer at the desk'\n", "text = pre_caption(caption)\n", "text_input = tokenizer(text, return_tensors=\"pt\")\n", "\n", "if use_cuda:\n", " image = image.cuda()\n", " text_input = text_input.to(image.device)" ] }, { "cell_type": "markdown", "id": "gorgeous-matrix", "metadata": {}, "source": [ "# 7. Compute GradCAM" ] }, { "cell_type": "code", "execution_count": 43, "id": "driven-termination", "metadata": {}, "outputs": [], "source": [ "output = model(image, text_input)\n", "loss = output[:,1].sum()\n", "\n", "model.zero_grad()\n", "loss.backward() \n", "\n", "with torch.no_grad():\n", " mask = text_input.attention_mask.view(text_input.attention_mask.size(0),1,-1,1,1)\n", "\n", " grads=model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attn_gradients()\n", " cams=model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attention_map()\n", "\n", " cams = cams[:, :, :, 1:].reshape(image.size(0), 12, -1, 24, 24) * mask\n", " grads = grads[:, :, :, 1:].clamp(0).reshape(image.size(0), 12, -1, 24, 24) * mask\n", "\n", " gradcam = cams * grads\n", " gradcam = gradcam[0].mean(0).cpu().detach()" ] }, { "cell_type": "markdown", "id": "abroad-northern", "metadata": {}, "source": [ "# 8. Visualize GradCam for each word" ] }, { "cell_type": "code", "execution_count": null, "id": "fourth-cache", "metadata": {}, "outputs": [], "source": [ "num_image = len(text_input.input_ids[0]) \n", "fig, ax = plt.subplots(num_image, 1, figsize=(15,5*num_image))\n", "\n", "rgb_image = cv2.imread(image_path)[:, :, ::-1]\n", "rgb_image = np.float32(rgb_image) / 255\n", "\n", "ax[0].imshow(rgb_image)\n", "ax[0].set_yticks([])\n", "ax[0].set_xticks([])\n", "ax[0].set_xlabel(\"Image\")\n", " \n", "for i,token_id in enumerate(text_input.input_ids[0][1:]):\n", " word = tokenizer.decode([token_id])\n", " gradcam_image = getAttMap(rgb_image, gradcam[i+1])\n", " ax[i+1].imshow(gradcam_image)\n", " ax[i+1].set_yticks([])\n", " ax[i+1].set_xticks([])\n", " ax[i+1].set_xlabel(word)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }