From c0716b1c11a19b4c31ec106538f1ceb4f0dbd527 Mon Sep 17 00:00:00 2001 From: m-yoshinaka <49668018+m-yoshinaka@users.noreply.github.com> Date: Wed, 16 Sep 2020 08:40:43 +0000 Subject: [PATCH] [update] Refactor run_sapphire.py Refactor according to PEP8 and make use of argparse and logging. --- run_sapphire.py | 92 +++++++++++++++++++++++++++++-------------------- 1 file changed, 54 insertions(+), 38 deletions(-) diff --git a/run_sapphire.py b/run_sapphire.py index 5c913ce..9118346 100644 --- a/run_sapphire.py +++ b/run_sapphire.py @@ -1,53 +1,69 @@ -import sys +import logging +import argparse import fasttext -from sapphire import Sapphire, setting +from sapphire import Sapphire -def run_sapphire(): - fasttext_path = sys.argv[1] - print(' * Loading pre-trained model ...', flush=True, end='') - model = fasttext.FastText.load_model(fasttext_path) - print(' * - completed') - - aligner = Sapphire(model) +formatter = '%(levelname)s : %(asctime)s : %(message)s' +logging.basicConfig(level=logging.INFO, format=formatter) - while True: - sentence_src = '' - while sentence_src == '': - print('Input tokenized sentence (A)') - sentence_src = input('> ') - if sentence_src == 'EXIT': - break +def run_sapphire(args): + logging.info('loading pre-trained model') + model = fasttext.FastText.load_model(args.model_path) + logging.info('loading completed') - sentence_trg = '' - while sentence_trg == '': - print('Input tokenized sentence (B)') - sentence_trg = input('> ') - if sentence_trg == 'exit': - break + aligner = Sapphire(model=model) + aligner.set_params(lambda_=args.lambda_, + delta=args.delta, + alpha=args.alpha, + hungarian=args.use_hungarian) - tokens_src = sentence_src.split() - tokens_trg = sentence_trg.split() + try: + while True: + print('\n' + '=' * 80) + sentence_src = input('Input tokenized sentence (A)\n>>> ') + sentence_trg = input('Input tokenized sentence (B)\n>>> ') - result = aligner.align(tokens_src, tokens_trg) - alignment = result.top_alignment[0][0] + if sentence_src == '' or sentence_trg == '': + logging.warning('please input two sentences!') + continue - if not alignment: - continue + tokens_src = sentence_src.split() + tokens_trg = sentence_trg.split() + _, alignment = aligner.align(tokens_src, tokens_trg) - print('-'.join(['-' for i in range(0, 20)])) - print('\n * Result') - print('{0:^20}{1:<6}{2:^20}'.format('A', '', 'B')) + print('{:-^48}'.format(' Result ')) + print('{0:^24}{1:^24}'.format('Sentence A', 'Sentence B')) - for al in alignment: - text_src = ' '.join(tokens_src[al[0] - 1 : al[1]]) - text_trg = ' '.join(tokens_trg[al[2] - 1 : al[3]]) - print('{0:<20}{1:<6}{2:<20}'.format(text_src, '<-->', text_trg)) - - print('-'.join(['-' for i in range(0, 20)])) + for src_s, src_e, trg_s, trg_e in alignment: + src_txt = ' '.join(tokens_src[src_s - 1:src_e]) + trg_txt = ' '.join(tokens_trg[trg_s - 1:trg_e]) + print('{0:>20}{1:^8}{2:<20}'.format(src_txt, '<-->', trg_txt)) + + except KeyboardInterrupt: + print() + logging.info('interrupted') + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('model_path', + default='model/wiki-news-300d-1M-subword.bin', + help='path to fastText model') + parser.add_argument('--lambda_', type=float, default=0.6, + help='threshold of word alignment candidate score') + parser.add_argument('--delta', type=float, default=0.6, + help='threshold of phrase alignment candidate score') + parser.add_argument('--alpha', type=float, default=0.01, + help='bias for length of phrase') + parser.add_argument('--use_hungarian', action='store_true', + help='use Hungarian-method for word alignment') + args = parser.parse_args() + + run_sapphire(args) if __name__ == '__main__': - run_sapphire() + main()