Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Improvements to examples/classification/ #559

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Improvements to examples/classification/
Allow multiple image files
Add batch_size
  • Loading branch information
lukeyeager committed Jan 29, 2016
commit abf719b12415bd08b379e375b0bdeabf8bd947f2
19 changes: 12 additions & 7 deletions examples/classification/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def load_image(path, height, width, mode='RGB'):
image = scipy.misc.imresize(image, (height, width), 'bilinear')
return image

def forward_pass(images, net, transformer, batch_size=1):
def forward_pass(images, net, transformer, batch_size=None):
"""
Returns scores for each image as an np.ndarray (nImages x nClasses)

Expand All @@ -119,6 +119,9 @@ def forward_pass(images, net, transformer, batch_size=1):
batch_size -- how many images can be processed at once
(a high value may result in out-of-memory errors)
"""
if batch_size is None:
batch_size = 1

caffe_images = []
for image in images:
if image.ndim == 2:
Expand Down Expand Up @@ -168,7 +171,7 @@ def read_labels(labels_file):
return labels

def classify(caffemodel, deploy_file, image_files,
mean_file=None, labels_file=None, use_gpu=True):
mean_file=None, labels_file=None, batch_size=None, use_gpu=True):
"""
Classify some images against a Caffe model and print the results

Expand Down Expand Up @@ -197,7 +200,7 @@ def classify(caffemodel, deploy_file, image_files,

# Classify the image
classify_start_time = time.time()
scores = forward_pass(images, net, transformer)
scores = forward_pass(images, net, transformer, batch_size=batch_size)
print 'Classification took %s seconds.' % (time.time() - classify_start_time,)

### Process the results
Expand Down Expand Up @@ -231,23 +234,25 @@ def classify(caffemodel, deploy_file, image_files,

parser.add_argument('caffemodel', help='Path to a .caffemodel')
parser.add_argument('deploy_file', help='Path to the deploy file')
parser.add_argument('image', help='Path to an image')
parser.add_argument('image_file',
nargs='+',
help='Path[s] to an image')

### Optional arguments

parser.add_argument('-m', '--mean',
help='Path to a mean file (*.npy)')
parser.add_argument('-l', '--labels',
help='Path to a labels file')
parser.add_argument('--batch-size',
type=int)
parser.add_argument('--nogpu',
action='store_true',
help="Don't use the GPU")

args = vars(parser.parse_args())

image_files = [args['image']]

classify(args['caffemodel'], args['deploy_file'], image_files,
classify(args['caffemodel'], args['deploy_file'], args['image_file'],
args['mean'], args['labels'], not args['nogpu'])

print 'Script took %s seconds.' % (time.time() - script_start_time,)
Expand Down
22 changes: 15 additions & 7 deletions examples/classification/use_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def unzip_archive(archive):
Arguments:
archive -- the path to an archive file
"""
assert os.path.exists(archive), 'File not found - %s' % archive

tmpdir = os.path.join(tempfile.gettempdir(),
os.path.basename(archive))
assert tmpdir != archive # That wouldn't work out
Expand All @@ -43,7 +45,7 @@ def unzip_archive(archive):
return tmpdir


def classify_with_archive(archive, image_files, use_gpu=True):
def classify_with_archive(archive, image_files, batch_size=None, use_gpu=True):
"""
"""
tmpdir = unzip_archive(archive)
Expand All @@ -68,7 +70,8 @@ def classify_with_archive(archive, image_files, use_gpu=True):
assert deploy_file is not None, 'Deploy file not found'

classify(caffemodel, deploy_file, image_files,
mean_file=mean_file, labels_file=labels_file, use_gpu=use_gpu)
mean_file=mean_file, labels_file=labels_file,
batch_size=batch_size, use_gpu=use_gpu)


if __name__ == '__main__':
Expand All @@ -78,20 +81,25 @@ def classify_with_archive(archive, image_files, use_gpu=True):

### Positional arguments

parser.add_argument('archive', help='Path to a DIGITS model archive')
parser.add_argument('image', help='Path to an image')
parser.add_argument('archive', help='Path to a DIGITS model archive')
parser.add_argument('image_file',
nargs='+',
help='Path[s] to an image')

### Optional arguments

parser.add_argument('--batch-size',
type=int)
parser.add_argument('--nogpu',
action='store_true',
help="Don't use the GPU")

args = vars(parser.parse_args())

image_files = [args['image']]

classify_with_archive(args['archive'], image_files, not args['nogpu'])
classify_with_archive(args['archive'], args['image_file'],
batch_size=args['batch_size'],
use_gpu=(not args['nogpu']),
)

print 'Script took %s seconds.' % (time.time() - script_start_time,)