Official implementation of the paper Isomorphic Pruning for Vision Models.
pip install -r requirements.txt
Please prepare the ImageNet dataset as follows:
data
├── imagenet
│ ├── train
│ │ ├── n01440764
│ │ ├── n01443537
│ │ ├── n01484850
│ │ ├── n01491361
│ └── val
│ │ ├── n01440764
│ │ ├── n01443537
│ │ ├── n01484850
│ │ ├── n01491361
We provide scripts to reproduce the results in our paper. Our pruned models are also available here
mkdir pretrained && cd pretrained
wget https://www.dropbox.com/s/7z1z1z1z1z1z1z1/pruned_models.zip
pretrained
├── deit_0.6G_isomorphic.pth
├── deit_1.2G_isomorphic.pth
├── deit_2.6G_isomorphic.pth
└── deit_4.2G_isomorphic.pth
You can evluate the pruned models using the following command:
python evaluate.py --model pretrained/deit_4.2G_isomorphic.pth --interpolation bicubic
MACs: 4.1626 G, Params: 20.6943 M
Evaluating pretrained/deit_4.2G_isomorphic.pth...
100%|███████████████| 782/782 [01:57<00:00, 6.68it/s]
Accuracy: 0.8241, Loss: 0.8036
Evaluate the performance of the pre-trained models from Timm on ImageNet validation set.
bash scripts/evaluation/deit_small_distilled_patch16_224.fb_in1k.sh
MACs: 4.6391 G, Params: 22.4364 M
Evaluating deit_small_distilled_patch16_224.fb_in1k...
100%|█████████████| 782/782 [02:00<00:00, 6.51it/s]
Accuracy: 0.8117, Loss: 0.7511
Perform isomorphic pruning on the pre-trained models. We use data-driven method to estimate the importance of parameters. It will accumulate the importance scores over multiple batches.
bash scripts/pruning/deit_4.2G.sh
Summary:
MACs: 17.69 G => 4.17 G
Params: 87.34 M => 20.72 M
Saving the pruned model to output/pruned/deit_4.2G.pth...
Finetune the pruned model and save the intermediate/latest/best checkpoints under output/finetuned
.
bash scripts/finetuning/deit_4.2G.sh
The pruned model will be saved as a .pth
file with the model definition. We can directly load the .pth
to obain the pruned model. You can also pass the timm model name to the script to download the pre-trained model and evaluate it.
# bilinear for ResNet and bicubic for other models
python evaluate.py --model PATH_TO_PRUNED_MODEL --interpolation bicubic