-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit e934b1c
Showing
20 changed files
with
2,033 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
|
||
# Created by https://www.gitignore.io/api/linux,python,pycharm | ||
# Edit at https://www.gitignore.io/?templates=linux,python,pycharm | ||
|
||
### Linux ### | ||
*~ | ||
|
||
# temporary files which can be created if a process still has a handle open of a deleted file | ||
.fuse_hidden* | ||
|
||
# KDE directory preferences | ||
.directory | ||
|
||
# Linux trash folder which might appear on any partition or disk | ||
.Trash-* | ||
|
||
# .nfs files are created when an open file is removed but is still being accessed | ||
.nfs* | ||
|
||
### PyCharm ### | ||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm | ||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | ||
/.idea/ | ||
|
||
# CMake | ||
cmake-build-*/ | ||
|
||
# File-based project format | ||
*.iws | ||
|
||
# IntelliJ | ||
out/ | ||
|
||
# JIRA plugin | ||
atlassian-ide-plugin.xml | ||
|
||
# Crashlytics plugin (for Android Studio and IntelliJ) | ||
com_crashlytics_export_strings.xml | ||
crashlytics.properties | ||
crashlytics-build.properties | ||
fabric.properties | ||
|
||
### Python ### | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# Mr Developer | ||
.mr.developer.cfg | ||
.project | ||
.pydevproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# End of https://www.gitignore.io/api/linux,python,pycharm | ||
|
||
/results/ | ||
/data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
BSD 3-Clause License | ||
|
||
Copyright (c) 2019, Jerome Rony | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
1. Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
2. Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
3. Neither the name of the copyright holder nor the names of its | ||
contributors may be used to endorse or promote products derived from | ||
this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
## Requirements for the experiments | ||
|
||
- scikit-learn | ||
- pytorch >= 1.4 | ||
- sacred >= 0.8 | ||
- tqdm | ||
- visdom_logger https://github.com/luizgh/visdom_logger | ||
- faiss https://github.com/facebookresearch/faiss | ||
|
||
## Data management | ||
|
||
For In-Shop, you need to manually download the data from https://drive.google.com/drive/folders/0B7EVK8r0v71pVDZFQXRsMDZCX1E (at least the `img.zip` and `list_eval_partition.txt`), put them in `data/InShop` and extract `img.zip`. | ||
|
||
You can download and generate the `train.txt` and `test.txt` for every dataset using the `prepare_data.py` script with: | ||
```bash | ||
python prepare_data.py | ||
``` | ||
This will download and prepare all the necessary data for _CUB200_, _Cars-196_ and _Stanford Online Products_. | ||
|
||
## Usage | ||
|
||
This repo uses `sacred` to manage the experiments. | ||
To run an experiment (e.g. on CUB200): | ||
|
||
```bash | ||
python experiment.py with dataset.cub | ||
``` | ||
|
||
You can add an observer to save the metrics and files related to the expriment by adding `-F result_dir`: | ||
|
||
```bash | ||
python experiment.py -F result_dir with dataset.cub | ||
``` | ||
|
||
## Reproducing the results of the paper | ||
|
||
CUB200 | ||
```bash | ||
python experiment.py with dataset.cub model.resnet50 epochs=30 lr=0.02 | ||
``` | ||
|
||
CARS-196 | ||
```bash | ||
python experiment.py with dataset.cars model.resnet50 epochs=100 lr=0.05 model.norm_layer=batch | ||
``` | ||
|
||
Stanford Online Products | ||
```bash | ||
python experiment.py with dataset.sop model.resnet50 epochs=100 lr=0.01 momentum=0.9 nesterov=True model.norm_layer=batch | ||
``` | ||
|
||
In-Shop | ||
```bash | ||
python experiment.py with dataset.inshop model.resnet50 epochs=100 lr=0.01 momentum=0.9 nesterov=True model.norm_layer=batch | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import math | ||
import os | ||
from copy import deepcopy | ||
from functools import partial | ||
from pprint import pprint | ||
|
||
import sacred | ||
import torch | ||
import torch.nn as nn | ||
from sacred import SETTINGS | ||
from sacred.utils import apply_backspaces_and_linefeeds | ||
from torch.backends import cudnn | ||
from torch.optim import SGD, lr_scheduler | ||
from visdom_logger import VisdomLogger | ||
|
||
from models.ingredient import model_ingredient, get_model | ||
from utils import state_dict_to_cpu, SmoothCrossEntropy | ||
from utils.data.dataset_ingredient import data_ingredient, get_loaders | ||
from utils.training import train, evaluate | ||
|
||
ex = sacred.Experiment('Metric Learning', ingredients=[data_ingredient, model_ingredient]) | ||
# Filter backspaces and linefeeds | ||
SETTINGS.CAPTURE_MODE = 'sys' | ||
ex.captured_out_filter = apply_backspaces_and_linefeeds | ||
|
||
|
||
@ex.config | ||
def config(): | ||
epochs = 20 | ||
lr = 0.02 | ||
momentum = 0. | ||
nesterov = False | ||
weight_decay = 5e-4 | ||
scheduler = 'warmcos' | ||
|
||
visdom_port = None | ||
visdom_freq = 20 | ||
cpu = False # Force training on CPU | ||
cudnn_flag = 'benchmark' | ||
temp_dir = os.path.join('results', 'temp') | ||
|
||
no_bias_decay = True | ||
label_smoothing = 0.1 | ||
temperature = 1. | ||
|
||
|
||
@ex.capture | ||
def get_optimizer_scheduler(parameters, loader_length, epochs, lr, momentum, nesterov, weight_decay, scheduler, | ||
lr_step=None): | ||
optimizer = SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay, | ||
nesterov=True if nesterov and momentum else False) | ||
if epochs == 0: | ||
scheduler = None | ||
elif scheduler == 'cos': | ||
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * loader_length, eta_min=0) | ||
elif scheduler == 'warmcos': | ||
warm_cosine = lambda i: min((i + 1) / 100, (1 + math.cos(math.pi * i / (epochs * loader_length))) / 2) | ||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_cosine) | ||
elif scheduler == 'step': | ||
scheduler = lr_scheduler.StepLR(optimizer, lr_step * loader_length) | ||
elif scheduler == 'warmstep': | ||
warm_step = lambda i: min((i + 1) / 100, 1) * 0.1 ** (i // (lr_step * loader_length)) | ||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_step) | ||
else: | ||
scheduler = lr_scheduler.StepLR(optimizer, epochs * loader_length) | ||
return optimizer, scheduler | ||
|
||
|
||
@ex.automain | ||
def main(epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_bias_decay, label_smoothing, | ||
temperature): | ||
device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu') | ||
callback = VisdomLogger(port=visdom_port) if visdom_port else None | ||
if cudnn_flag == 'deterministic': | ||
setattr(cudnn, cudnn_flag, True) | ||
|
||
torch.manual_seed(seed) | ||
loaders, recall_ks = get_loaders() | ||
|
||
torch.manual_seed(seed) | ||
model = get_model(num_classes=loaders.num_classes) | ||
class_loss = SmoothCrossEntropy(epsilon=label_smoothing, temperature=temperature) | ||
|
||
model.to(device) | ||
if torch.cuda.device_count() > 1: | ||
model = nn.DataParallel(model) | ||
parameters = [] | ||
if no_bias_decay: | ||
parameters.append({'params': [par for par in model.parameters() if par.dim() != 1]}) | ||
parameters.append({'params': [par for par in model.parameters() if par.dim() == 1], 'weight_decay': 0}) | ||
else: | ||
parameters.append({'params': model.parameters()}) | ||
optimizer, scheduler = get_optimizer_scheduler(parameters=parameters, loader_length=len(loaders.train)) | ||
|
||
# setup partial function to simplify call | ||
eval_function = partial(evaluate, model=model, recall=recall_ks, query_loader=loaders.query, | ||
gallery_loader=loaders.gallery) | ||
|
||
# setup best validation logger | ||
metrics = eval_function() | ||
if callback is not None: | ||
callback.scalars(['l2', 'cosine'], 0, [metrics.recall['l2'][1], metrics.recall['cosine'][1]], | ||
title='Val Recall@1') | ||
pprint(metrics.recall) | ||
best_val = (0, metrics.recall, deepcopy(model.state_dict())) | ||
|
||
torch.manual_seed(seed) | ||
for epoch in range(epochs): | ||
if cudnn_flag == 'benchmark': | ||
setattr(cudnn, cudnn_flag, True) | ||
|
||
train(model=model, loader=loaders.train, class_loss=class_loss, optimizer=optimizer, | ||
scheduler=scheduler, epoch=epoch, callback=callback, freq=visdom_freq, ex=ex) | ||
|
||
# validation | ||
if cudnn_flag == 'benchmark': | ||
setattr(cudnn, cudnn_flag, False) | ||
metrics = eval_function() | ||
print('Validation [{:03d}]'.format(epoch)), pprint(metrics.recall) | ||
ex.log_scalar('val.recall_l2@1', metrics.recall['l2'][1], step=epoch + 1) | ||
ex.log_scalar('val.recall_cosine@1', metrics.recall['cosine'][1], step=epoch + 1) | ||
|
||
if callback is not None: | ||
callback.scalars(['l2', 'cosine'], epoch + 1, | ||
[metrics.recall['l2'][1], metrics.recall['cosine'][1]], title='Val Recall') | ||
|
||
# save model dict if the chosen validation metric is better | ||
if metrics.recall['cosine'][1] >= best_val[1]['cosine'][1]: | ||
best_val = (epoch + 1, metrics.recall, deepcopy(model.state_dict())) | ||
|
||
# logging | ||
ex.info['recall'] = best_val[1] | ||
|
||
# saving | ||
save_name = os.path.join(temp_dir, '{}_{}.pt'.format(ex.current_run.config['model']['arch'], | ||
ex.current_run.config['dataset']['name'])) | ||
torch.save(state_dict_to_cpu(best_val[2]), save_name) | ||
ex.add_artifact(save_name) | ||
|
||
if callback is not None: | ||
save_name = os.path.join(temp_dir, 'visdom_data.pt') | ||
callback.save(save_name) | ||
ex.add_artifact(save_name) | ||
|
||
return best_val[1]['cosine'][1] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .bn_inception import bninception | ||
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, \ | ||
wide_resnet50_2, wide_resnet101_2 | ||
|
||
__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', | ||
'wide_resnet50_2', 'wide_resnet101_2', 'bninception'] |
Oops, something went wrong.