Skip to content

Commit

Permalink
Merge pull request #23 from Jeffwhen/main
Browse files Browse the repository at this point in the history
Move lmdb thread sync into tpu-perf
  • Loading branch information
Jeffwhen authored Nov 16, 2022
2 parents dc14734 + c6373ca commit 8a098db
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 9 deletions.
40 changes: 34 additions & 6 deletions python/tpu_perf/make_lmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,41 @@
from .preprocess import load_plugins
load_plugins()

from threading import Lock

lock = Lock()

from .preprocess import get_preprocess_method

def build_lmdb(tree, path, config):
from .preprocess import get_preprocess_method
if 'input' not in config:
return
data_config = config['input']
preprocess = get_preprocess_method(data_config['preprocess'])
preprocess(tree, config)
try:
if 'input' not in config:
return
data_config = config['input']
if 'preprocess' not in data_config:
return
out_path = config['lmdb_out']

preprocess = get_preprocess_method(data_config['preprocess'])

with lock:
if os.path.exists(os.path.join(out_path, 'info.yaml')):
logging.info(f'{config["name"]} {out_path} already exist')
return
os.makedirs(out_path, exist_ok=True)
info_fn = os.path.join(out_path, 'info.yaml')
import yaml
with open(info_fn, 'w') as f:
yaml.dump(data_config, f)

preprocess(tree, config)
except Exception as err:
import shutil
shutil.rmtree(out_path, ignore_errors=True)
import sys
print(sys.exc_info())
logging.error(f'{path} quit because of exception, {err}')
os._exit(-1)

def main():
logging.basicConfig(
Expand Down
6 changes: 3 additions & 3 deletions python/tpu_perf/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ def load_plugins(name):
for dn in dirnames:
if dn != name:
continue
import_path = os.path.relpath(os.path.join(dirpath, dn), '.')
import_path = import_path.replace('/', '.')
rel_path = os.path.relpath(os.path.join(dirpath, dn), '.')
import_path = rel_path.replace('/', '.')
try:
importlib.import_module(import_path)
except ModuleNotFoundError as err:
if err.name not in import_path:
raise err
logging.warning('No dataset plugin')
logging.warning(f'No {name} plugin in {rel_path}')

def dict_override(a, b):
r = a.copy()
Expand Down

0 comments on commit 8a098db

Please sign in to comment.