Skip to content

Commit

Permalink
Update retrieval_feature.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yinhaoxs authored May 12, 2020
1 parent ad42b29 commit a8a9b48
Showing 1 changed file with 132 additions and 85 deletions.
217 changes: 132 additions & 85 deletions utils/retrieval_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Email: yinhao_x@163.com
Wechat: xss_yinhao
Github: http://github.com/yinhaoxs
data: 2019-11-23 18:26
desc:
'''
Expand All @@ -23,7 +22,7 @@
from lshash.lshash import LSHash
import math
from sklearn.externals import joblib
from classify import class_results
# from classify import class_results

import torch
from torch.utils.model_zoo import load_url
Expand All @@ -34,96 +33,144 @@
from cirtorch.utils.evaluate import compute_map_and_print
from cirtorch.utils.general import get_data_root, htime


# setting up the visible GPU
os.environ['CUDA_VISIBLE_DEVICES'] = "0"


class ImageProcess():
def __init__(self, img_dir):
self.img_dir = img_dir

def process(self):
imgs = list()
for root, dirs, files in os.walk(self.img_dir):
for file in files:
img_path = os.path.join(root + os.sep, file)
try:
image = Image.open(img_path)
if max(image.size) / min(image.size) < 5:
imgs.append(img_path)
else:
continue
except:
print("image height/width ratio is small")

return imgs
def __init__(self, img_dir):
self.img_dir = img_dir

def process(self):
imgs = list()
for root, dirs, files in os.walk(self.img_dir):
for file in files:
img_path = os.path.join(root + os.sep, file)
try:
image = Image.open(img_path)
if max(image.size) / min(image.size) < 5:
imgs.append(img_path)
else:
continue
except:
print("image height/width ratio is small")

return imgs


class AntiFraudFeatureDataset():
def __init__(self, img_dir, network, feature_path, index_path):
self.img_dir = img_dir
self.network = network
self.feature_path = feature_path
self.index_path = index_path

def constructfeature(self, hash_size, input_dim, num_hashtables):
multiscale = '[1]'
print(">> Loading network:\n>>>> '{}'".format(self.network))
# state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
state = torch.load(self.network)
# parsing net params from meta
# architecture, pooling, mean, std required
# the rest has default values, in case that is doesnt exist
net_params = {}
net_params['architecture'] = state['meta']['architecture']
net_params['pooling'] = state['meta']['pooling']
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
net_params['regional'] = state['meta'].get('regional', False)
net_params['whitening'] = state['meta'].get('whitening', False)
net_params['mean'] = state['meta']['mean']
net_params['std'] = state['meta']['std']
net_params['pretrained'] = False
# network initialization
net = init_network(net_params)
net.load_state_dict(state['state_dict'])
print(">>>> loaded network: ")
print(net.meta_repr())
# setting up the multi-scale parameters
ms = list(eval(multiscale))
print(">>>> Evaluating scales: {}".format(ms))
# moving network to gpu and eval mode
net.cuda()
net.eval()

# set up the transform
normalize = transforms.Normalize(
mean=net.meta['mean'],
std=net.meta['std']
)
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])

# extract database and query vectors
print('>> database images...')
images = ImageProcess(self.img_dir).process()
vecs, img_paths = extract_vectors(net, images, 224, transforms, ms=ms)
feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T)))
# index
lsh = LSHash(hash_size=int(hash_size), input_dim=int(input_dim), num_hashtables=int(num_hashtables))
for img_path, vec in feature_dict.items():
lsh.index(vec.flatten(), extra_data=img_path)

## 保存索引模型
with open(self.feature_path, "wb") as f:
pickle.dump(feature_dict, f)
with open(self.index_path, "wb") as f:
pickle.dump(lsh, f)

print("extract feature is done")
def __init__(self, img_dir, network, feature_path='', index_path=''):
self.img_dir = img_dir
self.network = network
self.feature_path = feature_path
self.index_path = index_path

def constructfeature(self, hash_size, input_dim, num_hashtables):
multiscale = '[1]'
print(">> Loading network:\n>>>> '{}'".format(self.network))
# state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
state = torch.load(self.network)
# parsing net params from meta
# architecture, pooling, mean, std required
# the rest has default values, in case that is doesnt exist
net_params = {}
net_params['architecture'] = state['meta']['architecture']
net_params['pooling'] = state['meta']['pooling']
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
net_params['regional'] = state['meta'].get('regional', False)
net_params['whitening'] = state['meta'].get('whitening', False)
net_params['mean'] = state['meta']['mean']
net_params['std'] = state['meta']['std']
net_params['pretrained'] = False
# network initialization
net = init_network(net_params)
net.load_state_dict(state['state_dict'])
print(">>>> loaded network: ")
print(net.meta_repr())
# setting up the multi-scale parameters
ms = list(eval(multiscale))
print(">>>> Evaluating scales: {}".format(ms))
# moving network to gpu and eval mode
if torch.cuda.is_available():
net.cuda()
net.eval()

# set up the transform
normalize = transforms.Normalize(
mean=net.meta['mean'],
std=net.meta['std']
)
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])

# extract database and query vectors
print('>> database images...')
images = ImageProcess(self.img_dir).process()
vecs, img_paths = extract_vectors(net, images, 1024, transform, ms=ms)
feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T)))
# index
lsh = LSHash(hash_size=int(hash_size), input_dim=int(input_dim), num_hashtables=int(num_hashtables))
for img_path, vec in feature_dict.items():
lsh.index(vec.flatten(), extra_data=img_path)

# ## 保存索引模型
# with open(self.feature_path, "wb") as f:
# pickle.dump(feature_dict, f)
# with open(self.index_path, "wb") as f:
# pickle.dump(lsh, f)

print("extract feature is done")
return feature_dict, lsh

def test_feature(self):
multiscale = '[1]'
print(">> Loading network:\n>>>> '{}'".format(self.network))
# state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
state = torch.load(self.network)
# parsing net params from meta
# architecture, pooling, mean, std required
# the rest has default values, in case that is doesnt exist
net_params = {}
net_params['architecture'] = state['meta']['architecture']
net_params['pooling'] = state['meta']['pooling']
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
net_params['regional'] = state['meta'].get('regional', False)
net_params['whitening'] = state['meta'].get('whitening', False)
net_params['mean'] = state['meta']['mean']
net_params['std'] = state['meta']['std']
net_params['pretrained'] = False
# network initialization
net = init_network(net_params)
net.load_state_dict(state['state_dict'])
print(">>>> loaded network: ")
print(net.meta_repr())
# setting up the multi-scale parameters
ms = list(eval(multiscale))
print(">>>> Evaluating scales: {}".format(ms))
# moving network to gpu and eval mode
if torch.cuda.is_available():
net.cuda()
net.eval()

# set up the transform
normalize = transforms.Normalize(
mean=net.meta['mean'],
std=net.meta['std']
)
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])

# extract database and query vectors
print('>> database images...')
images = ImageProcess(self.img_dir).process()
vecs, img_paths = extract_vectors(net, images, 1024, transform, ms=ms)
feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T)))
return feature_dict


if __name__ == '__main__':
pass

pass

0 comments on commit a8a9b48

Please sign in to comment.