Skip to content

Cross-task Attention Mechanism for Dense Multi-task Learning (WACV 2023)

License

Notifications You must be signed in to change notification settings

astra-vision/DenseMTL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Cross-task Attention Mechanism for Dense Multi-task Learning (DenseMTL)

This repository provides the official source code and model weights for our Cross-task Attention Mechanism for Dense Multi-task Learning paper (WACV 2023). The implementation is done using the PyTorch library.

DenseMTL: Cross-task Attention Mechanism for Dense Multi-task Learning
Ivan Lopes1, Tuan-Hung Vu1,2, Raoul de Charette1
1 Inria, Paris, France. 2 Valeo.ai, Paris, France.

To cite our paper, please use:

@inproceedings{lopes2023densemtl,
  title={Cross-task Attention Mechanism for Dense Multi-task Learning},
  author={Lopes, Ivan and Vu, Tuan-Hung and de Charette, Raoul},
  booktitle={WACV},
  year={2023}
}

Table of content

Overview

DenseMTL is an cross-attention based multi-task architecture which leverages multiple attention mechanisms to extract and enrich task features. As seen in the figure above, xTAM modules each receive a pair of differing task features to better assess cross task interactions and allow for an efficient cross talk distillation.

In total, this work covers a wide range of experiments, we summarize it by:

  • 3 settings: fully-supervised (FS), semi-supervised auxiliary depth (SDE), and domain adaptation (DA).
  • 4 datasets: Cityscapes, Virtual Kitti 2, Synthia, and NYU-Depth v2.
  • 4 tasks: semantic segmentation (S), depth regression (D), surface normals estimation (N), and edge detection (E).
  • 3 task sets: {S, D}, {S, D, N}, and {S, D, N, E}.

Installation

1. Dependencies

First create a new conda environment with the required packages found in environment.yml. You can do so with the following line:

>>> conda env create -n densemtl -f environment.yml

Then activate environment densemtl using:

>>> conda activate densemtl

2. Datasets

  • CITYSCAPES: Follow the instructions in Cityscape to download the images and validation ground-truths. Please follow the dataset directory structure:

    <CITYSCAPES_DIR>/             % Cityscapes dataset root
    ├── leftImg8bit/              % input image (leftImg8bit_trainvaltest.zip)
    ├── leftImg8bit_sequence/     % sequences need for monodepth (disparity_sequence_trainvaltest.zip)
    ├── disparity/                % stereo depth (disparity_trainvaltest.zip)
    ├── camera/                   % camera parameters (camera_trainvaltest.zip)
    └── gtFine/                   % semantic segmentation labels (gtFine_trainvaltest.zip)
  • SYNTHIA: Follow the instructions here to download the images from the SYNTHIA-RAND-CITYSCAPES (CVPR16) split. Download the segmentation labels from CTRL-UDA using the link here. Please follow the dataset directory structure:

    <SYNTHIA_DIR>/                % Synthia dataset root
    ├── RGB/                      % input images
    ├── GT/                       % semseg labels labels
    └── Depth/                    % depth labels
  • VKITTI2: Follow the instructions here to download the images from the Virtual KITTI 2 dataset. Please follow the dataset directory structure:

    <VKITTI2_DIR>/                % VKITTI 2 dataset root
    ├── rgb/                      % input images (vkitti_2.0.3_rgb.tar)
    ├── classSegmentation/        % semseg labels (vkitti_2.0.3_classSegmentation.tar)
    └── depth/                    % depth labels (vkitti_2.0.3_depth.tar)
  • NYUDv2: Follow the instructions here to download the NYUDv2 dataset along with its semantic segmentation, normals, depth, and edge labels. Please follow the dataset directory structure:

    <NYUDV2_DIR>/                 % NYUDv2 dataset root
    ├── images/                   % input images
    ├── segmentation/             % semseg labels
    ├── depth/                    % depth labels
    ├── edge/                     % semseg labels
    └── normals/                  % normals labels

3. Environment variables

Update configs/env_config.yml with the path to the different directories by defining:

  • Saved models path: MODELS.
  • Datasets paths CITYSCAPES_DIR, SYNTHIA_DIR, VKITTI2_DIR, and NYUDV2_DIR.
  • Logs path: LOG_DIR.

All constants provided in this file are loaded as environment variables and accessible at runtime via os.environ. Alternatively those constants can be defined in the command line before running the project.

Running DenseMTL

1. Command Line Interface

The following are the command line inferface arguments and options:

--env-config ENV_CONFIG
  Path to file containing the environment paths, defaults to configs/env_config.yml.
--base CONFIG
  Optional path to base configuration yaml file, can be left unused if --config file contains all keys
--config CONFIG
  Path to main configuration yaml file.
--project PROJECT
  Project name for logging and used as wandb project
--resume
  Flag to resume training, this will look for last available model checkpoint from the same setup
