This project provides the zero-shot classification task on ILSVRC dataset using multi-modality large-scale model pretrained on Noah-Wukong dataset. Model structure as follows:
Models | Embedding dimension | Image encoder | similarity | # vis_token | checkpoints |
---|---|---|---|---|---|
Wukong_ViT-B^G | 512 | Vit-b/32 | Global | / | download |
Wukong_ViT-B^F | 512 | Vit-b/32 | Token-wise | / | download |
Wukong_ViT-B | 512 | Vit-b/32 | Token-wise | 12 | download |
Wukong_ViT-L^G | 768 | Vit-L/14 | Global | / | download |
Wukong_ViT-L ^F | 768 | Vit-L/14 | Token-wise | / | download |
Wukong_ViT-L | 768 | Vit-L/14 | Token-wise | 24 | download |
More benchmark of the multi-modality modal please refer to Noah-Wukong Benchmark
- Hardware
- Ascend processor
- Framework
- Tutorial
- Download ILSVRC dataset and organize the file as follows:
.
└── data_root
├── class1
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── class2
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── class3
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── classN
├── ...
- Download corresponding Chinese class name file imagenet_class_name_zh.json and place it the same folder with eval.py .
Download following files and place them under src/tools/
- English: bpe_simple_vocab_16e6.txt.gz
- Chinese: vocab_zh.txt
Download prompt filezh_templates.txtto src/tools/.This file defines the prompts used in zero-shot classification task. The number of prompts can be modified according to time/performance balance. Custom prompts are also allowed.
Download corresponding pretrained checkpoint files following links in the table.
Run eval.py to do zero-shot classification, each model has its config file under src/config/ folder.
python eval.py --config_path [config_path] --ckpt_path [ckpt_path] --dataset_path [/path/to/data_root] --batch_size [batch size]
evaluation result is something like this
INFO:main:correct @1: 51.51; correct @5: 78.33
Detailed zero-shot classification performance is as below:
single@1 | single@5 | embed(80)@1 | embed(80)@5 | |
---|---|---|---|---|
ViT-B-G | 44.68 | 71.19 | 47.32 | 74.3 |
ViT-B-F | 32.53 | 57.51 | 37.17 | 63.22 |
ViT-B | 45.22 | 70.69 | 48.24 | 73.43 |
ViT-L-G | 56.15 | 79.86 | 57.54 | 81.46 |
ViT-L-F | 49.74 | 76.3 | 52.83 | 78.88 |
ViT-L | 50.22 | 74.79 | 54.43 | 80.1 |
Wukong 100m dataset files can be downloaed from Wukong, file structure should be like this:
.
└── data_root
└─wukong_release
├─ wukong_100m_0.csv
├─ wukong_100m_1.csv
├─ wukong_100m_2.csv
├─ ....
└─ wukong_100m_255.csv
We provide a multi-threaded python script for downloading the images through annotation files.
cd models/research/mm/wukong/src/dataset/
python wukong_download.py --csv_dir /path/to/data_root/wukong_release/ --img_dir IMG_DIR [--start_id 0] [--end_id -1] [--thread_num 4]
where IMG_DIR refer to the downloaded image dir, option start_id and end_id defines the start and end id for csv files to be downloaded, thread_num defines the number of threads used for parallel downloading. If not provided, default setting will download images in all csv files. Each csv file corresponds to a subdir under IMG_DIR and the final structure is like this:
.
└── IMG_DIR
├─000
│ ├─ 00000.jpg
│ ├─ 00001.jpg
│ ├─ 00002.jpg
│ └─ ......
├─001
├─002
├─...
In order to be used in Mindspore, we convert the raw data into MindRecord format. To do this, run code
cd models/research/mm/wukong/
python -m src.dataset.generate_dataset --csv_dir /path/to/data_root/wukong_release/ --img_dir IMG_DIR --data_record_dir DATA_RECORD_DIR [--shard_num 10] [--worker_num 4] [--block_size 2000]
Here DATA_RECORD_DIR refer to the path where mindrecord files will be generated into; shared_num refer to the number of files mindrecord is splited; worker_num refer to the number of workers to convert mindrecord and block size defines the block size of each write. After execution the mindrecord files should be like this
└─DATA_RECORD_DIR
├─ wukong100m.mindrecord0
├─ wukong100m.mindrecord0.db
├─ ....
├─ wukong100m.mindrecord9
└─ wukong100m.mindrecord9.db
Then you can load the dataset in a standard format like get_wukong_dataset function in models/research/mm/wukong/src/dataset/dataset.py.