Skip to content

Commit

Permalink
[update] Refactor run_sapphire.py
Browse files Browse the repository at this point in the history
Refactor according to PEP8 and make use of argparse and logging.
  • Loading branch information
m-yoshinaka committed Sep 16, 2020
1 parent b0b8192 commit c0716b1
Showing 1 changed file with 54 additions and 38 deletions.
92 changes: 54 additions & 38 deletions run_sapphire.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,69 @@
import sys
import logging
import argparse
import fasttext

from sapphire import Sapphire, setting
from sapphire import Sapphire


def run_sapphire():
fasttext_path = sys.argv[1]
print(' * Loading pre-trained model ...', flush=True, end='')
model = fasttext.FastText.load_model(fasttext_path)
print(' * - completed')

aligner = Sapphire(model)
formatter = '%(levelname)s : %(asctime)s : %(message)s'
logging.basicConfig(level=logging.INFO, format=formatter)

while True:

sentence_src = ''
while sentence_src == '':
print('Input tokenized sentence (A)')
sentence_src = input('> ')
if sentence_src == 'EXIT':
break
def run_sapphire(args):
logging.info('loading pre-trained model')
model = fasttext.FastText.load_model(args.model_path)
logging.info('loading completed')

sentence_trg = ''
while sentence_trg == '':
print('Input tokenized sentence (B)')
sentence_trg = input('> ')
if sentence_trg == 'exit':
break
aligner = Sapphire(model=model)
aligner.set_params(lambda_=args.lambda_,
delta=args.delta,
alpha=args.alpha,
hungarian=args.use_hungarian)

tokens_src = sentence_src.split()
tokens_trg = sentence_trg.split()
try:
while True:
print('\n' + '=' * 80)
sentence_src = input('Input tokenized sentence (A)\n>>> ')
sentence_trg = input('Input tokenized sentence (B)\n>>> ')

result = aligner.align(tokens_src, tokens_trg)
alignment = result.top_alignment[0][0]
if sentence_src == '' or sentence_trg == '':
logging.warning('please input two sentences!')
continue

if not alignment:
continue
tokens_src = sentence_src.split()
tokens_trg = sentence_trg.split()
_, alignment = aligner.align(tokens_src, tokens_trg)

print('-'.join(['-' for i in range(0, 20)]))
print('\n * Result')
print('{0:^20}{1:<6}{2:^20}'.format('A', '', 'B'))
print('{:-^48}'.format(' Result '))
print('{0:^24}{1:^24}'.format('Sentence A', 'Sentence B'))

for al in alignment:
text_src = ' '.join(tokens_src[al[0] - 1 : al[1]])
text_trg = ' '.join(tokens_trg[al[2] - 1 : al[3]])
print('{0:<20}{1:<6}{2:<20}'.format(text_src, '<-->', text_trg))

print('-'.join(['-' for i in range(0, 20)]))
for src_s, src_e, trg_s, trg_e in alignment:
src_txt = ' '.join(tokens_src[src_s - 1:src_e])
trg_txt = ' '.join(tokens_trg[trg_s - 1:trg_e])
print('{0:>20}{1:^8}{2:<20}'.format(src_txt, '<-->', trg_txt))

except KeyboardInterrupt:
print()
logging.info('interrupted')


def main():
parser = argparse.ArgumentParser()
parser.add_argument('model_path',
default='model/wiki-news-300d-1M-subword.bin',
help='path to fastText model')
parser.add_argument('--lambda_', type=float, default=0.6,
help='threshold of word alignment candidate score')
parser.add_argument('--delta', type=float, default=0.6,
help='threshold of phrase alignment candidate score')
parser.add_argument('--alpha', type=float, default=0.01,
help='bias for length of phrase')
parser.add_argument('--use_hungarian', action='store_true',
help='use Hungarian-method for word alignment')
args = parser.parse_args()

run_sapphire(args)


if __name__ == '__main__':
run_sapphire()
main()

0 comments on commit c0716b1

Please sign in to comment.