Skip to content

Commit

Permalink
update some files
Browse files Browse the repository at this point in the history
  • Loading branch information
m-yoshinaka committed Dec 25, 2019
1 parent 4e05932 commit 04f1443
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 58 deletions.
48 changes: 24 additions & 24 deletions run_sapphire.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,39 @@ def run_sapphire():

while True:

sentence_src = ""
while sentence_src == "":
print("Input sentence1")
sentence_src = input("> ")
if sentence_src == "exit":
sentence_src = ''
while sentence_src == '':
print('Input sentence1')
sentence_src = input('> ')
if sentence_src == 'exit':
break

sentence_trg = ""
while sentence_trg == "":
print("Input sentence2")
sentence_trg = input("> ")
if sentence_trg == "exit":
sentence_trg = ''
while sentence_trg == '':
print('Input sentence2')
sentence_trg = input('> ')
if sentence_trg == 'exit':
break

tokens_src = sentence_src.split()
tokens_trg = sentence_trg.split()

alignment = aligner.align(tokens_src, tokens_trg)
print("\n{}\n".format(alignment))

if "-" in alignment:
for al in alignment.split(" "):
srcs, trgs = al.split("-")
src_s = int(srcs.split(",")[0])
src_e = int(srcs.split(",")[-1])
trg_s = int(trgs.split(",")[0])
trg_e = int(trgs.split(",")[-1])
text_src = " ".join(tokens_src[src_s - 1:src_e])
text_trg = " ".join(tokens_trg[trg_s - 1:trg_e])
print("{} <--> {}".format(text_src, text_trg))
print('\n{}\n'.format(alignment))

if '-' in alignment:
for al in alignment.split(' '):
srcs, trgs = al.split('-')
src_s = int(srcs.split(',')[0])
src_e = int(srcs.split(',')[-1])
trg_s = int(trgs.split(',')[0])
trg_e = int(trgs.split(',')[-1])
text_src = ' '.join(tokens_src[src_s - 1:src_e])
text_trg = ' '.join(tokens_trg[trg_s - 1:trg_e])
print('{} <--> {}'.format(text_src, text_trg))

print("---")
print('---')


if __name__ == "__main__":
if __name__ == '__main__':
run_sapphire()
40 changes: 20 additions & 20 deletions sapphire/phrase_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class PhraseExtract(object):

def __init__(self):
self.name = ""
self.name = ''

