Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
yinhaoxs authored Nov 23, 2019
1 parent c56552d commit 175eda8
Show file tree
Hide file tree
Showing 16 changed files with 2,452 additions and 0 deletions.
392 changes: 392 additions & 0 deletions cirtorch/networks/imageretrievalnet_cpu.py

Large diffs are not rendered by default.

605 changes: 605 additions & 0 deletions classify.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
websites:
host: 0.0.0.0
port: 15788

model:
network: /*.pth
model_dir: /*
type: [SA,SB]
43 changes: 43 additions & 0 deletions image_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# coding=utf-8
# /usr/bin/env pythpn

'''
Author: yinhao
Email: yinhao_x@163.com
Wechat: xss_yinhao
Github: http://github.com/yinhaoxs
data: 2019-11-23 18:25
desc:
'''

from retrieval_feature import *
from retrieval_index import *
from classify import class_results
from PIL import Image
from PIL import Image, ImageFile, TiffImagePlugin

'''
ImageFile.LOAD_TRUNCATED_IMAGES=True
TiffImagePlugin.READ_LIBTIFF=True
Image.DEBUG=True
'''

def main(img_dir, network, hash_size, input_dim, num_hashtables, feature_path, index_path, out_similar_dir, out_similar_file_dir, all_csv_file, num_results):
# classify
class_results(img_dir)

# extract feature
AntiFraudFeatureDataset(img_dir, network, feature_path, index_path).constructfeature(hash_size, input_dim, num_hashtables)

# similar index
EvaluteMap(out_similar_dir, out_similar_file_dir, all_csv_file, feature_path, index_path).retrieval_images()


if __name__ == "__main__":
pass





255 changes: 255 additions & 0 deletions interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# coding=utf-8
# /usr/bin/env pythpn

'''
Author: yinhao
Email: yinhao_x@163.com
Wechat: xss_yinhao
Github: http://github.com/yinhaoxs
data: 2019-11-23 21:51
desc:
'''
import torch
from torch.utils.model_zoo import load_url
from torchvision import transforms
from cirtorch.datasets.testdataset import configdataset
from cirtorch.utils.download import download_train, download_test
from cirtorch.utils.evaluate import compute_map_and_print
from cirtorch.utils.general import get_data_root, htime
from cirtorch.networks.imageretrievalnet_cpu import init_network, extract_vectors
from cirtorch.datasets.datahelpers import imresize

from PIL import Image
import numpy as np
import pandas as pd
from flask import Flask, request
import json, io, sys, time, traceback, argparse, logging, subprocess, pickle, os, yaml,shutil
import cv2
import pdb
from werkzeug.utils import cached_property
from apscheduler.schedulers.background import BackgroundScheduler
from multiprocessing import Pool

app = Flask(__name__)

@app.route("/")
def index():
return ""

@app.route("/images/*", methods=['GET','POST'])
def accInsurance():
"""
flask request process handle
:return:
"""
try:
if request.method == 'GET':
return json.dumps({'err': 1, 'msg': 'POST only'})
else:
app.logger.debug("print headers------")
headers = request.headers
headers_info = ""
for k, v in headers.items():
headers_info += "{}: {}\n".format(k, v)
app.logger.debug(headers_info)

app.logger.debug("print forms------")
forms_info = ""
for k, v in request.form.items():
forms_info += "{}: {}\n".format(k, v)
app.logger.debug(forms_info)

if 'query' not in request.files:
return json.dumps({'err': 2, 'msg': 'query image is empty'})

if 'sig' not in request.form:
return json.dumps({'err': 3, 'msg': 'sig is empty'})

if 'q_no' not in request.form:
return json.dumps({'err': 4, 'msg': 'no is empty'})

if 'q_did' not in request.form:
return json.dumps({'err': 5, 'msg': 'did is empty'})

if 'q_id' not in request.form:
return json.dumps({'err': 6, 'msg': 'id is empty'})

if 'type' not in request.form:
return json.dumps({'err': 7, 'msg': 'type is empty'})

img_name = request.files['query'].filename
img_bytes = request.files['query'].read()
img = request.files['query']
sig = request.form['sig']
q_no = request.form['q_no']
q_did = request.form['q_did']
q_id = request.form['q_id']
type = request.form['type']

if str(type) not in types:
return json.dumps({'err': 8, 'msg': 'type is not exist'})

if img_bytes is None:
return json.dumps({'err': 10, 'msg': 'img is none'})

results = imageRetrieval().retrieval_online_v0(img, q_no, q_did, q_id, type)

data = dict()
data['query'] = img_name
data['sig'] = sig
data['type'] = type
data['q_no'] = q_no
data['q_did'] = q_did
data['q_id'] = q_id
data['results'] = results

return json.dumps({'err': 0, 'msg': 'success', 'data': data})

except:
app.logger.exception(sys.exc_info())
return json.dumps({'err': 9, 'msg': 'unknow error'})


class imageRetrieval():
def __init__(self):
pass

def cosine_dist(self, x, y):
return 100 * float(np.dot(x, y))/(np.dot(x,x)*np.dot(y,y)) ** 0.5

def inference(self, img):
try:
input = Image.open(img).convert("RGB")
input = imresize(input, 224)
input = transforms(input).unsqueeze()
with torch.no_grad():
vect = net(input)
return vect
except:
print('cannot indentify error')

def retrieval_online_v0(self, img, q_no, q_did, q_id, type):
# load model
query_vect = self.inference(img)
query_vect = list(query_vect.detach().numpy().T[0])

lsh = lsh_dict[str(type)]
response = lsh.query(query_vect, num_results=1, distance_func = "cosine")

try:
similar_path = response[0][0][1]
score = np.rint(self.cosine_dist(list(query_vect), list(response[0][0][0])))
rank_list = similar_path.split("/")
s_id, s_did, s_no = rank_list[-1].split("_")[-1].split(".")[0], rank_list[-1].split("_")[0], rank_list[-2]
results = [{"s_no": s_no, "r_did": s_did, "s_id": s_id, "score": score}]
except:
results = []

img_path = "/{}/{}_{}".format(q_no, q_did, q_id)
lsh.index(query_vect, extra_data=img_path)
lsh_dict[str(type)] = lsh

return results



class initModel():
def __init__(self):
pass

def init_model(self, network, model_dir, types):
print(">> Loading network:\n>>>> '{}'".format(network))
# state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
state = torch.load(network)
# parsing net params from meta
# architecture, pooling, mean, std required
# the rest has default values, in case that is doesnt exist
net_params = {}
net_params['architecture'] = state['meta']['architecture']
net_params['pooling'] = state['meta']['pooling']
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
net_params['regional'] = state['meta'].get('regional', False)
net_params['whitening'] = state['meta'].get('whitening', False)
net_params['mean'] = state['meta']['mean']
net_params['std'] = state['meta']['std']
net_params['pretrained'] = False
# network initialization
net = init_network(net_params)
net.load_state_dict(state['state_dict'])
print(">>>> loaded network: ")
print(net.meta_repr())
# moving network to gpu and eval mode
# net.cuda()
net.eval()

# set up the transform
normalize = transforms.Normalize(
mean=net.meta['mean'],
std=net.meta['std']
)
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])

lsh_dict = dict()
for type in types:
with open(os.path.join(model_dir, "dataset_index_{}.pkl".format(str(type))), "rb") as f:
lsh = pickle.load(f)

lsh_dict[str(type)] = lsh

return net, lsh_dict, transforms

def init(self):
with open('config.yaml', 'r') as f:
conf = yaml.load(f)

app.logger.info(conf)
host = conf['website']['host']
port = conf['website']['port']
network = conf['model']['network']
model_dir = conf['model']['model_dir']
types = conf['model']['type']

net, lsh_dict, transforms = self.init_model(network, model_dir, types)

return host, port, net, lsh_dict, transforms, model_dir, types


def job():
for type in types:
with open(os.path.join(model_dir, "dataset_index_{}_v0.pkl".format(str(type))), "wb") as f:
pickle.dump(lsh_dict[str(type)], f)


if __name__ == "__main__":
"""
start app from ssh
"""
scheduler = BackgroundScheduler()
host, port, net, lsh_dict, transforms, model_dir, types = initModel().init()
app.run(host=host, port=port, debug=True)
print("start server {}:{}".format(host, port))

scheduler.add_job(job, 'interval', seconds= 30)
scheduler.start()

else:
"""
start app from gunicorn
"""
scheduler = BackgroundScheduler()
gunicorn_logger = logging.getLogger("gunicorn.error")
app.logger.handlers = gunicorn_logger.handlers
app.logger.setLevel(gunicorn_logger.level)

host, port, net, lsh_dict, transforms, model_dir, types = initModel().init()
app.logger.info("started from gunicorn...")

scheduler.add_job(job, 'interval', seconds=30)
scheduler.start()



33 changes: 33 additions & 0 deletions nts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# NTS-Net

This is a PyTorch implementation of the ECCV2018 paper "Learning to Navigate for Fine-grained Classification" (Ze Yang, Tiange Luo, Dong Wang, Zhiqiang Hu, Jun Gao, Liwei Wang).

## Requirements
- python 3+
- pytorch 0.4+
- numpy
- datetime

## Datasets
Download the [CUB-200-2011](http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) datasets and put it in the root directory named **CUB_200_2011**, You can also try other fine-grained datasets.

## Train the model
If you want to train the NTS-Net, just run ``python train.py``. You may need to change the configurations in ``config.py``. The parameter ``PROPOSAL_NUM`` is ``M`` in the original paper and the parameter ``CAT_NUM`` is ``K`` in the original paper. During training, the log file and checkpoint file will be saved in ``save_dir`` directory. You can change the parameter ``resume`` to choose the checkpoint model to resume.

## Test the model
If you want to test the NTS-Net, just run ``python test.py``. You need to specify the ``test_model`` in ``config.py`` to choose the checkpoint model for testing.

## Model
We also provide the checkpoint model trained by ourselves, you can download it from [here](https://drive.google.com/file/d/1F-eKqPRjlya5GH2HwTlLKNSPEUaxCu9H/view?usp=sharing). If you test on our provided model, you will get a 87.6% test accuracy.

## Reference
If you are interested in our work and want to cite it, please acknowledge the following paper:

```
@inproceedings{Yang2018Learning,
author = {Yang, Ze and Luo, Tiange and Wang, Dong and Hu, Zhiqiang and Gao, Jun and Wang, Liwei},
title = {Learning to Navigate for Fine-grained Classification},
booktitle = {ECCV},
year = {2018}
}
```
10 changes: 10 additions & 0 deletions nts/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
BATCH_SIZE = 16
PROPOSAL_NUM = 6
CAT_NUM = 4
INPUT_SIZE = (448, 448) # (w, h)
LR = 0.001
WD = 1e-4
SAVE_FREQ = 1
resume = ''
test_model = 'model.ckpt'
save_dir = '/data_4t/yangz/models/'
Loading

0 comments on commit 175eda8

Please sign in to comment.