-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
yinhaoxs
authored and
yinhaoxs
committed
Sep 14, 2020
1 parent
0f8ae1d
commit f28220d
Showing
7 changed files
with
326 additions
and
44 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# coding=utf-8 | ||
# /usr/bin/env pythpn | ||
|
||
''' | ||
Author: yinhao | ||
Email: yinhao_x@163.com | ||
Wechat: xss_yinhao | ||
Github: http://github.com/yinhaoxs | ||
data: 2019-11-23 18:26 | ||
desc: | ||
''' | ||
|
||
import os | ||
from PIL import Image | ||
from lshash.lshash import LSHash | ||
import torch | ||
from torchvision import transforms | ||
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors | ||
|
||
# setting up the visible GPU | ||
os.environ['CUDA_VISIBLE_DEVICES'] = "0" | ||
|
||
|
||
class ImageProcess(): | ||
def __init__(self, img_dir): | ||
self.img_dir = img_dir | ||
|
||
def process(self): | ||
imgs = list() | ||
for root, dirs, files in os.walk(self.img_dir): | ||
for file in files: | ||
img_path = os.path.join(root + os.sep, file) | ||
try: | ||
image = Image.open(img_path) | ||
if max(image.size) / min(image.size) < 5: | ||
imgs.append(img_path) | ||
else: | ||
continue | ||
except: | ||
print("image height/width ratio is small") | ||
|
||
return imgs | ||
|
||
|
||
class AntiFraudFeatureDataset(): | ||
def __init__(self, img_dir, network, feature_path='', index_path=''): | ||
self.img_dir = img_dir | ||
self.network = network | ||
self.feature_path = feature_path | ||
self.index_path = index_path | ||
|
||
def constructfeature(self, hash_size, input_dim, num_hashtables): | ||
multiscale = '[1]' | ||
print(">> Loading network:\n>>>> '{}'".format(self.network)) | ||
# state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks')) | ||
state = torch.load(self.network) | ||
# parsing net params from meta | ||
# architecture, pooling, mean, std required | ||
# the rest has default values, in case that is doesnt exist | ||
net_params = {} | ||
net_params['architecture'] = state['meta']['architecture'] | ||
net_params['pooling'] = state['meta']['pooling'] | ||
net_params['local_whitening'] = state['meta'].get('local_whitening', False) | ||
net_params['regional'] = state['meta'].get('regional', False) | ||
net_params['whitening'] = state['meta'].get('whitening', False) | ||
net_params['mean'] = state['meta']['mean'] | ||
net_params['std'] = state['meta']['std'] | ||
net_params['pretrained'] = False | ||
# network initialization | ||
net = init_network(net_params) | ||
net.load_state_dict(state['state_dict']) | ||
print(">>>> loaded network: ") | ||
print(net.meta_repr()) | ||
# setting up the multi-scale parameters | ||
ms = list(eval(multiscale)) | ||
print(">>>> Evaluating scales: {}".format(ms)) | ||
# moving network to gpu and eval mode | ||
if torch.cuda.is_available(): | ||
net.cuda() | ||
net.eval() | ||
|
||
# set up the transform | ||
normalize = transforms.Normalize( | ||
mean=net.meta['mean'], | ||
std=net.meta['std'] | ||
) | ||
transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
|
||
# extract database and query vectors | ||
print('>> database images...') | ||
images = ImageProcess(self.img_dir).process() | ||
vecs, img_paths = extract_vectors(net, images, 1024, transform, ms=ms) | ||
feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T))) | ||
# index | ||
lsh = LSHash(hash_size=int(hash_size), input_dim=int(input_dim), num_hashtables=int(num_hashtables)) | ||
for img_path, vec in feature_dict.items(): | ||
lsh.index(vec.flatten(), extra_data=img_path) | ||
|
||
# ## 保存索引模型 | ||
# with open(self.feature_path, "wb") as f: | ||
# pickle.dump(feature_dict, f) | ||
# with open(self.index_path, "wb") as f: | ||
# pickle.dump(lsh, f) | ||
|
||
print("extract feature is done") | ||
return feature_dict, lsh | ||
|
||
def test_feature(self): | ||
multiscale = '[1]' | ||
print(">> Loading network:\n>>>> '{}'".format(self.network)) | ||
# state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks')) | ||
state = torch.load(self.network) | ||
# parsing net params from meta | ||
# architecture, pooling, mean, std required | ||
# the rest has default values, in case that is doesnt exist | ||
net_params = {} | ||
net_params['architecture'] = state['meta']['architecture'] | ||
net_params['pooling'] = state['meta']['pooling'] | ||
net_params['local_whitening'] = state['meta'].get('local_whitening', False) | ||
net_params['regional'] = state['meta'].get('regional', False) | ||
net_params['whitening'] = state['meta'].get('whitening', False) | ||
net_params['mean'] = state['meta']['mean'] | ||
net_params['std'] = state['meta']['std'] | ||
net_params['pretrained'] = False | ||
# network initialization | ||
net = init_network(net_params) | ||
net.load_state_dict(state['state_dict']) | ||
print(">>>> loaded network: ") | ||
print(net.meta_repr()) | ||
# setting up the multi-scale parameters | ||
ms = list(eval(multiscale)) | ||
print(">>>> Evaluating scales: {}".format(ms)) | ||
# moving network to gpu and eval mode | ||
if torch.cuda.is_available(): | ||
net.cuda() | ||
net.eval() | ||
|
||
# set up the transform | ||
normalize = transforms.Normalize( | ||
mean=net.meta['mean'], | ||
std=net.meta['std'] | ||
) | ||
transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
|
||
# extract database and query vectors | ||
print('>> database images...') | ||
images = ImageProcess(self.img_dir).process() | ||
vecs, img_paths = extract_vectors(net, images, 1024, transform, ms=ms) | ||
feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T))) | ||
return feature_dict | ||
|
||
|
||
if __name__ == '__main__': | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# coding=utf-8 | ||
# /usr/bin/env pythpn | ||
|
||
''' | ||
Author: yinhao | ||
Email: yinhao_x@163.com | ||
Wechat: xss_yinhao | ||
Github: http://github.com/yinhaoxs | ||
data: 2019-11-23 18:27 | ||
desc: | ||
''' | ||
|
||
import os | ||
import shutil | ||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
class EvaluteMap(): | ||
def __init__(self, out_similar_dir='', out_similar_file_dir='', all_csv_file='', feature_path='', index_path=''): | ||
self.out_similar_dir = out_similar_dir | ||
self.out_similar_file_dir = out_similar_file_dir | ||
self.all_csv_file = all_csv_file | ||
self.feature_path = feature_path | ||
self.index_path = index_path | ||
|
||
|
||
def get_dict(self, query_no, query_id, simi_no, simi_id, num, score): | ||
new_dict = { | ||
'index': str(num), | ||
'id1': str(query_id), | ||
'id2': str(simi_id), | ||
'no1': str(query_no), | ||
'no2': str(simi_no), | ||
'score': score | ||
} | ||
return new_dict | ||
|
||
|
||
def find_similar_img_gyz(self, feature_dict, lsh, num_results): | ||
for q_path, q_vec in feature_dict.items(): | ||
try: | ||
response = lsh.query(q_vec.flatten(), num_results=int(num_results), distance_func="cosine") | ||
query_img_path0 = response[0][0][1] | ||
query_img_path1 = response[1][0][1] | ||
query_img_path2 = response[2][0][1] | ||
# score0 = response[0][1] | ||
# score0 = np.rint(100 * (1 - score0)) | ||
print('**********************************************') | ||
print('input img: {}'.format(q_path)) | ||
print('query0 img: {}'.format(query_img_path0)) | ||
print('query1 img: {}'.format(query_img_path1)) | ||
print('query2 img: {}'.format(query_img_path2)) | ||
except: | ||
continue | ||
|
||
|
||
def find_similar_img(self, feature_dict, lsh, num_results): | ||
num = 0 | ||
result_list = list() | ||
for q_path, q_vec in feature_dict.items(): | ||
response = lsh.query(q_vec.flatten(), num_results=int(num_results), distance_func="cosine") | ||
s_path_list, s_vec_list, s_id_list, s_no_list, score_list = list(), list(), list(), list(), list() | ||
q_path = q_path[0] | ||
q_no, q_id = q_path.split("\\")[-2], q_path.split("\\")[-1] | ||
try: | ||
for i in range(int(num_results)): | ||
s_path, s_vec = response[i][0][1], response[i][0][0] | ||
s_path = s_path[0] | ||
s_no, s_id = s_path.split("\\")[-2], s_path.split("\\")[-1] | ||
if str(s_no) != str(q_no): | ||
score = np.rint(100 * (1 - response[i][1])) | ||
s_path_list.append(s_path) | ||
s_vec_list.append(s_vec) | ||
s_id_list.append(s_id) | ||
s_no_list.append(s_no) | ||
score_list.append(score) | ||
else: | ||
continue | ||
|
||
if len(s_path_list) != 0: | ||
index = score_list.index(max(score_list)) | ||
s_path, s_vec, s_id, s_no, score = s_path_list[index], s_vec_list[index], s_id_list[index], \ | ||
s_no_list[index], score_list[index] | ||
else: | ||
s_path, s_vec, s_id, s_no, score = None, None, None, None, None | ||
except: | ||
s_path, s_vec, s_id, s_no, score = None, None, None, None, None | ||
|
||
try: | ||
##拷贝文件到指定文件夹 | ||
num += 1 | ||
des_path = os.path.join(self.out_similar_dir, str(num)) | ||
if not os.path.exists(des_path): | ||
os.makedirs(des_path) | ||
shutil.copy(q_path, des_path) | ||
os.rename(os.path.join(des_path, q_id), os.path.join(des_path, "query_" + q_no + "_" + q_id)) | ||
if s_path != None: | ||
shutil.copy(s_path, des_path) | ||
os.rename(os.path.join(des_path, s_id), os.path.join(des_path, s_no + "_" + s_id)) | ||
|
||
new_dict = self.get_dict(q_no, q_id, s_no, s_id, num, score) | ||
result_list.append(new_dict) | ||
except: | ||
continue | ||
|
||
try: | ||
result_s = pd.DataFrame.from_dict(result_list) | ||
result_s.to_csv(self.all_csv_file, encoding="gbk", index=False) | ||
except: | ||
print("write error") | ||
|
||
|
||
def filter_gap_score(self): | ||
for value in range(90, 101): | ||
try: | ||
pd_df = pd.read_csv(self.all_csv_file, encoding="gbk", error_bad_lines=False) | ||
pd_tmp = pd_df[pd_df["score"] == int(value)] | ||
if not os.path.exists(self.out_similar_file_dir): | ||
os.makedirs(self.out_similar_file_dir) | ||
|
||
try: | ||
results_split_csv = os.path.join(self.out_similar_file_dir + os.sep, | ||
"filter_{}.csv".format(str(value))) | ||
pd_tmp.to_csv(results_split_csv, encoding="gbk", index=False) | ||
except: | ||
print("write part error") | ||
|
||
lines = pd_df[pd_df["score"] == int(value)]["index"] | ||
num = 0 | ||
for line in lines: | ||
des_path_temp = os.path.join(self.out_similar_file_dir + os.sep, str(value), str(line)) | ||
if not os.path.exists(des_path_temp): | ||
os.makedirs(des_path_temp) | ||
pairs_path = os.path.join(self.out_similar_dir + os.sep, str(line)) | ||
for img_id in os.listdir(pairs_path): | ||
img_path = os.path.join(pairs_path + os.sep, img_id) | ||
shutil.copy(img_path, des_path_temp) | ||
except: | ||
print("error") | ||
|
||
|
||
def retrieval_images(self, feature_dict, lsh, num_results=1): | ||
# load model | ||
# with open(self.feature_path, "rb") as f: | ||
# feature_dict = pickle.load(f) | ||
# with open(self.index_path, "rb") as f: | ||
# lsh = pickle.load(f) | ||
|
||
self.find_similar_img_gyz(feature_dict, lsh, num_results) | ||
# self.filter_gap_score() | ||
|
||
|
||
if __name__ == "__main__": | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.