Skip to content

Commit

Permalink
Update imageretrievalnet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yinhaoxs authored May 12, 2020
1 parent 23c570e commit 9dffc65
Showing 1 changed file with 123 additions and 86 deletions.
209 changes: 123 additions & 86 deletions cirtorch/networks/imageretrievalnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,81 +11,85 @@
from cirtorch.layers.normalization import L2N, PowerLaw
from cirtorch.datasets.genericdataset import ImagesFromList
from cirtorch.utils.general import get_data_root
from PIL import Image
from ModelHelper.Common.CommonUtils.ImageAugmentation import Padding
import cv2

# for some models, we have imported features (convolutions) from caffe because the image retrieval performance is higher for them
FEATURES = {
'vgg16' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-vgg16-features-d369c8e.pth',
'resnet50' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet50-features-ac468af.pth',
'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet101-features-10a101d.pth',
'resnet152' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet152-features-1011020.pth',
'vgg16': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-vgg16-features-d369c8e.pth',
'resnet50': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet50-features-ac468af.pth',
'resnet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet101-features-10a101d.pth',
'resnet152': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet152-features-1011020.pth',
}

# TODO: pre-compute for more architectures and properly test variations (pre l2norm, post l2norm)
# pre-computed local pca whitening that can be applied before the pooling layer
L_WHITENING = {
'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-9f830ef.pth', # no pre l2 norm
'resnet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-9f830ef.pth',
# no pre l2 norm
# 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-da5c935.pth', # with pre l2 norm
}

# possible global pooling layers, each on of these can be made regional
POOLING = {
'mac' : MAC,
'spoc' : SPoC,
'gem' : GeM,
'gemmp' : GeMmp,
'rmac' : RMAC,
'mac': MAC,
'spoc': SPoC,
'gem': GeM,
'gemmp': GeMmp,
'rmac': RMAC,
}

# TODO: pre-compute for: resnet50-gem-r, resnet50-mac-r, vgg16-mac-r, alexnet-mac-r
# pre-computed regional whitening, for most commonly used architectures and pooling methods
R_WHITENING = {
'alexnet-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-rwhiten-c8cf7e2.pth',
'vgg16-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth',
'resnet101-mac-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-rwhiten-7f1ed8c.pth',
'resnet101-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-rwhiten-adace84.pth',
'alexnet-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-rwhiten-c8cf7e2.pth',
'vgg16-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth',
'resnet101-mac-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-rwhiten-7f1ed8c.pth',
'resnet101-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-rwhiten-adace84.pth',
}

# TODO: pre-compute for more architectures
# pre-computed final (global) whitening, for most commonly used architectures and pooling methods
WHITENING = {
'alexnet-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-whiten-454ad53.pth',
'alexnet-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-whiten-4c9126b.pth',
'vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-whiten-eaa6695.pth',
'vgg16-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-whiten-83582df.pth',
'resnet50-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet50-gem-whiten-f15da7b.pth',
'resnet101-mac-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-whiten-9df41d3.pth',
'resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-whiten-22ab0c1.pth',
'resnet101-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-whiten-b379c0a.pth',
'resnet101-gemmp' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gemmp-whiten-770f53c.pth',
'resnet152-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet152-gem-whiten-abe7b93.pth',
'densenet121-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet121-gem-whiten-79e3eea.pth',
'densenet169-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet169-gem-whiten-6b2a76a.pth',
'densenet201-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet201-gem-whiten-22ea45c.pth',
'alexnet-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-whiten-454ad53.pth',
'alexnet-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-whiten-4c9126b.pth',
'vgg16-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-whiten-eaa6695.pth',
'vgg16-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-whiten-83582df.pth',
'resnet50-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet50-gem-whiten-f15da7b.pth',
'resnet101-mac-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-whiten-9df41d3.pth',
'resnet101-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-whiten-22ab0c1.pth',
'resnet101-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-whiten-b379c0a.pth',
'resnet101-gemmp': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gemmp-whiten-770f53c.pth',
'resnet152-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet152-gem-whiten-abe7b93.pth',
'densenet121-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet121-gem-whiten-79e3eea.pth',
'densenet169-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet169-gem-whiten-6b2a76a.pth',
'densenet201-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet201-gem-whiten-22ea45c.pth',
}

# output dimensionality for supported architectures
OUTPUT_DIM = {
'alexnet' : 256,
'vgg11' : 512,
'vgg13' : 512,
'vgg16' : 512,
'vgg19' : 512,
'resnet18' : 512,
'resnet34' : 512,
'resnet50' : 2048,
'resnet101' : 2048,
'resnet152' : 2048,
'densenet121' : 1024,
'densenet169' : 1664,
'densenet201' : 1920,
'densenet161' : 2208, # largest densenet
'squeezenet1_0' : 512,
'squeezenet1_1' : 512,
'alexnet': 256,
'vgg11': 512,
'vgg13': 512,
'vgg16': 512,
'vgg19': 512,
'resnet18': 512,
'resnet34': 512,
'resnet50': 2048,
'resnet101': 2048,
'resnet152': 2048,
'densenet121': 1024,
'densenet169': 1664,
'densenet201': 1920,
'densenet161': 2208, # largest densenet
'squeezenet1_0': 512,
'squeezenet1_1': 512,
}


