Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
hitokun-s committed Apr 2, 2016
1 parent 7d46b4e commit 3c0770e
Showing 3 changed files with 53 additions and 9 deletions.
24 changes: 24 additions & 0 deletions imagenet/food_hack/nin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import math

import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np


class NIN(chainer.Chain):
@@ -40,3 +44,23 @@ def __call__(self, x, t):
self.loss = F.softmax_cross_entropy(h, t)
self.accuracy = F.accuracy(h, t)
return self.loss

def predict(self, x):
self.clear()
h = F.max_pooling_2d(F.relu(self.mlpconv1(x)), 3, stride=2)
h = F.max_pooling_2d(F.relu(self.mlpconv2(h)), 3, stride=2)
h = F.max_pooling_2d(F.relu(self.mlpconv3(h)), 3, stride=2)
h = self.mlpconv4(F.dropout(h, train=self.train))
h = F.reshape(F.average_pooling_2d(h, 6), (x.data.shape[0], 1000))
# return F.softmax(h).data
answers = np.argmax(h.data, axis=1)
print answers
# chainer.Variable(xp.asarray([0]).astype(np.int32), volatile=volatile)
t = chainer.Variable(answers.astype(np.int32), volatile='on')
sfe = F.softmax_cross_entropy(h, t).data
# print math.exp(-sfe)
print sfe
# e_s = np.exp(h.data) # 各サンプルの各入力値の指数を取る
# z_s = e_s.sum(axis=1) # 各サンプルごとに、指数の和を計算
# probs = e_s[np.arange(len(e_s)), answers] / z_s # 各サ
return answers
23 changes: 20 additions & 3 deletions imagenet/food_hack/predict.py
Original file line number Diff line number Diff line change
@@ -93,14 +93,31 @@ def predict_core(img):
t2 = chainer.Variable(xp.asarray([1]).astype(np.int32), volatile=volatile)
t3 = chainer.Variable(xp.asarray([2]).astype(np.int32), volatile=volatile)

model.predict(x)

# print(model.predictor(x).data)
print(model(x,t1).data)
print(model(x,t2).data)
print(model(x,t3).data)
# print(model(x,t1).data)
# print(model(x,t2).data)
# print(model(x,t3).data)

def predict_core_multi(imgs):
x_batch = np.ndarray((len(imgs), 3, model.insize, model.insize), dtype=np.float32)
for (i,img) in enumerate(imgs):
x_batch[i] = img
volatile = 'on'
x = chainer.Variable(xp.asarray(x_batch), volatile=volatile)
# t1 = chainer.Variable(xp.asarray([0]).astype(np.int32), volatile=volatile)
# t2 = chainer.Variable(xp.asarray([1]).astype(np.int32), volatile=volatile)
# t3 = chainer.Variable(xp.asarray([2]).astype(np.int32), volatile=volatile)

model.predict(x)

def predict_by_data(ndArrData):
return predict_core(read_image_data(ndArrData, False, True))

def predict_by_data_multi(ndArrDatas):
return predict_core_multi([read_image_data(arr, False, True) for arr in ndArrDatas])

def predict(file_path):
return predict_core(read_image(file_path, False, True))

15 changes: 9 additions & 6 deletions imagenet/food_hack/selective_search_and_detect.py
Original file line number Diff line number Diff line change
@@ -10,13 +10,14 @@
import predict

# arg: cropped PIL.Image object
def resize_and_predict(img):
def resize(img):
size = min(img.size) # img.size は、(width, height)というタプルを返す。PILのバージョンによっては、img.width, img.heightも使えるが。
start_x = img.size[0] / 2 - size / 2
start_y = img.size[1] / 2 - size / 2
box = (start_x, start_y, start_x + size, start_y + size) # box is a 4-tuple defining the left, upper, right, and lower pixel coordinate.
img = img.crop(box).resize((256, 256), Image.ANTIALIAS)
predict.read_image_data(np.asarray(img))
# predict.predict_by_data(np.asarray(img))
return img

def main():

@@ -35,8 +36,7 @@ def main():
img = io.imread(tgt_img_path)

# perform selective search
img_lbl, regions = selectivesearch.selective_search(
img, scale=500, sigma=0.9, min_size=10)
img_lbl, regions = selectivesearch.selective_search(img, scale=500, sigma=0.9, min_size=10)

candidates = set()
for r in regions:
@@ -53,8 +53,11 @@ def main():
candidates.add(r['rect'])

img = Image.open(tgt_img_path)
for rect in candidates:
resize_and_predict(img.crop(rect))
imgs = [resize(img.crop(rect)) for rect in candidates]
# for rect in candidates:
# print rect
# img = resize(img.crop(rect))
predict.predict_by_data_multi(imgs)

# fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(6, 6))
# ax.imshow(img)

0 comments on commit 3c0770e

Please sign in to comment.