Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
MXNet benchmark scripts (#3971)
Browse files Browse the repository at this point in the history
  • Loading branch information
nswamy authored and piiswrong committed Nov 25, 2016
1 parent 550ae83 commit ddc27e0
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 21 deletions.
22 changes: 20 additions & 2 deletions example/image-classification/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Image Classification

This fold contains examples for image classifications. In this task, we assign
labels to an image with confidence scores, see the following figure for example ([source](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)):
This folder contains examples for image classifications. In this task, we assign labels to an image with confidence scores, see the following figure for example ([source](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)):

<img src=https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/image-classification.png
width=400/>
Expand Down Expand Up @@ -95,6 +94,25 @@ We can train a model using multiple machines.
See more launch options, e.g. by `Yarn`, and how to write a distributed training
program on this [tutorial](http://mxnet.readthedocs.io/en/latest/how_to/multi_devices.html)

### Benchmark
To run benchmark on imagenet networks, use **--benchmark** as the argument to train_imagenet.py, An example is shown below:

```
python train_imagenet.py --gpus 0,1 --network inception-v3 --batch-size 64 --data-shape 299 --num-epochs 1 --kv-store device --benchmark
```
When running in benchmark mode, the script generates synthetic data of the given data shape and batch size.
The `benchmark.py` can be used to run a series of benchmarks against different image networks on a given set of workers and takes the following arguments:
**--worker_file**: list of workers.
**--worker_count**: total number of workers.
**--gpu_count**: number of gpus to use on each worker.
**--networks**: one or more networks in the format network_name:batch_size:image_size
The script runs benchmarks on variable number of gpus upto gpu_count starting from 1 gpu doubling the number of gpus in each run using **kv-store=device** and after that running on variable number of nodes on all gpus starting with 1 node upto worker_count doubling the number of nodes used in each run using **kv-store==dist_sync_device**.
An example to run the benchmark script is shown below with 8 workers and 16 gpus on each worker:

```
python benchmark.py --worker_file /opt/deeplearning/workers --worker_count 8 --gpu_count 16 --networks 'inception-v3:32:299'
```

### Predict

- [predict with pre-trained model](../notebooks/predict-with-pretrained-model.ipynb)
Expand Down
207 changes: 207 additions & 0 deletions example/image-classification/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import logging
import argparse
import os
import time
import sys
import shutil
import csv
import re
import subprocess, threading
import pygal
import importlib
import collections
import threading
import copy

'''
Setup Logger and LogLevel
'''
def setup_logging(log_loc):
if os.path.exists(log_loc):
shutil.move(log_loc, log_loc + "_" + str(int(os.path.getctime(log_loc))))
os.makedirs(log_loc)

log_file = '{}/benchmark.log'.format(log_loc)
LOGGER = logging.getLogger('benchmark')
LOGGER.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(name)s %(message)s')
file_handler = logging.FileHandler(log_file)
console_handler = logging.StreamHandler()
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)

LOGGER.addHandler(file_handler)
LOGGER.addHandler(console_handler)
return LOGGER

'''
Runs the command given in the cmd_args for specified timeout period
and terminates after
'''
class RunCmd(threading.Thread):
def __init__(self, cmd_args, logfile):
threading.Thread.__init__(self)
self.cmd_args = cmd_args
self.logfile = logfile
self.process = None

def run(self):
LOGGER = logging.getLogger('benchmark')
LOGGER.info('started running %s', ' '.join(self.cmd_args))
log_fd = open(self.logfile, 'w')
self.process = subprocess.Popen(self.cmd_args, stdout=log_fd, stderr=subprocess.STDOUT, universal_newlines=True)
for line in self.process.communicate():
LOGGER.debug(line)
log_fd.close()
LOGGER.info('finished running %s', ' '.join(self.cmd_args))

def startCmd(self, timeout):
LOGGER.debug('Attempting to start Thread to run %s', ' '.join(self.cmd_args))
self.start()
self.join(timeout)
if self.is_alive():
LOGGER.debug('Terminating process running %s', ' '.join(self.cmd_args))
self.process.terminate()
self.join()
time.sleep(1)
return

log_loc = './benchmark'
LOGGER = setup_logging(log_loc)

class Network(object):
def __init__(self, name, img_size, batch_size):
self.name = name
self.img_size = img_size
self.batch_size = batch_size
self.gpu_speedup = collections.OrderedDict()

def parse_args():
class NetworkArgumentAction(argparse.Action):
def validate(self, attrs):
args = attrs.split(':')
if len(args) != 3 or isinstance(args[0], str) == False:
print 'expected network attributes in format network_name:batch_size:image_size'
print 'exiting1'
try:
#check if the network exists
importlib.import_module('symbol_' + args[0]).get_symbol('1000')
batch_size = int(args[1])
img_size = int(args[2])
return Network(name=args[0], batch_size=batch_size, img_size=img_size)
except Exception as e:
print 'expected network attributes in format network_name:batch_size:image_size'
print e
sys.exit(1)
def __init__(self, *args, **kw):
kw['nargs'] = '+'
argparse.Action.__init__(self, *args, **kw)
def __call__(self, parser, namespace, values, option_string=None):
if isinstance(values, list) == True:
setattr(namespace, self.dest, map(self.validate, values))
else:
setattr(namespace, self.dest, self.validate(values))
parser = argparse.ArgumentParser(description='Run Benchmark on various imagenet networks using train_imagenent.py')
parser.add_argument('--networks', dest='networks', nargs= '+', type=str, help= 'one or more networks in the format network_name:batch_size:image_size', action=NetworkArgumentAction)
parser.add_argument('--worker_file', type=str, help='file that contains a list of workers', required=True)
parser.add_argument('--worker_count', type=int, help='number of workers to run benchmark on', required=True)
parser.add_argument('--gpu_count', type=int, help='number of gpus on each worker to use', required=True)
args = parser.parse_args()
return args

def series(max_count):
i=max_count
s=[]
while i >= 1:
s.append(i)
i=i/2
return s[::-1]

'''
Choose the middle iteration to get the images processed per sec
'''
def images_processed(log_loc):
f=open(log_loc)
img_per_sec = re.findall("(?:Batch\s+\[30\]\\\\tSpeed:\s+)(\d+\.\d+)(?:\s+)", str(f.readlines()))
f.close()
img_per_sec = map(float, img_per_sec)
total_img_per_sec = sum(img_per_sec)
return total_img_per_sec

def generate_hosts_file(num_nodes, workers_file, args_workers_file):
f = open(workers_file, 'w')
output = subprocess.check_output(['head', '-n', str(num_nodes), args_workers_file])
f.write(output)
f.close()
return

def stop_old_processes(hosts_file):
stop_args = ['python', '../../tools/kill-mxnet.py', hosts_file]
stop_args_str = ' '.join(stop_args)
LOGGER.info('killing old remote processes\n %s', stop_args_str)
stop = subprocess.check_output(stop_args, stderr=subprocess.STDOUT)
LOGGER.debug(stop)
time.sleep(1)

def run_imagenet(kv_store, data_shape, batch_size, num_gpus, num_nodes, network, args_workers_file):
imagenet_args=['python', 'train_imagenet.py', '--gpus', ','.join(str(i) for i in xrange(num_gpus)), \
'--network', network, '--batch-size', str(batch_size * num_gpus), \
'--data-shape', str(data_shape), '--num-epochs', '1' ,'--kv-store', kv_store, '--benchmark']
log = log_loc + '/' + network + '_' + str(num_nodes*num_gpus) + '_log'
hosts = log_loc + '/' + network + '_' + str(num_nodes*num_gpus) + '_workers'
generate_hosts_file(num_nodes, hosts, args_workers_file)
stop_old_processes(hosts)
launch_args = ['../../tools/launch.py', '-n', str(num_nodes), '-s', str(num_nodes*2), '-H', hosts, ' '.join(imagenet_args) ]

#use train_imagenet when running on a single node
if kv_store == 'device':
imagenet = RunCmd(imagenet_args, log)
imagenet.startCmd(timeout = 60 * 10)
else:
launch = RunCmd(launch_args, log)
launch.startCmd(timeout = 60 * 10)

stop_old_processes(hosts)
img_per_sec = images_processed(log)
LOGGER.info('network: %s, num_gpus: %d, image/sec: %f', network, num_gpus*num_nodes, img_per_sec)
return img_per_sec

def plot_graph(args):
speedup_chart = pygal.Line(x_title ='gpus',y_title ='speedup', logarithmic=True)
speedup_chart.x_labels = map(str, series(args.worker_count * args.gpu_count))
speedup_chart.add('ideal speedup', series(args.worker_count * args.gpu_count))
for net in args.networks:
image_single_gpu = net.gpu_speedup[1] if 1 in net.gpu_speedup or not net.gpu_speedup[1] else 1
y_values = [ each/image_single_gpu for each in net.gpu_speedup.values() ]
LOGGER.info('%s: image_single_gpu:%.2f' %(net.name, image_single_gpu))
LOGGER.debug('network:%s, y_values: %s' % (net.name, ' '.join(map(str, y_values))))
speedup_chart.add(net.name , y_values \
, formatter= lambda y_val, img = copy.deepcopy(image_single_gpu), batch_size = copy.deepcopy(net.batch_size): 'speedup:%.2f, img/sec:%.2f, batch/gpu:%d' % \
(0 if y_val is None else y_val, 0 if y_val is None else y_val * img, batch_size))
speedup_chart.render_to_file(log_loc + '/speedup.svg')

def write_csv(log_loc, args):
for net in args.networks:
with open(log_loc + '/' + net.name + '.csv', 'wb') as f:
w = csv.writer(f)
w.writerow(['num_gpus', 'img_processed_per_sec'])
w.writerows(net.gpu_speedup.items())

def main():
args = parse_args()
for net in args.networks:
#use kv_store='device' when running on 1 node
for num_gpus in series(args.gpu_count):
imgs_per_sec = run_imagenet(kv_store='device', data_shape=net.img_size, batch_size=net.batch_size, \
num_gpus=num_gpus, num_nodes=1, network=net.name, args_workers_file=args.worker_file)
net.gpu_speedup[num_gpus] = imgs_per_sec
for num_nodes in series(args.worker_count)[1::]:
imgs_per_sec = run_imagenet(kv_store='dist_sync_device', data_shape=net.img_size, batch_size=net.batch_size, \
num_gpus=args.gpu_count, num_nodes=num_nodes, network=net.name, args_workers_file=args.worker_file)
net.gpu_speedup[num_nodes * args.gpu_count] = imgs_per_sec
LOGGER.info('Network: %s (num_gpus, images_processed): %s', net.name, ','.join(map(str, net.gpu_speedup.items())))
write_csv(log_loc, args)
plot_graph(args)

if __name__ == '__main__':
main()
61 changes: 55 additions & 6 deletions example/image-classification/train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

# don't use -n and -s, which are resevered for the distributed training
parser = argparse.ArgumentParser(description='train an image classifer on imagenet')
mutually_exclusive_parser_group = parser.add_mutually_exclusive_group(required=True)
parser.add_argument('--network', type=str, default='inception-bn',
choices = ['alexnet', 'vgg', 'googlenet', 'inception-bn', 'inception-bn-full', 'inception-v3'],
choices = ['alexnet', 'vgg', 'googlenet', 'inception-bn',
'inception-bn-full', 'inception-v3', 'resnet'],
help = 'the cnn to use')
parser.add_argument('--data-dir', type=str, required=True,
mutually_exclusive_parser_group.add_argument('--data-dir', type=str,
help='the input data directory')
parser.add_argument('--model-prefix', type=str,
help='the prefix of the model to load')
Expand All @@ -31,14 +33,14 @@
parser.add_argument('--batch-size', type=int, default=32,
help='the batch size')
parser.add_argument('--gpus', type=str,
help='the gpus will be used, e.g "0,1,2,3"')
help='gpus to be used, e.g "0,1,2,3"')
parser.add_argument('--kv-store', type=str, default='local',
help='the kvstore type')
parser.add_argument('--num-examples', type=int, default=1281167,
help='the number of training examples')
parser.add_argument('--num-classes', type=int, default=1000,
help='the number of classes')
parser.add_argument('--log-file', type=str,
parser.add_argument('--log-file', type=str,
help='the name of log file')
parser.add_argument('--log-dir', type=str, default="/tmp/",
help='directory of the log file')
Expand All @@ -48,13 +50,58 @@
help="validation dataset name")
parser.add_argument('--data-shape', type=int, default=224,
help='set image\'s shape')
mutually_exclusive_parser_group.add_argument('--benchmark', default=False, action='store_true',
help='benchmark for 50 iterations using randomly generated Synthetic data')
args = parser.parse_args()

# network
import importlib
net = importlib.import_module('symbol_' + args.network).get_symbol(args.num_classes)


# data
import random
from mxnet.io import DataBatch, DataIter
import numpy as np
class SyntheticDataIter(DataIter):
def __init__(self, num_classes, data_shape, max_iter):
self.batch_size = data_shape[0]
self.cur_iter = 0
self.max_iter = max_iter
label = np.random.randint(0, num_classes, [self.batch_size,])
data = np.random.uniform(-1, 1, data_shape)
self.data = mx.nd.array(data)
self.label = mx.nd.array(label)
def __iter__(self):
return self
@property
def provide_data(self):
return [('data',self.data.shape)]
@property
def provide_label(self):
return [('softmax_label',(self.batch_size,))]
def next(self):
self.cur_iter += 1
if self.cur_iter <= self.max_iter:
return DataBatch(data=(self.data,),
label=(self.label,),
pad=0,
index=None,
provide_data=self.provide_data,
provide_label=self.provide_label)
else:
raise StopIteration
def __next__(self):
return self.next()
def reset(self):
self.cur_iter = 0

def get_sythentic_data_iter(args, kv):
data_shape = (args.batch_size, 3, args.data_shape, args.data_shape)
train = SyntheticDataIter(args.num_classes, data_shape, 50)
val = SyntheticDataIter(args.num_classes, data_shape, 1)
return (train, val)

def get_iterator(args, kv):
data_shape = (3, args.data_shape, args.data_shape)
train = mx.io.ImageRecordIter(
Expand Down Expand Up @@ -83,5 +130,7 @@ def get_iterator(args, kv):

return (train, val)

# train
train_model.fit(args, net, get_iterator)
if args.benchmark:
train_model.fit(args, net, get_sythentic_data_iter)
else:
train_model.fit(args, net, get_iterator)
35 changes: 22 additions & 13 deletions example/image-classification/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,31 @@ def fit(args, network, data_loader, batch_end_callback=None):
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
**model_args)

eval_metrics = ['accuracy']
## TopKAccuracy only allows top_k > 1
for top_k in [5, 10, 20]:
eval_metrics.append(mx.metric.create('top_k_accuracy', top_k = top_k))

if batch_end_callback is not None:
if not isinstance(batch_end_callback, list):
batch_end_callback = [batch_end_callback]
else:
batch_end_callback = []
batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))

model.fit(
X = train,
eval_data = val,
eval_metric = eval_metrics,
kvstore = kv,
batch_end_callback = batch_end_callback,
epoch_end_callback = checkpoint)
if args.benchmark:
batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 10))
# don't run evaluation for benchmark
model.fit(
X = train,
eval_data = val,
kvstore = kv,
batch_end_callback = batch_end_callback,
epoch_end_callback = checkpoint)
else:
eval_metrics = ['accuracy']
# TopKAccuracy only allows top_k > 1
for top_k in [5, 10, 20]:
eval_metrics.append(mx.metric.create('top_k_accuracy', top_k = top_k))
batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))
model.fit(
X = train,
eval_data = val,
eval_metric = eval_metrics,
kvstore = kv,
batch_end_callback = batch_end_callback,
epoch_end_callback = checkpoint)

0 comments on commit ddc27e0

Please sign in to comment.