This repository provides the official source code and model weights for our Cross-task Attention Mechanism for Dense Multi-task Learning paper (WACV 2023). The implementation is done using the PyTorch library.
DenseMTL: Cross-task Attention Mechanism for Dense Multi-task Learning
Ivan Lopes1,
Tuan-Hung Vu1,2,
Raoul de Charette1
1 Inria, Paris, France.
2 Valeo.ai, Paris, France.
To cite our paper, please use:
@inproceedings{lopes2023densemtl,
title={Cross-task Attention Mechanism for Dense Multi-task Learning},
author={Lopes, Ivan and Vu, Tuan-Hung and de Charette, Raoul},
booktitle={WACV},
year={2023}
}
- Cross-task Attention Mechanism for Dense Multi-task Learning (DenseMTL)
- Table of content
- Overview
- Installation
- Running DenseMTL
- Project structure
- Credit
- License
DenseMTL is an cross-attention based multi-task architecture which leverages multiple attention mechanisms to extract and enrich task features. As seen in the figure above, xTAM modules each receive a pair of differing task features to better assess cross task interactions and allow for an efficient cross talk distillation.
In total, this work covers a wide range of experiments, we summarize it by:
- 3 settings: fully-supervised (
FS
), semi-supervised auxiliary depth (SDE
), and domain adaptation (DA
). - 4 datasets: Cityscapes, Virtual Kitti 2, Synthia, and NYU-Depth v2.
- 4 tasks: semantic segmentation (
S
), depth regression (D
), surface normals estimation (N
), and edge detection (E
). - 3 task sets:
{S, D}
,{S, D, N}
, and{S, D, N, E}
.
First create a new conda environment with the required packages found in environment.yml
. You can do so with the following line:
>>> conda env create -n densemtl -f environment.yml
Then activate environment densemtl
using:
>>> conda activate densemtl
-
CITYSCAPES: Follow the instructions in Cityscape to download the images and validation ground-truths. Please follow the dataset directory structure:
<CITYSCAPES_DIR>/ % Cityscapes dataset root ├── leftImg8bit/ % input image (leftImg8bit_trainvaltest.zip) ├── leftImg8bit_sequence/ % sequences need for monodepth (disparity_sequence_trainvaltest.zip) ├── disparity/ % stereo depth (disparity_trainvaltest.zip) ├── camera/ % camera parameters (camera_trainvaltest.zip) └── gtFine/ % semantic segmentation labels (gtFine_trainvaltest.zip)
-
SYNTHIA: Follow the instructions here to download the images from the SYNTHIA-RAND-CITYSCAPES (CVPR16) split. Download the segmentation labels from CTRL-UDA using the link here. Please follow the dataset directory structure:
<SYNTHIA_DIR>/ % Synthia dataset root ├── RGB/ % input images ├── GT/ % semseg labels labels └── Depth/ % depth labels
-
VKITTI2: Follow the instructions here to download the images from the Virtual KITTI 2 dataset. Please follow the dataset directory structure:
<VKITTI2_DIR>/ % VKITTI 2 dataset root ├── rgb/ % input images (vkitti_2.0.3_rgb.tar) ├── classSegmentation/ % semseg labels (vkitti_2.0.3_classSegmentation.tar) └── depth/ % depth labels (vkitti_2.0.3_depth.tar)
-
NYUDv2: Follow the instructions here to download the NYUDv2 dataset along with its semantic segmentation, normals, depth, and edge labels. Please follow the dataset directory structure:
<NYUDV2_DIR>/ % NYUDv2 dataset root ├── images/ % input images ├── segmentation/ % semseg labels ├── depth/ % depth labels ├── edge/ % semseg labels └── normals/ % normals labels
Update configs/env_config.yml
with the path to the different directories by defining:
- Saved models path:
MODELS
. - Datasets paths
CITYSCAPES_DIR
,SYNTHIA_DIR
,VKITTI2_DIR
, andNYUDV2_DIR
. - Logs path:
LOG_DIR
.
All constants provided in this file are loaded as environment variables and accessible at runtime via os.environ
. Alternatively those constants can be defined in the command line before running the project.
The following are the command line inferface arguments and options:
--env-config ENV_CONFIG
Path to file containing the environment paths, defaults to configs/env_config.yml.
--base CONFIG
Optional path to base configuration yaml file, can be left unused if --config file contains all keys
--config CONFIG
Path to main configuration yaml file.
--project PROJECT
Project name for logging and used as wandb project
--resume
Flag to resume training, this will look for last available model checkpoint from the same setup
--evaluate PATH
Will load the model provided at the file path and perform evaluation using the config setup.
-s SEED, --seed SEED
Seed for training and dataset.
-d, --debug
Flag to perform single validation inference for debugging purposes.
-w, --disable-wandb
Flag to disable Weight & Biases logging
Experiments are based off of configuration files. Overall each configuration file must follow this structure:
setup:
name: exp-name
model:
└── model args
loss:
└── loss args
lr:
└── learning rates
data:
└── data module args
training:
└── training args
optim:
└── optimizer args
scheduler:
└── scheduler args
For arguments which are recurring across experiments such as data
, training
, optim
, scheduler
, we use a base configuration file that we pass to the process via the --base
option. The two configuration files (provided with --base
and --config
) are merged together at the top level (config
can overwrite base
). See more details in main.py.
Environment variables can be referenced inside the configuration file by using the $ENV:
prefix, eg.: path: $ENV:CITYSCAPES_DIR
.
To reproduce the experiments, you can run the following scripts.
-
Single task learning baseline:
python main.py \ --base=configs/<dataset>/fs_bs2.yml \ --config=configs/<dataset>/resnet101_STL_<task>.yml
Where
<dataset>
$\in$ {cityscapes
,synthia
,vkitti2
,nyudv2
} and<task>
$\in$ {S
,D
,N
,E
}. -
Our method on the fully supervised setting (FS):
python main.py \ --base=configs/<dataset>/fs_bs2.yml \ --config=configs/<dataset>/resnet101_ours_<set>.yml
Where
<dataset>
$\in$ {cityscapes
,synthia
,vkitti2
,nyudv2
} and<set>
$\in$ {SD
,SDN
,SDNE
}.Do note, however, that experiments with the edge estimation task are only performed on the NYUDv2 dataset.
-
Our method on semi-supervised depth estimation (SDE):
python main.py --config configs/cityscapes/monodepth/resnet101_ours_SD.yml
-
Our method on domain adaptation (DA):
python main.py \ --base=configs/da/<dataset>/fs_bs2.yml \ --config configs/da/<dataset>/resnet101_ours_SD.yml
Where
<dataset>
$\in$ {sy2cs
(for Synthia$\mapsto$ Cityscapes),vk2cs
(for VKITTI2$\mapsto$ Cityscapes)}.
Our models on fully supervised training:
Setup | Set | Link |
---|---|---|
Synthia | SD | sy_densemtl_SD.pkl |
Synthia | SDN | sy_densemtl_SDN.pkl |
Virtual Kitti 2 | SD | vk_densemtl_SD.pkl |
Virtual Kitti 2 | SDN | vk_densemtl_SDN.pkl |
Cityscapes | SD | cs_densemtl_SD.pkl |
Cityscapes | SDN | cs_densemtl_SDN.pkl |
To evaluate a model, the --evaluate
option can be set with a path to the state dictionnary .pkl
file. This weight file will be loaded onto the model and the evaluation loop launched. Keep in mind you also need to provide a valid configuration files in order to evaluate our method with weights located in weights/vkitti2_densemtl_SD.pkl
, simply run:
python main.py \
--config=configs/vkitti2/resnet101_ours_SD.yml \
--base=configs/vkitti2/fs_bs2.yml \
--evaluate=weights/vkitti2densemtl.pkl
By default, visualizations, losses, and metrics are logged using Weights & Biases. In case you do not wish to log your trainings and evaluations through this tool, you can disable it by using the --disable-wandb
flag. In all cases, the loss values and metrics are logged via the standard output.
Checkpoints, models and configuration files are saved under the LOG_DIR
directory folder. More specifically, those will be located under <LOG_DIR>/<dataset>/<config-name>/s<seed>/<timestamp>
. For example you could have something like: <LOG_DIR>/vkitti2/resnet101_ours_SD/s42/2022-04-19_10-09-49
for a SD
training of our method on VKITTI2 with a seed equal to 42
.
The main.py
file is the entry point to perform training and evaluation on the different setups.
root
├── configs/ % Configuration files to run the experiments
├── training/ % Training loops for all settings
├── dataset/ % PyTorch dataset definitions as well as semantic segmentation encoding logic
├── models/ % Neural network modules, inference logic and method implementation
├── optim/ % Optimizers related code
├── loss/ % Loss modules for each task type includes the metric and visualization calls
├── metrics/ % Task metric implementations
├── vendor/ % Third party source code
└── utils/ % Utility code for other parts of the code
This repository contains code taken from Valeo.ai's ADVENT, Simon Vandenhende's MTL-survey, Niantic Labs' Monodepth 2, and Lukas Hoyer's 3-Ways.
DenseMTL is released under the Apache 2.0 license.