Skip to content

Commit

Permalink
Migrate video classification example to keras 3.0 (keras-team#1674)
Browse files Browse the repository at this point in the history
* migrate video classification example to keras 3

* update ckpt path

* update last updated date

* unfixed the backend used in the example

* update last updated date

* runned autogen for md and ipynb files
  • Loading branch information
GuichardVictor authored Dec 8, 2023
1 parent 55407cf commit a91b152
Show file tree
Hide file tree
Showing 3 changed files with 16,507 additions and 169 deletions.
39 changes: 26 additions & 13 deletions examples/vision/ipynb/video_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>\n",
"**Date created:** 2021/05/28<br>\n",
"**Last modified:** 2023/08/28<br>\n",
"**Last modified:** 2023/12/08<br>\n",
"**Description:** Training a video classifier with transfer learning and a recurrent model on the UCF101 dataset."
]
},
Expand Down Expand Up @@ -72,7 +72,7 @@
},
"outputs": [],
"source": [
"!wget -q https://github.com/sayakpaul/Action-Recognition-in-TensorFlow/releases/download/v1.0.0/ucf101_top5.tar.gz\n",
"!!wget -q https://github.com/sayakpaul/Action-Recognition-in-TensorFlow/releases/download/v1.0.0/ucf101_top5.tar.gz\n",
"!tar xf ucf101_top5.tar.gz"
]
},
Expand All @@ -93,17 +93,17 @@
},
"outputs": [],
"source": [
"from tensorflow_docs.vis import embed\n",
"from tensorflow import keras\n",
"import os\n",
"\n",
"import keras\n",
"from imutils import paths\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
"import pandas as pd\n",
"import numpy as np\n",
"import imageio\n",
"import cv2\n",
"import os"
"from IPython.display import Image"
]
},
{
Expand Down Expand Up @@ -314,7 +314,7 @@
" num_samples = len(df)\n",
" video_paths = df[\"video_name\"].values.tolist()\n",
" labels = df[\"tag\"].values\n",
" labels = label_processor(labels[..., None]).numpy()\n",
" labels = keras.ops.convert_to_numpy(label_processor(labels[..., None]))\n",
"\n",
" # `frame_masks` and `frame_features` are what we will feed to our sequence model.\n",
" # `frame_masks` will contain a bunch of booleans denoting if a timestep is\n",
Expand All @@ -331,7 +331,13 @@
" frames = frames[None, ...]\n",
"\n",
" # Initialize placeholders to store the masks and features of the current video.\n",
" temp_frame_mask = np.zeros(shape=(1, MAX_SEQ_LENGTH,), dtype=\"bool\")\n",
" temp_frame_mask = np.zeros(\n",
" shape=(\n",
" 1,\n",
" MAX_SEQ_LENGTH,\n",
" ),\n",
" dtype=\"bool\",\n",
" )\n",
" temp_frame_features = np.zeros(\n",
" shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype=\"float32\"\n",
" )\n",
Expand Down Expand Up @@ -388,6 +394,7 @@
},
"outputs": [],
"source": [
"\n",
"# Utility for our sequence model.\n",
"def get_sequence_model():\n",
" class_vocab = label_processor.get_vocabulary()\n",
Expand Down Expand Up @@ -415,7 +422,7 @@
"\n",
"# Utility for running experiments.\n",
"def run_experiment():\n",
" filepath = \"/tmp/video_classifier\"\n",
" filepath = \"/tmp/video_classifier/ckpt.weights.h5\"\n",
" checkpoint = keras.callbacks.ModelCheckpoint(\n",
" filepath, save_weights_only=True, save_best_only=True, verbose=1\n",
" )\n",
Expand Down Expand Up @@ -471,7 +478,13 @@
"\n",
"def prepare_single_video(frames):\n",
" frames = frames[None, ...]\n",
" frame_mask = np.zeros(shape=(1, MAX_SEQ_LENGTH,), dtype=\"bool\")\n",
" frame_mask = np.zeros(\n",
" shape=(\n",
" 1,\n",
" MAX_SEQ_LENGTH,\n",
" ),\n",
" dtype=\"bool\",\n",
" )\n",
" frame_features = np.zeros(shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype=\"float32\")\n",
"\n",
" for i, batch in enumerate(frames):\n",
Expand Down Expand Up @@ -502,7 +515,7 @@
"def to_gif(images):\n",
" converted_images = images.astype(np.uint8)\n",
" imageio.mimsave(\"animation.gif\", converted_images, duration=100)\n",
" return embed.embed_file(\"animation.gif\")\n",
" return Image(\"animation.gif\")\n",
"\n",
"\n",
"test_video = np.random.choice(test_df[\"video_name\"].values.tolist())\n",
Expand All @@ -523,7 +536,7 @@
"from video frames. You could also fine-tune the pre-trained network to notice how that\n",
"affects the end results.\n",
"* For speed-accuracy trade-offs, you can try out other models present inside\n",
"`tf.keras.applications`.\n",
"`keras.applications`.\n",
"* Try different combinations of `MAX_SEQ_LENGTH` to observe how that affects the\n",
"performance.\n",
"* Train on a higher number of classes and see if you are able to get good performance.\n",
Expand Down Expand Up @@ -571,4 +584,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit a91b152

Please sign in to comment.