@staticmethod
def _no_adpoint(s_start, s_end, t_start, t_end, matrix):
Expand Down Expand Up @@ -70,14 +70,14 @@ def extract(self, word_alignments: list, vectors_src: np.array, vectors_trg: np.
class PhraseAlign(object):

def __init__(self):
self.name = ""
self.name = ''

@staticmethod
def create_lattice(phrase_pairs: list, len_src: int, len_trg: int) -> list:
node_list = defaultdict(lambda: defaultdict(list))

bos_node = {"index": (0, 0, 0, 0), "sim": 0, "next": []}
eos_node = {"index": (len_src + 1, len_src + 1, len_trg + 1, len_trg + 1), "sim": 0, "next": []}
bos_node = {'index': (0, 0, 0, 0), 'sim': 0, 'next': []}
eos_node = {'index': (len_src + 1, len_src + 1, len_trg + 1, len_trg + 1), 'sim': 0, 'next': []}
node_list[0][0].append(bos_node)
node_list[len_src + 1][len_trg + 1].append(eos_node)

Expand All @@ -102,42 +102,42 @@ def _forward(s, t, start_node, end_node, pairs):

for next_pair in nearest_pairs:
s_start, s_end, t_start, t_end, sim = next_pair
next_node = {"index": (s_start, s_end, t_start, t_end), "sim": sim, "next": []}
next_node = {'index': (s_start, s_end, t_start, t_end), 'sim': sim, 'next': []}
rest_pairs = [p for p in pairs if p[0] > s_end and p[2] > t_end]

checked = False
for checked_node in node_list[s_start][t_start]:
if next_node["index"] == checked_node["index"]:
if next_node['index'] == checked_node['index']:
next_node = checked_node
checked = True
break
if not checked:
node_list[s_start][t_start].append(next_node)

if next_node != end_node:
similarity.append(next_node["sim"])
similarity.append(next_node['sim'])

for solution in _forward(s_end + 1, t_end + 1, next_node, end_node, rest_pairs):
ids = start_node["index"]
ids_src = ",".join([str(i) for i in range(ids[0], ids[1] + 1)])
ids_trg = ",".join([str(i) for i in range(ids[2], ids[3] + 1)])
path.append(["{}-{}".format(ids_src, ids_trg)] + solution)
ids = start_node['index']
ids_src = ','.join([str(i) for i in range(ids[0], ids[1] + 1)])
ids_trg = ','.join([str(i) for i in range(ids[2], ids[3] + 1)])
path.append(['{}-{}'.format(ids_src, ids_trg)] + solution)

if next_node != end_node:
similarity.pop()

return path

if not phrase_pairs:
return [("", "0")]
return [('', '0')]

_s_start, _s_end, _t_start, _t_end, _sim = sorted(phrase_pairs, key=lambda x: x[4], reverse=True)[0]
top_node = {"index": (_s_start, _s_end, _t_start, _t_end), "sim": _sim, "next": []}
top_node = {'index': (_s_start, _s_end, _t_start, _t_end), 'sim': _sim, 'next': []}
node_list[_s_start][_t_start].append(top_node)

top_src = ",".join([str(i) for i in range(_s_start, _s_end + 1)])
top_trg = ",".join([str(i) for i in range(_t_start, _t_end + 1)])
top_index = "{}-{}".format(top_src, top_trg)
top_src = ','.join([str(i) for i in range(_s_start, _s_end + 1)])
top_trg = ','.join([str(i) for i in range(_t_start, _t_end + 1)])
top_index = '{}-{}'.format(top_src, top_trg)

prev_pairs = [p for p in phrase_pairs if p[1] < _s_start and p[3] < _t_start]
prev_pairs.append((_s_start, _s_end, _t_start, _t_end, _sim))
Expand All @@ -146,17 +146,17 @@ def _forward(s, t, start_node, end_node, pairs):
next_pairs.append((len_src + 1, len_src + 1, len_trg + 1, len_trg + 1, 0))

similarity = []
prev_align = [(" ".join(sol[1:-1]), sol[-1]) for sol in
prev_align = [(' '.join(sol[1:-1]), sol[-1]) for sol in
_forward(1, 1, bos_node, top_node, prev_pairs)]

similarity = []
next_align = [(" ".join(sol[1:-1]), sol[-1]) for sol in
next_align = [(' '.join(sol[1:-1]), sol[-1]) for sol in
_forward(_s_end + 1, _t_end + 1, top_node, eos_node, next_pairs)]

alignments = []
for prev_path, next_path in itertools.product(prev_align, next_align):
concat_path = (" ".join((prev_path[0], top_index, next_path[0]))).strip()
length = len(concat_path.split(" "))
concat_path = (' '.join((prev_path[0], top_index, next_path[0]))).strip()
length = len(concat_path.split(' '))
score = (prev_path[1] + next_path[1] + _sim) / length if length != 0 else 0
alignments.append((concat_path, str(score)))

Expand Down
2 changes: 1 addition & 1 deletion sapphire/sapphire.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class Sapphire(object):

def __init__(self):
self.name = ""
self.name = ''

self.model_path = setting.MODEL_PATH
self.hungarian = setting.HUGARIAN
Expand Down
8 changes: 4 additions & 4 deletions sapphire/setting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sys

FASTTEXT_PATH = "model/wiki-news-300d-1M-subword.bin"
FASTTEXT_PATH = 'model/wiki-news-300d-1M-subword.bin'

HUGARIAN = False # word alignment option (default: grow-diag-final)

Expand All @@ -10,7 +10,7 @@
ALPHA = 0.05 # bias for length of phrase

if os.path.exists(FASTTEXT_PATH):
MODEL_PATH = "model/wiki-news-300d-1M-subword.bin" # path of pre-trained word embedding model (default: fastText)
MODEL_PATH = 'model/wiki-news-300d-1M-subword.bin' # path of pre-trained word embedding model (default: fastText)
else:
print("Input the path of your pre-trained word embedding model.")
MODEL_PATH = input("> ")
print('Input the path of your pre-trained word embedding model.')
MODEL_PATH = input('> ')
6 changes: 3 additions & 3 deletions sapphire/word_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class FastTextVectorize(WordEmbedding):

def __init__(self, model_path):
super().__init__()
print("Loading model: ", flush=True, end="")
print('Loading model: ', flush=True, end='')
self.model = fasttext.FastText.load_model(model_path)
print("DONE\n")
print('DONE\n')

def vectorize(self, words: list) -> np.array:
vector = []
Expand All @@ -37,7 +37,7 @@ def vectorize(self, words: list) -> np.array:
class WordAlign(object):

def __init__(self):
self.name = ""
self.name = ''

@staticmethod
def similarity_matrix(vectors_src: np.array, vectors_trg: np.array) -> np.ndarray:
Expand Down
12 changes: 6 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ def read_requirements():


setup(
name="sapphire",
version="0.1.0",
description="Simple Aligner for Phrasal Paraphrase with Hierarchical Representation",
author="Masato Yoshinaka",
author_email="yoshinaka.masato@ist.osaka-u.ac.jp",
name='sapphire',
version='0.1.0',
description='Simple Aligner for Phrasal Paraphrase with Hierarchical Representation',
author='Masato Yoshinaka',
author_email='yoshinaka.masato@ist.osaka-u.ac.jp',
install_requires=read_requirements(),
url="https://github.com/mybon13/sapphire",
url='https://github.com/mybon13/sapphire',
packages=find_packages()
)

0 comments on commit 04f1443

Please sign in to comment.