class ImageRetrievalNet(nn.Module):

def __init__(self, features, lwhiten, pool, whiten, meta):
super(ImageRetrievalNet, self).__init__()
self.features = nn.Sequential(*features)
Expand All @@ -94,7 +98,7 @@ def __init__(self, features, lwhiten, pool, whiten, meta):
self.whiten = whiten
self.norm = L2N()
self.meta = meta

def forward(self, x):
# x -> features
o = self.features(x)
Expand All @@ -104,9 +108,9 @@ def forward(self, x):
if self.lwhiten is not None:
# o = self.norm(o)
s = o.size()
o = o.permute(0,2,3,1).contiguous().view(-1, s[1])
o = o.permute(0, 2, 3, 1).contiguous().view(-1, s[1])
o = self.lwhiten(o)
o = o.view(s[0],s[2],s[3],self.lwhiten.out_features).permute(0,3,1,2)
o = o.view(s[0], s[2], s[3], self.lwhiten.out_features).permute(0, 3, 1, 2)
# o = self.norm(o)

# features -> pool -> norm
Expand All @@ -117,7 +121,7 @@ def forward(self, x):
o = self.norm(self.whiten(o))

# permute so that it is Dx1 column vector per image (DxN if many images)
return o.permute(1,0)
return o.permute(1, 0)

def __repr__(self):
tmpstr = super(ImageRetrievalNet, self).__repr__()[:-1]
Expand All @@ -126,7 +130,7 @@ def __repr__(self):
return tmpstr

def meta_repr(self):
tmpstr = ' (' + 'meta' + '): dict( \n' # + self.meta.__repr__() + '\n'
tmpstr = ' (' + 'meta' + '): dict( \n' # + self.meta.__repr__() + '\n'
tmpstr += ' architecture: {}\n'.format(self.meta['architecture'])
tmpstr += ' local_whitening: {}\n'.format(self.meta['local_whitening'])
tmpstr += ' pooling: {}\n'.format(self.meta['pooling'])
Expand All @@ -140,7 +144,6 @@ def meta_repr(self):


def init_network(params):

# parse params with default values
architecture = params.get('architecture', 'resnet101')
local_whitening = params.get('local_whitening', False)
Expand Down Expand Up @@ -192,22 +195,22 @@ def init_network(params):
lw = architecture
if lw in L_WHITENING:
print(">> {}: for '{}' custom computed local whitening '{}' is used"
.format(os.path.basename(__file__), lw, os.path.basename(L_WHITENING[lw])))
.format(os.path.basename(__file__), lw, os.path.basename(L_WHITENING[lw])))
whiten_dir = os.path.join(get_data_root(), 'whiten')
lwhiten.load_state_dict(model_zoo.load_url(L_WHITENING[lw], model_dir=whiten_dir))
else:
print(">> {}: for '{}' there is no local whitening computed, random weights are used"
.format(os.path.basename(__file__), lw))
.format(os.path.basename(__file__), lw))

else:
lwhiten = None

# initialize pooling
if pooling == 'gemmp':
pool = POOLING[pooling](mp=dim)
else:
pool = POOLING[pooling]()

# initialize regional pooling
if regional:
rpool = pool
Expand All @@ -218,12 +221,12 @@ def init_network(params):
rw = '{}-{}-r'.format(architecture, pooling)
if rw in R_WHITENING:
print(">> {}: for '{}' custom computed regional whitening '{}' is used"
.format(os.path.basename(__file__), rw, os.path.basename(R_WHITENING[rw])))
.format(os.path.basename(__file__), rw, os.path.basename(R_WHITENING[rw])))
whiten_dir = os.path.join(get_data_root(), 'whiten')
rwhiten.load_state_dict(model_zoo.load_url(R_WHITENING[rw], model_dir=whiten_dir))
else:
print(">> {}: for '{}' there is no regional whitening computed, random weights are used"
.format(os.path.basename(__file__), rw))
.format(os.path.basename(__file__), rw))

pool = Rpool(rpool, rwhiten)

Expand All @@ -241,25 +244,25 @@ def init_network(params):
w += '-r'
if w in WHITENING:
print(">> {}: for '{}' custom computed whitening '{}' is used"
.format(os.path.basename(__file__), w, os.path.basename(WHITENING[w])))
.format(os.path.basename(__file__), w, os.path.basename(WHITENING[w])))
whiten_dir = os.path.join(get_data_root(), 'whiten')
whiten.load_state_dict(model_zoo.load_url(WHITENING[w], model_dir=whiten_dir))
else:
print(">> {}: for '{}' there is no whitening computed, random weights are used"
.format(os.path.basename(__file__), w))
.format(os.path.basename(__file__), w))
else:
whiten = None

