Unofficial Pytorch Lightning implementation of Contrastive Syn-to-Real Generalization (ICLR 2021).
Based on:
Tested in a Python 3.8 environment in Linux and Windows with:
- Pytorch: 1.8.1
- Pytorch Lightning: 1.3.1
- Lightning bolts: 0.3.3
- Torchmetrics: 0.3.2
- 1x RTX 3070
Installing the dependencies:
pip install pytorch-lightning lightning-bolts torchmetrics
Download VisDA17 dataset from official website or, use the provided script for your convenience.
# The script downloads and extracts VisDA17 dataset.
# Note: It takes a very long time to download full dataset.
python datasets/prepare_visda17.py
If you downloaded the dataset manually, extract and place them as below.
📂 datasets
┣ 📂 visda17
┃ ┣ 📂 train
┃ ┃ 📂 validation
┗ ┗ 📂 test
Training
Simply run:
python run.py
or with options,
usage: run.py [-h] [-o OUTPUT] [-r ROOT] [-e EPOCHS] [-lr LEARNING_RATE] [-bs BATCH_SIZE] [-wd WEIGHT_DECAY] [--task {classification,segmentation}] [--encoder {resnet101,deeplab50,deeplab101}] [--momentum MOMENTUM] [--num-classes NUM_CLASSES] [--eval-only] [--gpus GPUS]
[--resume RESUME] [--dev-run] [--exp-name EXP_NAME] [--augmentation AUGMENTATION] [--seed SEED] [--fc-dim FC_DIM] [--no-apool] [--single-network] [--stages STAGES [STAGES ...]] [--emb-dim EMB_DIM] [--emb-depth EMB_DEPTH] [--num-patches NUM_PATCHES]
[--moco-weight MOCO_WEIGHT] [--moco-queue-size MOCO_QUEUE_SIZE] [--moco-momentum MOCO_MOMENTUM] [--moco-temperature MOCO_TEMPERATURE]
Evaluation
python run.py --eval-only --resume https://github.com/ryanking13/CSG/releases/download/v0.2/csg_resnet101.ckpt
Model | Accuracy |
---|---|
CSG (from paper) | 64.1 |
CSG (reimpl) | 67.1 |
Download GTA5 and Cityscapes datasets.
Place them as below.
📂 datasets
┣ 📂 GTA5
┃ ┣ 📂 images
┃ ┃ ┣ 📜 00001.png
┃ ┃ ┣ ...
┃ ┃ ┗ 📜 24966.png
┃ ┃ ┣ 📂 labels
┃ ┃ ┣ 📜 00001.png
┃ ┃ ┣ ...
┃ ┃ ┗ 📜 24966.png
┣ 📂 cityscapes
┃ ┣ 📂 leftImg8bit
┃ ┃ ┣ 📂 train
┃ ┃ ┃ 📂 val
┗ ┗ ┗ 📂 test
┃ ┣ 📂 gtFine
┃ ┃ ┣ 📂 train
┃ ┃ ┃ 📂 val
┗ ┗ ┗ 📂 test
Training
Simply run:
./run_seg.sh
Evaluation
./run_seg --eval-only --resume https://github.com/ryanking13/CSG/releases/download/v0.2/csg_deeplab50.ckpt
Model | IoU |
---|---|
CSG (from paper) | 35.27 |
CSG (reimpl) | 34.71 |
- Warmup LR scheduler
- No layerwise LR modification
- RandAugment augmentation types
- I got error
Distributed package doesn't have NCCL built in
On windows, nccl
is not supported, try:
set PL_TORCH_DISTRIBUTED_BACKEND=gloo