Skip to content

Files

Failed to load latest commit information.

Latest commit

 Cannot retrieve latest commit at this time.

History

History

training

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Training Code for SAM 2

This folder contains the training code for SAM 2, a foundation model for promptable visual segmentation in images and videos. The code allows users to train and fine-tune SAM 2 on their own datasets (image, video, or both).

Structure

The training code is organized into the following subfolders:

  • dataset: This folder contains image and video dataset and dataloader classes as well as their transforms.
  • model: This folder contains the main model class (SAM2Train) for training/fine-tuning. SAM2Train inherits from SAM2Base model and provides functions to enable training or fine-tuning SAM 2. It also accepts all training-time parameters used for simulating user prompts (e.g. iterative point sampling).
  • utils: This folder contains training utils such as loggers and distributed training utils.
  • scripts: This folder contains the script to extract the frames of SA-V dataset to be used in training.
  • loss_fns.py: This file has the main loss class (MultiStepMultiMasksAndIous) used for training.
  • optimizer.py: This file contains all optimizer utils that support arbitrary schedulers.
  • trainer.py: This file contains the Trainer class that accepts all the Hydra configurable modules (model, optimizer, datasets, etc..) and implements the main train/eval loop.
  • train.py: This script is used to launch training jobs. It supports single and multi-node jobs. For usage, please check the Getting Started section or run python training/train.py -h

Getting Started

To get started with the training code, we provide a simple example to fine-tune our checkpoints on MOSE dataset, which can be extended to your custom datasets.

Requirements:

  • We assume training on A100 GPUs with 80 GB of memory.
  • Download the MOSE dataset using one of the provided links from here.

Steps to fine-tune on MOSE:

  • Install the packages required for training by running pip install -e ".[dev]".

  • Set the paths for MOSE dataset in configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml.

    dataset:
        # PATHS to Dataset
        img_folder: null # PATH to MOSE JPEGImages folder
        gt_folder: null # PATH to MOSE Annotations folder
        file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
  • To fine-tune the base model on MOSE using 8 GPUs, run

    python training/train.py \
        -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \
        --use-cluster 0 \
        --num-gpus 8

    We also support multi-node training on a cluster using SLURM, for example, you can train on 2 nodes by running

    python training/train.py \
        -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \
        --use-cluster 1 \
        --num-gpus 8 \
        --num-nodes 2
        --partition $PARTITION \
        --qos $QOS \
        --account $ACCOUNT

    where partition, qos, and account are optional and depend on your SLURM configuration. By default, the checkpoint and logs will be saved under sam2_logs directory in the root of the repo. Alternatively, you can set the experiment log directory in the config file as follows:

      experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}

    The training losses can be monitored using tensorboard logs stored under tensorboard/ in the experiment log directory. We also provide a sample validation split for evaluation purposes. To generate predictions, follow this guide on how to use our vos_inference.py script. After generating the predictions, you can run the sav_evaluator.py as detailed here. The expected MOSE J&F after fine-tuning the Base plus model is 79.4.

    After training/fine-tuning, you can then use the new checkpoint (saved in checkpoints/ in the experiment log directory) similar to SAM 2 released checkpoints (as illustrated here).

Training on images and videos

The code supports training on images and videos (similar to how SAM 2 is trained). We provide classes for loading SA-1B as a sample image dataset, SA-V as a sample video dataset, as well as any DAVIS-style video dataset (e.g. MOSE). Note that to train on SA-V, you must first extract all videos to JPEG frames using the provided extraction script. Below is an example of how to setup the datasets in your config to train on a mix of image and video datasets:

data:
  train:
    _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset 
    phases_per_epoch: ${phases_per_epoch} # Chunks a single epoch into smaller phases
    batch_sizes: # List of batch sizes corresponding to each dataset
    - ${bs1} # Batch size of dataset 1
    - ${bs2} # Batch size of dataset 2
    datasets:
    # SA1B as an example of an image dataset
    - _target_: training.dataset.vos_dataset.VOSDataset
      training: true
      video_dataset:
        _target_: training.dataset.vos_raw_dataset.SA1BRawDataset
        img_folder: ${path_to_img_folder}
        gt_folder: ${path_to_gt_folder}
        file_list_txt: ${path_to_train_filelist} # Optional
      sampler:
        _target_: training.dataset.vos_sampler.RandomUniformSampler
        num_frames: 1
        max_num_objects: ${max_num_objects_per_image}
      transforms: ${image_transforms}
    # SA-V as an example of a video dataset
    - _target_: training.dataset.vos_dataset.VOSDataset
      training: true
      video_dataset:
        _target_: training.dataset.vos_raw_dataset.JSONRawDataset
        img_folder: ${path_to_img_folder}
        gt_folder: ${path_to_gt_folder}
        file_list_txt: ${path_to_train_filelist} # Optional
        ann_every: 4
      sampler:
        _target_: training.dataset.vos_sampler.RandomUniformSampler
        num_frames: 8 # Number of frames per video
        max_num_objects: ${max_num_objects_per_video}
        reverse_time_prob: ${reverse_time_prob} # probability to reverse video
      transforms: ${video_transforms}
    shuffle: True
    num_workers: ${num_train_workers}
    pin_memory: True
    drop_last: True
    collate_fn:
    _target_: training.utils.data_utils.collate_fn
    _partial_: true
    dict_key: all