Skip to content

Commit

Permalink
Support default_cases and case list file
Browse files Browse the repository at this point in the history
Change-Id: I8f1de9cd90404a42a1d7646c030d3a12e13f4970
  • Loading branch information
Jeffwhen committed Aug 8, 2022
1 parent 5ff04fd commit 6e529b6
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 47 deletions.
18 changes: 5 additions & 13 deletions python/tpu_perf/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,15 @@ def main():

import argparse
parser = argparse.ArgumentParser(description='tpu-perf benchmark tool')
parser.add_argument(
'models', metavar='MODEL', type=str, nargs='*',
help='models to build')
parser.add_argument('--time', action='store_true')
parser.add_argument('--mlir', action='store_true')
parser.add_argument('--exit-on-error', action='store_true')
BuildTree.add_arguments(parser)
args = parser.parse_args()
global option_time_only
option_time_only = args.time

tree = BuildTree(os.path.abspath('.'))
tree = BuildTree(os.path.abspath('.'), args)

mem_size = sys_memory_size()
max_workers = max(1, int(mem_size / 1024 / 1024 / 12))
Expand All @@ -172,15 +170,9 @@ def main():
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []

if not args.models:
for path, config in tree.walk():
f = executor.submit(build_fn, tree, path, config)
futures.append(f)
else:
for name in args.models:
for path, config in tree.read_dir(name):
f = executor.submit(build_fn, tree, path, config)
futures.append(f)
for path, config in tree.walk():
f = executor.submit(build_fn, tree, path, config)
futures.append(f)

for f in as_completed(futures):
err = f.exception()
Expand Down
40 changes: 35 additions & 5 deletions python/tpu_perf/buildtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,41 @@ def check_buildtree():
return ok

class BuildTree:
def __init__(self, root):
def __init__(self, root, args = None):
self.root = root
self.global_config = global_config = read_config(root) or dict()
global_config['root'] = root
if 'workdir' not in global_config:
global_config['workdir'] = os.path.join(root, 'output')

self.cases = []
if not args.full:
if 'default_cases' in self.global_config:
self.cases = self.global_config['default_cases']
if args.list:
with open(args.list) as f:
lines = [l.strip(' \n') for l in f.readlines()]
lines = [l for l in lines if l]
self.cases = lines
if args.models:
self.cases = args.models

self.output_names = set()

def read_global_variable(self, name, config = dict()):
return self.expand_variables(config, self.global_config[name])
@staticmethod
def add_arguments(parser):
parser.add_argument(
'models', metavar='MODEL', type=str, nargs='*',
help='model directories to run')
parser.add_argument('--full', action='store_true', help='Run all cases')
parser.add_argument('--list', '-l', type=str, help='Case list')

def read_global_variable(self, name, config = dict(), default=None):
if default is None and name not in self.global_config:
logging.error(f'Invalid global config field {name}')
raise RuntimeError('Invalid Field')
return self.expand_variables(
config, self.global_config.get(name, default))

whole_var_pattern = '^\$\(([a-z0-9_]+)\)$'

Expand Down Expand Up @@ -167,13 +191,19 @@ def _read_dir(self, config_fn):

if 'input' in config:
key = hash_name(config['input'])
config['lmdb_out'] = os.path.join(
self.read_global_variable('data_dir'), key)
data_dir = self.read_global_variable(
'data_dir', default='$(root)/data')
config['lmdb_out'] = os.path.join(data_dir, key)

yield path, copy.deepcopy(config)

def walk(self, path=None):
if path is None:
if self.cases:
for path in self.cases:
for ret in self.read_dir(path):
yield ret
return
path = self.root
if not os.path.isdir(path):
return
Expand Down
18 changes: 8 additions & 10 deletions python/tpu_perf/make_lmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def main():

if not check_buildtree():
sys.exit(1)
tree = BuildTree(os.path.abspath('.'))
import argparse
parser = argparse.ArgumentParser(description='tpu-perf benchmark tool')
BuildTree.add_arguments(parser)
args = parser.parse_args()
tree = BuildTree(os.path.abspath('.'), args)

# Prepare data path
try:
Expand Down Expand Up @@ -76,15 +80,9 @@ def main():
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []

if len(sys.argv) == 1:
for path, config in tree.walk():
f = executor.submit(build_lmdb, data_dir, tree, path, config)
futures.append(f)
else:
for name in sys.argv[1:]:
for path, config in tree.read_dir(name):
f = executor.submit(build_lmdb, data_dir, tree, path, config)
futures.append(f)
for path, config in tree.walk():
f = executor.submit(build_lmdb, data_dir, tree, path, config)
futures.append(f)

for f in as_completed(futures):
err = f.exception()
Expand Down
16 changes: 8 additions & 8 deletions python/tpu_perf/precision_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def main():
if not check_buildtree():
sys.exit(1)

tree = BuildTree(os.path.abspath('.'))
import argparse
parser = argparse.ArgumentParser(description='tpu-perf benchmark tool')
BuildTree.add_arguments(parser)
args = parser.parse_args()

tree = BuildTree(os.path.abspath('.'), args)
runner = Runner()
if len(sys.argv) == 1:
for path, config in tree.walk():
runner.run(tree, path, config)
else:
for name in sys.argv[1:]:
for path, config in tree.read_dir(name):
runner.run(tree, path, config)
for path, config in tree.walk():
runner.run(tree, path, config)

if __name__ == '__main__':
main()
15 changes: 4 additions & 11 deletions python/tpu_perf/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ def main():

import argparse
parser = argparse.ArgumentParser(description='tpu-perf benchmark tool')
parser.add_argument(
'models', metavar='MODEL', type=str, nargs='*',
help='models to run')
BuildTree.add_arguments(parser)
parser.add_argument('--cmodel', action='store_true')
parser.add_argument('--mlir', action='store_true')
args = parser.parse_args()
Expand All @@ -203,7 +201,7 @@ def main():
if not check_buildtree():
sys.exit(1)

tree = BuildTree(os.path.abspath('.'))
tree = BuildTree(os.path.abspath('.'), args)
stat_fn = os.path.join(tree.global_config['workdir'], 'stats.csv')
run_func = run_mlir if args.mlir else run_nntc
with open(stat_fn, 'w') as f:
Expand All @@ -230,13 +228,8 @@ def main():
'cpu_usage',
'ddr_utilization'])

if not args.models:
for path, config in tree.walk():
run_func(tree, path, config, csv_f)
else:
for name in args.models:
for path, config in tree.read_dir(name):
run_func(tree, path, config, csv_f)
for path, config in tree.walk():
run_func(tree, path, config, csv_f)

if __name__ == '__main__':
main()

0 comments on commit 6e529b6

Please sign in to comment.