Skip to content
/ FiT3D Public

[ECCV 2024] Improving 2D Feature Representations by 3D-Aware Fine-Tuning

License

Notifications You must be signed in to change notification settings

ywyue/FiT3D

Repository files navigation

Improving 2D Feature Representations by 3D-Aware Fine-Tuning

ECCV 2024

Yuanwen Yue 1, Anurag Das 2, Francis Engelmann 1,3, Siyu Tang 1, Jan Eric Lenssen 2

1ETH Zurich, 2Max Planck Institute for Informatics, 3Google

Open In Colab Spaces

This is the official repository (under construction) for the paper Improving 2D Feature Representations by 3D-Aware Fine-Tuning.

Changelog

  • Add Colab Notebook and Hugging Face demo
  • Release ScanNet++ preprocessing code
  • Release feature Gaussian training code
  • Release fine-tuning code
  • Release evaluation code
Table of Contents
  1. Demo
  2. Preparation
  3. Training
  4. Evaluation
  5. Citation

Demo

We provide a Colab Notebook with step-by-step guides to make inference and visualize the PCA features and K-Means clustering of original 2D models and our fine-tuned models. We also provide an online Hugging Face demo 🤗 where users can upload their own images and check the visualizations online. Alternatively, to run the demo locally, just try python app.py.

Preparation

Environment

  • The code has been tested on Linux with Python 3.10.14, torch 1.9.0, and cuda 11.8.
  • Create an environment and install pytorch and other required packages:
    git clone https://github.com/ywyue/FiT3D.git
    cd FiT3D
    conda create -n fit3d python=3.10
    conda activate fit3d
    pip install torch==2.0.0 torchvision==0.15.1 --index-url https://download.pytorch.org/whl/cu118
    pip install -r requirements.txt
  • Compile the feature rasterization modules and the knn module for feature lifting:
    cd submodules/diff-feature-gaussian-rasterization
    python setup.py install
    cd ../simple-knn/
    python setup.py install

Data

We train feature Gaussians and fine-tuning on ScanNet++ scenes. Preprocessing code and instructions are here. After preprocessing, the ScanNet++ data is expected to be organized as following:

FiT3D/
└── db/
    └── scannetpp/
        ├── metadata/
        |    ├── nvs_sem_train.txt  # Training set for NVS and semantic tasks with 230 scenes
        |    ├── nvs_sem_val.txt # Validation set for NVS and semantic tasks with 50 scenes
        |    ├── train_samples.txt  # Training sample list, formatted as sceneID_imageID
        |    ├── val_samples.txt # Validation sample list, formatted as sceneID_imageID
        |    ├── train_view_info.npy  # Training sample camera info, e.g. projection matrices
        |    └── val_view_info.npy # Validation sample camera info, e.g. projection matrices
        └── scenes/
            ├── 0a5c013435  # scene id
            ├── ...
            └── 0a7cc12c0e
              ├── images  # undistorted and downscaled images
              ├── masks # undistorted and downscaled anonymized masks
              ├── points3D.txt  # 3D feature points used by COLMAP
              └── transforms_train.json # camera poses in the format used by Nerfstudio

For all other evaluation datasets (ScanNet, NYUd, NYUv2, ADE20k, Pascal VOC, KITTI), please follow their official websites for downloading instructions.

Training

Stage I: Lifting Features to 3D

Example command to train the feature Gaussians for a single scene:

python train_feat_gaussian.py --run_name=example_feature_gaussian_training \
                    --model_name=dinov2_small \
                    --source_path=db/scannetpp/scenes/0a5c013435 \
                    --low_sem_dim=64

model_name indicates the 2D feature extractor and can be selected from dinov2_small, dinov2_base, dinov2_reg_small, clip_base, mae_base, deit3_base. low_sem_dim is the dimension of the semantic feature vector attached to each Gaussian. Note it should have the same value with NUM_CHANNELS_FEAT in submodules/diff-feature-gaussian-rasterization/cuda_rasterizer/config.h.

To generate the commands for training Gaussians for all scenes in ScanNet++, run:

python gen_commands.py --train_fgs_commands_folder=train_fgs_commands --model_name=dinov2_small --low_sem_dim=64

Training commands for all scenes will be stored in train_fgs_commands.

After training, we need to write the parameters of all feature Gaussians to a single file, which will be used in the 2nd stage. To do that, run:

python write_feat_gaussian.py

After that, all the pretrained Gaussians of training scenes are stored as pretrained_feat_gaussians_train.pth and all the pretrained Gaussians of validation scenes are stored as pretrained_feat_gaussians_val.pth. Both files will be stored in db/scannetpp/metadata.

Stage II: Fine-Tuning

In this stage, we use the pretrained Gaussians to render features and use those features as target to finetune the 2D feature extractor. To do that, run

python finetune.py --model_name=dinov2_small \
                   --output_dir=output_finemodel \
                   --job_name=finetuning_dinov2_small \
                   --train_gaussian_list=db/scannetpp/metadata/pretrained_feat_gaussians_train.pth \
                   --val_gaussian_list=db/scannetpp/metadata/pretrained_feat_gaussians_val.pth

model_name indicates the 2D feature extractor and should be consistent with the feature extractor used in the first stage. The default fine-tuning epoch is 1, after which the weights of the finetuned model will be saved in output_dir/date_job_name.

Evaluation

Citation

If you find our code or paper useful, please cite:

@inproceedings{yue2024improving,
  title     = {{Improving 2D Feature Representations by 3D-Aware Fine-Tuning}},
  author    = {Yue, Yuanwen and Das, Anurag and Engelmann, Francis and Tang, Siyu and Lenssen, Jan Eric},
  booktitle = {European Conference on Computer Vision (ECCV)},
  year      = {2024}
}