Prototypical VoteNet for Few-Shot 3D Point Cloud Object Detection (NeurIPS 2022)
By
Shizhen Zhao
and
Xiaojuan Qi.
Most existing 3D point cloud object detection approaches heavily rely on large amounts of labeled training data. However, the labeling process is costly and time-consuming. This paper considers few-shot 3D point cloud object detection, where only a few annotated samples of novel classes are needed with abundant samples of base classes. To this end, we propose Prototypical VoteNet to recognize and localize novel instances, which incorporates two new modules: Prototypical Vote Module (PVM) and Prototypical Head Module (PHM). Specifically, as the 3D basic geometric structures can be shared among categories, PVM is designed to leverage class-agnostic geometric prototypes, which are learned from base classes, to refine local features of novel categories.Then PHM is proposed to utilize class prototypes to enhance the global feature of each object, facilitating subsequent object localization and classification, which is trained by the episodic training strategy. To evaluate the model in this new setting, we contribute two new benchmark datasets, FS-ScanNet and FS-SUNRGBD. We conduct extensive experiments to demonstrate the effectiveness of Prototypical VoteNet, and our proposed method shows significant and consistent improvements compared to baselines on two benchmark datasets.
Please make sure that you have installed all dependencies. Our implementation has been tested on one NVIDIA 3090 GPU with cuda 11.2.
Step 1. (Create virtual env using conda)
conda create --name prototypical_votenet python=3.8 -y
conda activate prototypical_votenet
Step 2. (Intall Pytorch)
pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio===0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
You may change the above command according to your cuda version. Please refer to official website of Pytorch.
Step 3. (Install mmdet, mmcv and mmsegmentation)
pip install mmdet==2.19.0
pip install mmcv-full==1.3.18 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html
pip install mmsegmentation==0.20.0
You may change the above command according to your Pytorch version. Please refer to official website of MMDetection3D.
Step 4. (Setup the code base in your env)
pip install setuptools==58.0.4
pip install -v -e .
Request the datasets (FS-ScanNet and FS-SUNRGBD) from zhaosz@eee.hku.hk (academic only). Due to licensing issues, please send me your request using your university email.
After downloading FS-ScanNet and FS-SUNRGBD, you should unzip and put it in your project folder. The datasets have been processed so that you can directly use them to train your own models.
For example, train and test the model on FS-SUNRGBD 1-shot, 2-shot, 3-shot, 4-shot, and 5-shot.
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
./configs/prototypical_votenet/train_together_sun/prototypical_votenet_16x8_sunrgbd-3d-10class_1_1.py --sample_num 16 --work_path work_path/sunrgbd_split1_shot1_1
CUDA_VISIBLE_DEVICES=1 python tools/train.py \
./configs/prototypical_votenet/train_together_sun/prototypical_votenet_16x8_sunrgbd-3d-10class_1_2.py --sample_num 16 --work_path work_path/sunrgbd_split1_shot2_1
CUDA_VISIBLE_DEVICES=2 python tools/train.py \
./configs/prototypical_votenet/train_together_sun/prototypical_votenet_16x8_sunrgbd-3d-10class_1_3.py --sample_num 16 --work_path work_path/sunrgbd_split1_shot3_1
CUDA_VISIBLE_DEVICES=3 python tools/train.py \
./configs/prototypical_votenet/train_together_sun/prototypical_votenet_16x8_sunrgbd-3d-10class_1_4.py --sample_num 16 --work_path work_path/sunrgbd_split1_shot4_1
CUDA_VISIBLE_DEVICES=4 python tools/train.py \
./configs/prototypical_votenet/train_together_sun/prototypical_votenet_16x8_sunrgbd-3d-10class_1_5.py --sample_num 16 --work_path work_path/sunrgbd_split1_shot5_1
You can find the commands in the folder named "tools/".
1-shot | 3-shot | 5-shot |
---|---|---|
model | model | model |
1-shot | 3-shot | 5-shot |
---|---|---|
model | model | model |
1-shot | 2-shot | 3-shot | 4-shot | 5-shot |
---|---|---|---|---|
model | model | model | model | model |
Please consider 😬 staring this repository and citing the following paper if you feel this repository useful.
@inproceedings{zhao2022fs3d,
title={Prototypical VoteNet for Few-Shot 3D Point Cloud Object Detection},
author={Zhao, Shizhen and Qi, Xiaojuan},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
Our code is largely based on MMDetection3D, and we thank the authors for their implementation. Please also consider citing their wonderful code base.
@misc{mmdet3d2020,
title={{MMDetection3D: OpenMMLab} next-generation platform for general {3D} object detection},
author={MMDetection3D Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmdetection3d}},
year={2020}
}
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
If you have any questions, you can email me (zhaosz@eee.hku.hk).