# create meta information to be stored in the network
meta = {
'architecture' : architecture,
'local_whitening' : local_whitening,
'pooling' : pooling,
'regional' : regional,
'whitening' : whitening,
'mean' : mean,
'std' : std,
'outputdim' : dim,
'architecture': architecture,
'local_whitening': local_whitening,
'pooling': pooling,
'regional': regional,
'whitening': whitening,
'mean': mean,
'std': std,
'outputdim': dim,
}

# create a generic image retrieval network
Expand All @@ -268,57 +271,89 @@ def init_network(params):
# initialize features with custom pretrained network if needed
if pretrained and architecture in FEATURES:
print(">> {}: for '{}' custom pretrained features '{}' are used"
.format(os.path.basename(__file__), architecture, os.path.basename(FEATURES[architecture])))
.format(os.path.basename(__file__), architecture, os.path.basename(FEATURES[architecture])))
model_dir = os.path.join(get_data_root(), 'networks')
net.features.load_state_dict(model_zoo.load_url(FEATURES[architecture], model_dir=model_dir))

return net


# def img2tensor(img_path, imsize, transform):
# img = cv2.imread(img_path)
# padding = Padding((imsize, imsize))
# img = padding(img)
# if transform is not None:
# img = transform(img)
#
# return img


def extract_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
# moving network to gpu and eval mode
net.cuda()
if torch.cuda.is_available():
net.cuda()
net.eval()

# creating dataset loader
loader = torch.utils.data.DataLoader(
ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
batch_size=1, shuffle=False, num_workers=1, pin_memory=True
)

# extracting vectors
with torch.no_grad():
vecs = torch.zeros(net.meta['outputdim'], len(images))
for i, input in enumerate(loader):
input = input.cuda()
img_paths = list()
for i, (input, path) in enumerate(loader):
if torch.cuda.is_available():
input = input.cuda()

if len(ms) == 1 and ms[0] == 1:
vecs[:, i] = extract_ss(net, input)
else:
vecs[:, i] = extract_ms(net, input, ms, msp)
img_paths.append(path)

if (i+1) % print_freq == 0 or (i+1) == len(images):
print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='')
if (i + 1) % print_freq == 0 or (i + 1) == len(images):
print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')
print('')

return vecs
# vecs = torch.zeros(net.meta['outputdim'], len(images))
# img_path_list = list()
# for i in range(len(images)):
# img_path = images[i]
# img_path_list.append(img_path)
# input = img2tensor(img_path, image_size, transform)
# if torch.cuda.is_available():
# input = input.cuda()
#
# if len(ms) == 1 and ms[0] == 1:
# vecs[:, i] = extract_ss(net, input)
# else:
# vecs[:, i] = extract_ms(net, input, ms, msp)
#
# if (i + 1) % print_freq == 0 or (i + 1) == len(images):
# print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')

return vecs, img_paths


def extract_ss(net, input):
return net(input).cpu().data.squeeze()


def extract_ms(net, input, ms, msp):

v = torch.zeros(net.meta['outputdim'])
for s in ms:

for s in ms:
if s == 1:
input_t = input.clone()
else:
else:
input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False)
v += net(input_t).pow(msp).cpu().data.squeeze()

v /= len(ms)
v = v.pow(1./msp)
v = v.pow(1. / msp)
v /= v.norm()

return v
Expand Down Expand Up @@ -348,14 +383,15 @@ def extract_regional_vectors(net, images, image_size, transform, bbxs=None, ms=[
# vecs.append(extract_msr(net, input, ms, msp))
raise NotImplementedError

if (i+1) % print_freq == 0 or (i+1) == len(images):
print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='')
if (i + 1) % print_freq == 0 or (i + 1) == len(images):
print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')
print('')

return vecs


def extract_ssr(net, input):
return net.pool(net.features(input), aggregate=False).squeeze(0).squeeze(-1).squeeze(-1).permute(1,0).cpu().data
return net.pool(net.features(input), aggregate=False).squeeze(0).squeeze(-1).squeeze(-1).permute(1, 0).cpu().data


def extract_local_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
Expand All @@ -382,11 +418,12 @@ def extract_local_vectors(net, images, image_size, transform, bbxs=None, ms=[1],
# vecs.append(extract_msl(net, input, ms, msp))
raise NotImplementedError

if (i+1) % print_freq == 0 or (i+1) == len(images):
print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='')
if (i + 1) % print_freq == 0 or (i + 1) == len(images):
print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')
print('')

return vecs


def extract_ssl(net, input):
return net.norm(net.features(input)).squeeze(0).view(net.meta['outputdim'], -1).cpu().data
return net.norm(net.features(input)).squeeze(0).view(net.meta['outputdim'], -1).cpu().data

0 comments on commit 9dffc65

Please sign in to comment.