--evaluate PATH
  Will load the model provided at the file path and perform evaluation using the config setup.
-s SEED, --seed SEED
  Seed for training and dataset.
-d, --debug
  Flag to perform single validation inference for debugging purposes.
-w, --disable-wandb
  Flag to disable Weight & Biases logging

Experiments are based off of configuration files. Overall each configuration file must follow this structure:

setup:
  name: exp-name
  model:
    └── model args
  loss:
    └── loss args
  lr:
    └── learning rates
data:
  └── data module args
training:
  └── training args
optim:
  └── optimizer args
scheduler:
  └── scheduler args

For arguments which are recurring across experiments such as data, training, optim, scheduler, we use a base configuration file that we pass to the process via the --base option. The two configuration files (provided with --base and --config) are merged together at the top level (config can overwrite base). See more details in main.py.

Environment variables can be referenced inside the configuration file by using the $ENV: prefix, eg.: path: $ENV:CITYSCAPES_DIR.

2. Experiments

To reproduce the experiments, you can run the following scripts.

  • Single task learning baseline:

    python main.py \
      --base=configs/<dataset>/fs_bs2.yml \
      --config=configs/<dataset>/resnet101_STL_<task>.yml
    

    Where <dataset> $\in$ {cityscapes, synthia, vkitti2, nyudv2} and <task> $\in$ {S, D, N, E}.

  • Our method on the fully supervised setting (FS):

    python main.py \
      --base=configs/<dataset>/fs_bs2.yml \
      --config=configs/<dataset>/resnet101_ours_<set>.yml
    

    Where <dataset> $\in$ {cityscapes, synthia, vkitti2, nyudv2} and <set> $\in$ {SD, SDN, SDNE}.

    Do note, however, that experiments with the edge estimation task are only performed on the NYUDv2 dataset.

  • Our method on semi-supervised depth estimation (SDE):

    python main.py --config configs/cityscapes/monodepth/resnet101_ours_SD.yml
    
  • Our method on domain adaptation (DA):

    python main.py \
      --base=configs/da/<dataset>/fs_bs2.yml \
      --config configs/da/<dataset>/resnet101_ours_SD.yml
    

    Where <dataset> $\in$ {sy2cs (for Synthia $\mapsto$ Cityscapes), vk2cs (for VKITTI2 $\mapsto$ Cityscapes)}.

3. Models

Our models on fully supervised training:

Setup Set Link
Synthia SD sy_densemtl_SD.pkl
Synthia SDN sy_densemtl_SDN.pkl
Virtual Kitti 2 SD vk_densemtl_SD.pkl
Virtual Kitti 2 SDN vk_densemtl_SDN.pkl
Cityscapes SD cs_densemtl_SD.pkl
Cityscapes SDN cs_densemtl_SDN.pkl

4. Evaluation

To evaluate a model, the --evaluate option can be set with a path to the state dictionnary .pkl file. This weight file will be loaded onto the model and the evaluation loop launched. Keep in mind you also need to provide a valid configuration files in order to evaluate our method with weights located in weights/vkitti2_densemtl_SD.pkl, simply run:

python main.py \
  --config=configs/vkitti2/resnet101_ours_SD.yml \
  --base=configs/vkitti2/fs_bs2.yml \
  --evaluate=weights/vkitti2densemtl.pkl

5. Visualization & Logging

By default, visualizations, losses, and metrics are logged using Weights & Biases. In case you do not wish to log your trainings and evaluations through this tool, you can disable it by using the --disable-wandb flag. In all cases, the loss values and metrics are logged via the standard output.

Checkpoints, models and configuration files are saved under the LOG_DIR directory folder. More specifically, those will be located under <LOG_DIR>/<dataset>/<config-name>/s<seed>/<timestamp>. For example you could have something like: <LOG_DIR>/vkitti2/resnet101_ours_SD/s42/2022-04-19_10-09-49 for a SD training of our method on VKITTI2 with a seed equal to 42.

Project structure

The main.py file is the entry point to perform training and evaluation on the different setups.

root
  ├── configs/    % Configuration files to run the experiments
  ├── training/   % Training loops for all settings
  ├── dataset/    % PyTorch dataset definitions as well as semantic segmentation encoding logic
  ├── models/     % Neural network modules, inference logic and method implementation
  ├── optim/      % Optimizers related code
  ├── loss/       % Loss modules for each task type includes the metric and visualization calls
  ├── metrics/    % Task metric implementations
  ├── vendor/     % Third party source code
  └── utils/      % Utility code for other parts of the code

Credit

This repository contains code taken from Valeo.ai's ADVENT, Simon Vandenhende's MTL-survey, Niantic Labs' Monodepth 2, and Lukas Hoyer's 3-Ways.

License

DenseMTL is released under the Apache 2.0 license.


↑ back to top