Skip to content

Commit

Permalink
Added RTS tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthieu Paul committed Sep 18, 2022
1 parent 5f9fc81 commit d6cbd67
Show file tree
Hide file tree
Showing 29 changed files with 3,628 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ltr/train_settings/*/debug.py
pytracking/parameter/*/debug.py
pytracking/networks/
pytracking/tracking_results/
pytracking/segmentation_results/
pytracking/result_plots/
pytracking/evaluation/local.py
pytracking/run_local.py
23 changes: 19 additions & 4 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,19 @@ of the models on standard tracking datasets.
<td>50.9</td>
<td><a href="https://drive.google.com/file/d/1XQAtrM9n_PHQn-B2i8y6Q-PQFcAoKObA">model</a></td>
</tr>
<tr>
<td>RTS</td>
<td>-</td>
<td>-</td>
<td>65.4</td>
<td>67.9</td>
<td>69.7</td>
<td>-</td>
<td>81.6</td>
<td>-</td>
<td>-</td>
<td><a href="https://drive.google.com/drive/folders/1uMQqeRN4RbaeF3IsRpvjvk6D5K8khpSW?usp=sharing">model</a></td>
</tr>
</table>

### Raw Results
Expand All @@ -181,10 +194,12 @@ The raw results on the AVisT benchmark for the trackers in this repository (see
## VOS

### Models
| Model | YouTube-VOS 2018 (Overall Score) | YouTube-VOS 2019 (Overall Score) | DAVIS 2017 val (J&F score) | Links |
|:-----------:|:--------------------------------:|:--------------------------------:|:--------------------------:|:-----:|
| [LWL_ytvos](ltr/train_settings/lwl/lwl_stage2.py) | 81.5 | 81.0 | -- | [model](https://drive.google.com/file/d/1Xnm4A2BRBliDBKO4EEFHAQfGyfOMsVyY/view?usp=sharing) |
| [LWL_boxinit](ltr/train_settings/lwl/lwl_boxinit.py) | 70.4 | -- | 70.8 | [model](https://drive.google.com/file/d/1aAsj_N1LAMpmmcb1iOxo2z66tJM6MEuM/view?usp=sharing) |
| Model | YouTube-VOS 2018 (Overall Score) | YouTube-VOS 2019 (Overall Score) | DAVIS 2017 val (J&F score) | Links |
|:----------------------------------------------------:|:--------------------------------:|:--------------------------------:|:--------------------------:|:-----:|
| [LWL_ytvos](ltr/train_settings/lwl/lwl_stage2.py) | 81.5 | 81.0 | -- | [model](https://drive.google.com/file/d/1Xnm4A2BRBliDBKO4EEFHAQfGyfOMsVyY/view?usp=sharing) |
| [LWL_boxinit](ltr/train_settings/lwl/lwl_boxinit.py) | 70.4 | -- | 70.8 | [model](https://drive.google.com/file/d/1aAsj_N1LAMpmmcb1iOxo2z66tJM6MEuM/view?usp=sharing) |
| RTS | -- | 79.7 | 80.2 | [model](https://drive.google.com/drive/folders/1uMQqeRN4RbaeF3IsRpvjvk6D5K8khpSW?usp=sharing) |
| RTS (Box) | -- | 70.8 | 72.6 | [model](https://drive.google.com/drive/folders/1uMQqeRN4RbaeF3IsRpvjvk6D5K8khpSW?usp=sharing) |


### Raw Results
Expand Down
22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# PyTracking
A general python framework for visual object tracking and video object segmentation, based on **PyTorch**.

### :fire: One tracking paper accepted at ECCV 2022! 👇
* [Robust Visual Tracking by Segmentation](https://arxiv.org/abs/2203.11191) | **Code available!**

### :fire: We released AVisT a new tracking dataset for adverse visibility! 👇
* [AVisT: A Benchmark for Visual Object Tracking in Adverse Visibility
](https://arxiv.org/abs/2208.06888) | The [dataset](https://sites.google.com/view/avist-benchmark), the integration [avistdataset.py](pytracking/evaluation/avistdataset.py) and the evaluation code [analyze_avist_results.ipynb](pytracking/notebooks/analyze_avist_results.ipynb) are available!
Expand All @@ -19,9 +22,9 @@ A general python framework for visual object tracking and video object segmentat

## Highlights

### ToMP, KeepTrack, LWL, KYS, PrDiMP, DiMP and ATOM Trackers
### RTS, ToMP, KeepTrack, LWL, KYS, PrDiMP, DiMP and ATOM Trackers

Official implementation of the **ToMP** (CVPR 2022), **KeepTrack** (ICCV 2021), **LWL** (ECCV 2020), **KYS** (ECCV 2020), **PrDiMP** (CVPR 2020),
Official implementation of the **RTS** (ECCV 2022), **ToMP** (CVPR 2022), **KeepTrack** (ICCV 2021), **LWL** (ECCV 2020), **KYS** (ECCV 2020), **PrDiMP** (CVPR 2020),
**DiMP** (ICCV 2019), and **ATOM** (CVPR 2019) trackers, including complete **training code** and trained models.

### [Tracking Libraries](pytracking)
Expand Down Expand Up @@ -50,13 +53,25 @@ benchmarks are provided in the [model zoo](MODEL_ZOO.md).
## Trackers
The toolkit contains the implementation of the following trackers.

### RTS (ECCV 2022)

**[[Paper]](https://arxiv.org/abs/2203.11191) [[Raw results]](MODEL_ZOO.md#Raw-Results-1)
[[Models]](MODEL_ZOO.md#Models-1) [[Training Code]](./ltr/README.md#RTS) [[Tracker Code]](./pytracking/README.md#RTS)**

Official implementation of **RTS**. RTS is a robust, end-to-end trainable, segmentation-centric pipeline that internally
works with segmentation masks instead of bounding boxes. Thus, it can learn a better target representation that clearly
differentiates the target from the background. To achieve the necessary robustness for challenging tracking scenarios,
a separate instance localization component is used to condition the segmentation decoder when producing the output mask.

![RTS_teaser_figure](pytracking/.figs/rts_overview.png)

### ToMP (CVPR 2022)

**[[Paper]](https://arxiv.org/abs/2203.11192) [[Raw results]](MODEL_ZOO.md#Raw-Results-1)
[[Models]](MODEL_ZOO.md#Models-1) [[Training Code]](./ltr/README.md#ToMP) [[Tracker Code]](./pytracking/README.md#ToMP)**

Official implementation of **ToMP**. ToMP employs a Transformer-based
model prediction module in order to localize the target. The model predictor is further exteneded to estimate a second set
model prediction module in order to localize the target. The model predictor is further extended to estimate a second set
of weights that are applied for accurate bounding box regression.
The resulting tracker ToMP relies on training and on test frame information in order to predict all weights transductively.

Expand Down Expand Up @@ -191,6 +206,7 @@ python run_webcam.py dimp dimp50
* [Martin Danelljan](https://martin-danelljan.github.io/)
* [Goutam Bhat](https://goutamgmb.github.io/)
* [Christoph Mayer](https://2006pmach.github.io/)
* [Matthieu Paul](https://github.com/mattpfr)

### Guest Contributors
* [Felix Järemo-Lawin](https://liu.se/en/employee/felja34) [LWL]
Expand Down
15 changes: 14 additions & 1 deletion ltr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ A general PyTorch based framework for learning tracking representations.
* [Quick Start](#quick-start)
* [Overview](#overview)
* [Trackers](#trackers)
* [RTS](#RTS)
* [ToMP](#ToMP)
* [KeepTrack](#KeepTrack)
* [LWL](#LWL)
* [KYS](#KYS)
Expand All @@ -30,7 +32,7 @@ python run_training bbreg atom_default


## Overview
The framework consists of the following sub-modules.
The framework consists of the following submodules.
- [actors](actors): Contains the actor classes for different trainings. The actor class is responsible for passing the input data through the network can calculating losses.
- [admin](admin): Includes functions for loading networks, tensorboard etc. and also contains environment settings.
- [dataset](dataset): Contains integration of a number of training datasets, namely [TrackingNet](https://tracking-net.org/), [GOT-10k](http://got-10k.aitestunion.com/), [LaSOT](http://vision.cs.stonybrook.edu/~lasot/),
Expand All @@ -45,6 +47,17 @@ The framework consists of the following sub-modules.
## Trackers
The framework currently contains the training code for the following trackers.

### RTS
Three steps are required to train RTS:
- Download [lasot_got10k_pregenerated_masks.zip](https://drive.google.com/file/d/1p3vSWd_kcwoLdiw1fg24Qd4V4eiRFJp4/view?usp=sharing).

Unzip the archive in the `pregenerated_masks` set in `ltr/admin/local.py`.
- Download the pretrained LWL weights [lwl_stage2.pth](https://drive.google.com/file/d/1Xnm4A2BRBliDBKO4EEFHAQfGyfOMsVyY/view?usp=sharing).

Save the weights in the `pretrained_networks` set in `ltr/admin/local.py`.
- Use this setting for training with ResNet50 backbone: [rts.rts50](train_settings/rts/rts50.py)


### ToMP
The following setting files can be used to train the ToMP tracker. We omit training with a separate test encoding since the training is more stable but leads to comparable performance. Set the flag to false to use the same setup as in the paper.
- [tomp.tomp50](train_settings/tomp/tomp50.py): The default setting use for training with ResNet50 backbone.
Expand Down
178 changes: 177 additions & 1 deletion ltr/actors/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import BaseActor
import torch
import torch.nn as nn

import numpy as np
from pytracking.analysis.vos_utils import davis_jaccard_measure


Expand Down Expand Up @@ -138,3 +138,179 @@ def __call__(self, data):
stats['Stats/acc_box_train'] = acc_box/cnt_box

return loss, stats


class RTSActor(BaseActor):
"""Based on the LWL and DiMP actor, supporting both segmentation and classification terms """
def __init__(self, net, objective, loss_weight=None,
num_refinement_iter=3,
disable_backbone_bn=False,
disable_all_bn=False):
"""
args:
net - The network model to train
objective - Loss functions
loss_weight - Weights for each training loss
num_refinement_iter - Number of update iterations N^{train}_{update} used to update the target model in
each frame
disable_backbone_bn - If True, all batch norm layers in the backbone feature extractor are disabled, i.e.
set to eval mode.
disable_all_bn - If True, all the batch norm layers in network are disabled, i.e. set to eval mode.
"""
super().__init__(net, objective)
if loss_weight is None:
loss_weight = {'segm': 1.0}
self.loss_weight = loss_weight

self.num_refinement_iter = num_refinement_iter
self.disable_backbone_bn = disable_backbone_bn
self.disable_all_bn = disable_all_bn

def train(self, mode=True):
""" Set whether the network is in train mode.
args:
mode (True) - Bool specifying whether in training mode.
"""
self.net.train(mode)

if self.disable_all_bn:
self.net.eval()
elif self.disable_backbone_bn:
for m in self.net.feature_extractor.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()

def __call__(self, data):
"""
args:
data - The input data, should contain the fields 'train_images', 'test_images', 'train_masks',
'test_masks'
returns:
loss - the training loss
stats - dict containing detailed losses
"""
segm_pred, target_scores = self.net(train_imgs=data['train_images'],
test_imgs=data['test_images'],
train_masks=data['train_masks'],
test_masks=data['test_masks'],
train_bb=data['train_anno'],
train_label=data['train_label'],
test_label=data['test_label'],
num_refinement_iter=self.num_refinement_iter)

# Segmentation Loss
################################################################
acc = 0
cnt = 0

segm_pred = segm_pred.view(-1, 1, *segm_pred.shape[-2:])
gt_segm = data['test_masks']
gt_segm = gt_segm.view(-1, 1, *gt_segm.shape[-2:])

loss_segm = self.loss_weight['segm'] * self.objective['segm'](segm_pred, gt_segm)

acc_l = [davis_jaccard_measure(torch.sigmoid(rm.detach()).cpu().numpy() > 0.5, lb.cpu().numpy()) for
rm, lb in zip(segm_pred.view(-1, *segm_pred.shape[-2:]), gt_segm.view(-1, *segm_pred.shape[-2:]))]
acc += sum(acc_l)
cnt += len(acc_l)


if torch.isinf(loss_segm) or torch.isnan(loss_segm):
raise Exception('ERROR: Segm Loss was nan or inf!!!')


# Classification Loss
#################################################################

threshold = 0.25

last_target_score = target_scores[-1]
label = data['test_label']

last_target_score_reshaped = last_target_score.view(-1, last_target_score.shape[-2] * last_target_score.shape[-1])
label_reshaped = label.view(-1, label.shape[-2] * label.shape[-1])
prediction_max_vals, pred_argmax_ids = last_target_score_reshaped.max(dim=1)
label_max_vals, label_argmax_ids = label_reshaped.max(dim=1)

label_val_at_peak = label_reshaped[torch.arange(len(pred_argmax_ids)), pred_argmax_ids]
label_val_at_peak = torch.max(label_val_at_peak, torch.zeros_like(label_val_at_peak))

prediction_correct = ((label_val_at_peak >= threshold) & (label_max_vals > 0.25)) | ((label_val_at_peak < threshold) & (label_max_vals < 0.25))
prediction_accuracy = prediction_correct.float().mean()

# Peak to Peak distance
n_samples = label.shape[0] * label.shape[1]
n_pixels = label.shape[2] * label.shape[3]

peak_dist = np.zeros(n_samples)

assert(last_target_score.shape == label.shape)

for sample in range(0, n_samples):
pred_idx = pred_argmax_ids[sample].cpu()
label_idx = label_argmax_ids[sample].cpu()
peak_pred = np.array([pred_idx // label.shape[-1], pred_idx % label.shape[-1]])
peak_label = np.array([label_idx // label.shape[-1], label_idx % label.shape[-1]])
peak_dist[sample] = np.linalg.norm(peak_label - peak_pred)

peak_dist = peak_dist.mean()

# Compute loss
loss_test_init_clf = 0
loss_test_iter_clf = 0

# Classification losses for the different optimization iterations
clf_losses_test = [self.objective['test_clf'](
s, data['test_label'], data['test_anno']) for s in target_scores]

# Loss of the final filter
clf_loss_test = clf_losses_test[-1]
loss_target_classifier = self.loss_weight['test_clf'] * clf_loss_test

# Loss for the initial filter iteration
if 'test_init_clf' in self.loss_weight.keys():
loss_test_init_clf = self.loss_weight['test_init_clf'] * clf_losses_test[0]

# Loss for the intermediate filter iterations
if 'test_iter_clf' in self.loss_weight.keys():
test_iter_weights = self.loss_weight['test_iter_clf']
if isinstance(test_iter_weights, list):
loss_test_iter_clf = sum([a * b for a, b in zip(test_iter_weights, clf_losses_test[1:-1])])
else:
loss_test_iter_clf = (test_iter_weights / (len(clf_losses_test) - 2)) * sum(clf_losses_test[1:-1])

# Total loss
loss_classifier = loss_target_classifier + loss_test_init_clf + loss_test_iter_clf

if torch.isinf(loss_classifier) or torch.isnan(loss_classifier):
raise Exception('ERROR: Classifier Loss was nan or inf!!!')

# TOTAL LOSS
loss = loss_segm + loss_classifier

# Log stats Segmentation
stats = {
'Loss/total': loss.item(),
'Loss/segm': loss_segm.item(),
'Stats/acc': acc / cnt,
'Stats/clf_acc': prediction_accuracy,
'Stats/clf_peak_dist': peak_dist,
}

# Log stats Classification
if 'test_clf' in self.loss_weight.keys():
stats['Loss/target_clf'] = loss_target_classifier.item()
if 'test_init_clf' in self.loss_weight.keys():
stats['Loss/test_init_clf'] = loss_test_init_clf.item()
if 'test_iter_clf' in self.loss_weight.keys():
stats['Loss/test_iter_clf'] = loss_test_iter_clf.item()

if 'test_clf' in self.loss_weight.keys():
stats['ClfTrain/test_loss'] = clf_loss_test.item()
if len(clf_losses_test) > 0:
stats['ClfTrain/test_init_loss'] = clf_losses_test[0].item()
if len(clf_losses_test) > 2:
stats['ClfTrain/test_iter_loss'] = sum(clf_losses_test[1:-1]).item() / (len(clf_losses_test) - 2)

return loss, stats

1 change: 1 addition & 0 deletions ltr/admin/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def create_default_local_file():
'workspace_dir': empty_str,
'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'',
'pretrained_networks': 'self.workspace_dir + \'/pretrained_networks/\'',
'pregenerated_masks': empty_str,
'lasot_dir': empty_str,
'got10k_dir': empty_str,
'trackingnet_dir': empty_str,
Expand Down
Loading

0 comments on commit d6cbd67

Please sign in to comment.