-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor according to PEP8 and make use of argparse and logging.
- Loading branch information
1 parent
b0b8192
commit c0716b1
Showing
1 changed file
with
54 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,69 @@ | ||
import sys | ||
import logging | ||
import argparse | ||
import fasttext | ||
|
||
from sapphire import Sapphire, setting | ||
from sapphire import Sapphire | ||
|
||
|
||
def run_sapphire(): | ||
fasttext_path = sys.argv[1] | ||
print(' * Loading pre-trained model ...', flush=True, end='') | ||
model = fasttext.FastText.load_model(fasttext_path) | ||
print(' * - completed') | ||
|
||
aligner = Sapphire(model) | ||
formatter = '%(levelname)s : %(asctime)s : %(message)s' | ||
logging.basicConfig(level=logging.INFO, format=formatter) | ||
|
||
while True: | ||
|
||
sentence_src = '' | ||
while sentence_src == '': | ||
print('Input tokenized sentence (A)') | ||
sentence_src = input('> ') | ||
if sentence_src == 'EXIT': | ||
break | ||
def run_sapphire(args): | ||
logging.info('loading pre-trained model') | ||
model = fasttext.FastText.load_model(args.model_path) | ||
logging.info('loading completed') | ||
|
||
sentence_trg = '' | ||
while sentence_trg == '': | ||
print('Input tokenized sentence (B)') | ||
sentence_trg = input('> ') | ||
if sentence_trg == 'exit': | ||
break | ||
aligner = Sapphire(model=model) | ||
aligner.set_params(lambda_=args.lambda_, | ||
delta=args.delta, | ||
alpha=args.alpha, | ||
hungarian=args.use_hungarian) | ||
|
||
tokens_src = sentence_src.split() | ||
tokens_trg = sentence_trg.split() | ||
try: | ||
while True: | ||
print('\n' + '=' * 80) | ||
sentence_src = input('Input tokenized sentence (A)\n>>> ') | ||
sentence_trg = input('Input tokenized sentence (B)\n>>> ') | ||
|
||
result = aligner.align(tokens_src, tokens_trg) | ||
alignment = result.top_alignment[0][0] | ||
if sentence_src == '' or sentence_trg == '': | ||
logging.warning('please input two sentences!') | ||
continue | ||
|
||
if not alignment: | ||
continue | ||
tokens_src = sentence_src.split() | ||
tokens_trg = sentence_trg.split() | ||
_, alignment = aligner.align(tokens_src, tokens_trg) | ||
|
||
print('-'.join(['-' for i in range(0, 20)])) | ||
print('\n * Result') | ||
print('{0:^20}{1:<6}{2:^20}'.format('A', '', 'B')) | ||
print('{:-^48}'.format(' Result ')) | ||
print('{0:^24}{1:^24}'.format('Sentence A', 'Sentence B')) | ||
|
||
for al in alignment: | ||
text_src = ' '.join(tokens_src[al[0] - 1 : al[1]]) | ||
text_trg = ' '.join(tokens_trg[al[2] - 1 : al[3]]) | ||
print('{0:<20}{1:<6}{2:<20}'.format(text_src, '<-->', text_trg)) | ||
|
||
print('-'.join(['-' for i in range(0, 20)])) | ||
for src_s, src_e, trg_s, trg_e in alignment: | ||
src_txt = ' '.join(tokens_src[src_s - 1:src_e]) | ||
trg_txt = ' '.join(tokens_trg[trg_s - 1:trg_e]) | ||
print('{0:>20}{1:^8}{2:<20}'.format(src_txt, '<-->', trg_txt)) | ||
|
||
except KeyboardInterrupt: | ||
print() | ||
logging.info('interrupted') | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('model_path', | ||
default='model/wiki-news-300d-1M-subword.bin', | ||
help='path to fastText model') | ||
parser.add_argument('--lambda_', type=float, default=0.6, | ||
help='threshold of word alignment candidate score') | ||
parser.add_argument('--delta', type=float, default=0.6, | ||
help='threshold of phrase alignment candidate score') | ||
parser.add_argument('--alpha', type=float, default=0.01, | ||
help='bias for length of phrase') | ||
parser.add_argument('--use_hungarian', action='store_true', | ||
help='use Hungarian-method for word alignment') | ||
args = parser.parse_args() | ||
|
||
run_sapphire(args) | ||
|
||
|
||
if __name__ == '__main__': | ||
run_sapphire() | ||
main() |