diff --git a/sapphire/phrase_alignment.py b/sapphire/phrase_alignment.py index 895e73c..f272f30 100644 --- a/sapphire/phrase_alignment.py +++ b/sapphire/phrase_alignment.py @@ -158,4 +158,4 @@ def _forward(s, t, start_node, end_node, pairs): alignments.sort(key=lambda x: float(x[1]), reverse=True) - return alignments + return alignments[0][0] diff --git a/sapphire/sapphire.py b/sapphire/sapphire.py index cec26ec..49bd4dd 100644 --- a/sapphire/sapphire.py +++ b/sapphire/sapphire.py @@ -1,36 +1,21 @@ -from . import setting from .word_alignment import FastTextVectorize, WordAlign from .phrase_alignment import PhraseExtract, PhraseAlign -class SapphireAlignment(object): - - def __init__(self, word_alignment: list, top_alignment: list): - self.name = '' - self.word_alignment = word_alignment - self.top_alignment = top_alignment - - class Sapphire(object): def __init__(self, model): - self.name = '' - - self._hungarian = setting.HUGARIAN - self._lambda = setting.LAMBDA - self._delta = setting.DELTA - self._alpha = setting.ALPHA - self.vectorizer = FastTextVectorize(model) self.word_aligner = WordAlign() self.extractor = PhraseExtract() self.phrase_aligner = PhraseAlign() + self.set_params() - def set_params(self, LAMBDA, DELTA, ALPHA, HUGARIAN=False): - self._lambda = LAMBDA - self._delta = DELTA - self._alpha = ALPHA - self._hungarian = HUGARIAN + def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False): + self.lambda_ = lambda_ + self.delta = delta + self.alpha = alpha + self.hungarian = hungarian def align(self, tokens_src: list, tokens_trg: list): len_src = len(tokens_src) @@ -39,12 +24,18 @@ def align(self, tokens_src: list, tokens_trg: list): vectors_src = self.vectorizer.vectorize(tokens_src) vectors_trg = self.vectorizer.vectorize(tokens_trg) - sim_matrix = self.word_aligner.similarity_matrix(vectors_src, vectors_trg) - word_alignment = self.word_aligner.align(sim_matrix, self._lambda, self._hungarian) - - phrase_pairs = self.extractor.extract(word_alignment, vectors_src, vectors_trg, self._delta, self._alpha) - phrase_alignment = self.phrase_aligner.create_lattice(phrase_pairs, len_src, len_trg) - - result = SapphireAlignment(word_alignment, phrase_alignment[:10]) - - return result + sim_matrix = self.word_aligner.similarity_matrix(vectors_src, + vectors_trg) + word_alignment = self.word_aligner.align(sim_matrix, + self.lambda_, + self.hungarian) + + phrase_pairs = self.extractor.extract(word_alignment, + vectors_src, + vectors_trg, + self.delta, + self.alpha) + phrase_alignment = self.phrase_aligner.create_lattice(phrase_pairs, + len_src, + len_trg) + return word_alignment, phrase_alignment diff --git a/sapphire/word_alignment.py b/sapphire/word_alignment.py index 01ee3bd..70f0ef3 100644 --- a/sapphire/word_alignment.py +++ b/sapphire/word_alignment.py @@ -18,6 +18,7 @@ class FastTextVectorize(WordEmbedding): def __init__(self, model): super().__init__() self.model = model + self.dim = model.get_dimension() def vectorize(self, words: list) -> np.array: vector = [] @@ -25,9 +26,8 @@ def vectorize(self, words: list) -> np.array: if words: for word in words: vector.append(self.model.get_word_vector(word.lower())) - # vector.append(self.model.get_word_vector(word)) else: - vector.append(np.zeros(300)) + vector.append(np.zeros(self.dim)) return np.array(vector) @@ -99,7 +99,8 @@ def _final(matrix): align_matrix = np.logical_and(src2trg, trg2src) union_matrix = np.logical_or(src2trg, trg2src) - neighbors = [(-1, 0), (0, -1), (1, 0), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)] + neighbors = [(-1, 0), (0, -1), (1, 0), (0, 1), + (-1, -1), (-1, 1), (1, -1), (1, 1)] _grow_diag() _final(src2trg) diff --git a/setup.py b/setup.py index 473430c..347fb67 100644 --- a/setup.py +++ b/setup.py @@ -11,11 +11,12 @@ def read_requirements(): setup( name='sapphire', - version='0.1.0', - description='Simple Aligner for Phrasal Paraphrase with Hierarchical Representation', + version='0.1.1', + description='Simple Aligner for Phrasal Paraphrase \ + with Hierarchical Representation', author='Masato Yoshinaka', author_email='yoshinaka.masato@ist.osaka-u.ac.jp', install_requires=read_requirements(), - url='https://github.com/mybon13/sapphire', + url='https://github.com/m-yoshinaka/sapphire', packages=find_packages() -) \ No newline at end of file +)