Codes and pretrained models for TWIST:
@article{wang2021self,
title={Self-Supervised Learning by Estimating Twin Class Distributions},
author={Wang, Feng and Kong, Tao and Zhang, Rufeng and Liu, Huaping and Li, Hang},
journal={arXiv preprint arXiv:2110.07402},
year={2021}
}
TWIST is a novel self-supervised representation learning method by classifying large-scale unlabeled datasets in an end-to-end way. We employ a siamese network terminated by a softmax operation to produce twin class distributions of two augmented images. Without supervision, we enforce the class distributions of different augmentations to be consistent. In the meantime, we regularize the class distributions to make them sharp and diverse. TWIST can naturally avoid the trivial solutions without specific designs such as asymmetric network, stop-gradient operation, or momentum encoder.
- 12/4/2022: The performances of ViT-S (DeiT-S) and ViT-B are improved (+0.6 and +1.1 respectively), which is achieved by changing hyper-parameters (reduce the batch-size from 2048 to 1024, and change the drop path rate from 0.0 to 0.1 for ViT-B).
arch | params | epochs | linear | download | ||||
---|---|---|---|---|---|---|---|---|
Model with multi-crop and self-labeling | ||||||||
ResNet-50 | 24M | 850 | 75.5% | backbone only | full ckpt | args | log | eval logs |
ResNet-50w2 | 94M | 250 | 77.7% | backbone only | full ckpt | args | log | eval logs |
DeiT-S | 21M | 300 | 76.2% | backbone only | full ckpt | args | log | eval logs |
ViT-B | 86M | 300 | 78.4% | backbone only | full ckpt | args | log | eval logs |
Model without multi-crop and self-labeling | ||||||||
ResNet-50 | 24M | 800 | 72.6% | backbone only | full ckpt | args | log | eval logs |
arch | params | epochs | NMI | AMI | ARI | ACC | download | |||
---|---|---|---|---|---|---|---|---|---|---|
ResNet-50 | 24M | 800 | 74.4 | 57.7 | 30.1 | 40.5 | backbone only | full ckpt | args | log |
arch | 1% labels | 10% labels | 100% labels |
---|---|---|---|
resnet-50 | 61.5% | 71.7% | 78.4% |
resnet-50w2 | 67.2% | 75.3% | 80.3% |
Task | AP all | AP 50 | AP 75 |
---|---|---|---|
VOC07+12 detection | 58.1 | 84.2 | 65.4 |
COCO detection | 41.9 | 62.6 | 45.7 |
COCO instance segmentation | 37.9 | 59.7 | 40.6 |
ResNet-50 (requires 8 GPUs, Top-1 Linear 72.6%)
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env train.py \
--data-path ${DATAPATH} \
--output_dir ${OUTPUT} \
--aug barlow \
--batch-size 256 \
--dim 32768 \
--epochs 800
ResNet-50 (requires 16 GPUs spliting over 2 nodes for multi-crop training, Top-1 Linear 75.5%)
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
--nnodes=${WORKER_NUM} \
--node_rank=${MACHINE_ID} \
--master_addr=${HOST} \
--master_port=${PORT} train.py \
--data-path ${DATAPATH} \
--output_dir ${OUTPUT}
ResNet-50w2 (requires 32 GPUs spliting over 4 nodes for multi-crop training, Top-1 Linear 77.7%)
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
--nnodes=${WORKER_NUM} \
--node_rank=${MACHINE_ID} \
--master_addr=${HOST} \
--master_port=${PORT} train.py \
--data-path ${DATAPATH} \
--output_dir ${OUTPUT} \
--backbone 'resnet50w2' \
--batch-size 60 \
--bunch-size 240 \
--epochs 250 \
--mme_epochs 200
DeiT-S (requires 16 GPUs spliting over 2 nodes for multi-crop training, Top-1 Linear 75.6%)
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
--nnodes=${WORKER_NUM} \
--node_rank=${MACHINE_ID} \
--master_addr=${HOST} \
--master_port=${PORT} train.py \
--data-path ${DATAPATH} \
--output_dir ${OUTPUT} \
--backbone 'vit_s' \
--batch-size 64 \
--bunch-size 256 \
--clip_norm 3.0 \
--epochs 300 \
--mme_epochs 300 \
--lam1 -0.6 \
--lam2 1.0 \
--local_crops_number 6 \
--lr 0.0003 \
--momentum_start 0.996 \
--momentum_end 1.0 \
--optim admw \
--use_momentum_encoder 1 \
--weight_decay 0.06 \
--weight_decay_end 0.12
ViT-B (requires 32 GPUs spliting over 4 nodes for multi-crop training, Top-1 Linear 77.3%)
python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
--nnodes=${WORKER_NUM} \
--node_rank=${MACHINE_ID} \
--master_addr=${HOST} \
--master_port=${PORT} train.py \
--data-path ${DATAPATH} \
--output_dir ${OUTPUT} \
--backbone 'vit_b' \
--batch-size 32 \
--bunch-size 256 \
--clip_norm 3.0 \
--epochs 300 \
--mme_epochs 300 \
--lam1 -0.6 \
--lam2 1.0 \
--local_crops_number 10 \
--lr 0.00075 \
--momentum_start 0.996 \
--momentum_end 1.0 \
--optim admw \
--use_momentum_encoder 1 \
--weight_decay 0.06 \
--weight_decay_end 0.06 \
--drop_path 0.1
For ResNet-50
python3 evaluate.py \
${DATAPATH} \
${OUTPUT}/checkpoint.pth \
--weight-decay 0 \
--checkpoint-dir ${OUTPUT}/linear_multihead/ \
--batch-size 1024 \
--val_epoch 1 \
--lr-classifier 0.2
For DeiT-S
python3 -m torch.distributed.launch --nproc_per_node=8 evaluate_vitlinear.py \
--arch vit_s \
--pretrained_weights ${OUTPUT}/checkpoint.pth \
--lr 0.02 \
--data_path ${DATAPATH} \
--output_dir ${OUTPUT} \
For ViT-B
python3 -m torch.distributed.launch --nproc_per_node=8 evaluate_vitlinear.py \
--arch vit_b \
--pretrained_weights ${OUTPUT}/checkpoint.pth \
--lr 0.0015 \
--data_path ${DATAPATH} \
--output_dir ${OUTPUT} \
1% Percent (61.5%)
python3 evaluate.py ${DATAPATH} ${MODELPATH} \
--weights finetune \
--lr-backbone 0.04 \
--lr-classifier 0.2 \
--train-percent 1 \
--weight-decay 0 \
--epochs 20 \
--backbone 'resnet50'
10% Percent (71.7%)
python3 evaluate.py ${DATAPATH} ${MODELPATH} \
--weights finetune \
--lr-backbone 0.02 \
--lr-classifier 0.2 \
--train-percent 10 \
--weight-decay 0 \
--epochs 20 \
--backbone 'resnet50'
100% Percent (78.4%)
python3 evaluate.py ${DATAPATH} ${MODELPATH} \
--weights finetune \
--lr-backbone 0.01 \
--lr-classifier 0.2 \
--train-percent 100 \
--weight-decay 0 \
--epochs 30 \
--backbone 'resnet50'
-
Install detectron2.
-
Convert a pre-trained MoCo model to detectron2's format:
python3 detection/convert-pretrain-to-detectron2.py ${MODELPATH} ${OUTPUTPKLPATH}
-
Put dataset under "detection/datasets" directory, following the directory structure requried by detectron2.
-
Training: VOC
cd detection/ python3 train_net.py \ --config-file voc_fpn_1fc/pascal_voc_R_50_FPN_24k_infomin.yaml \ --num-gpus 8 \ MODEL.WEIGHTS ../${OUTPUTPKLPATH}
COCO
python3 train_net.py \ --config-file infomin_configs/R_50_FPN_1x_infomin.yaml \ --num-gpus 8 \ MODEL.WEIGHTS ../${OUTPUTPKLPATH}