Skip to content

Commit

Permalink
[update] Change the behavior of Sapphire class
Browse files Browse the repository at this point in the history
- Sapphire class returns only 1 word alignment & 1 phrase alignment
- Some fixes in *_alignment.py according to changes in Sapphire class
  • Loading branch information
m-yoshinaka committed Sep 16, 2020
1 parent b58f304 commit 1ea07f6
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 38 deletions.
2 changes: 1 addition & 1 deletion sapphire/phrase_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,4 @@ def _forward(s, t, start_node, end_node, pairs):

alignments.sort(key=lambda x: float(x[1]), reverse=True)

return alignments
return alignments[0][0]
51 changes: 21 additions & 30 deletions sapphire/sapphire.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,21 @@
from . import setting
from .word_alignment import FastTextVectorize, WordAlign
from .phrase_alignment import PhraseExtract, PhraseAlign


class SapphireAlignment(object):

def __init__(self, word_alignment: list, top_alignment: list):
self.name = ''
self.word_alignment = word_alignment
self.top_alignment = top_alignment


class Sapphire(object):

def __init__(self, model):
self.name = ''

self._hungarian = setting.HUGARIAN
self._lambda = setting.LAMBDA
self._delta = setting.DELTA
self._alpha = setting.ALPHA

self.vectorizer = FastTextVectorize(model)
self.word_aligner = WordAlign()
self.extractor = PhraseExtract()
self.phrase_aligner = PhraseAlign()
self.set_params()

def set_params(self, LAMBDA, DELTA, ALPHA, HUGARIAN=False):
self._lambda = LAMBDA
self._delta = DELTA
self._alpha = ALPHA
self._hungarian = HUGARIAN
def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False):
self.lambda_ = lambda_
self.delta = delta
self.alpha = alpha
self.hungarian = hungarian

def align(self, tokens_src: list, tokens_trg: list):
len_src = len(tokens_src)
Expand All @@ -39,12 +24,18 @@ def align(self, tokens_src: list, tokens_trg: list):
vectors_src = self.vectorizer.vectorize(tokens_src)
vectors_trg = self.vectorizer.vectorize(tokens_trg)

sim_matrix = self.word_aligner.similarity_matrix(vectors_src, vectors_trg)
word_alignment = self.word_aligner.align(sim_matrix, self._lambda, self._hungarian)

phrase_pairs = self.extractor.extract(word_alignment, vectors_src, vectors_trg, self._delta, self._alpha)
phrase_alignment = self.phrase_aligner.create_lattice(phrase_pairs, len_src, len_trg)

result = SapphireAlignment(word_alignment, phrase_alignment[:10])

return result
sim_matrix = self.word_aligner.similarity_matrix(vectors_src,
vectors_trg)
word_alignment = self.word_aligner.align(sim_matrix,
self.lambda_,
self.hungarian)

phrase_pairs = self.extractor.extract(word_alignment,
vectors_src,
vectors_trg,
self.delta,
self.alpha)
phrase_alignment = self.phrase_aligner.create_lattice(phrase_pairs,
len_src,
len_trg)
return word_alignment, phrase_alignment
7 changes: 4 additions & 3 deletions sapphire/word_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ class FastTextVectorize(WordEmbedding):
def __init__(self, model):
super().__init__()
self.model = model
self.dim = model.get_dimension()

def vectorize(self, words: list) -> np.array:
vector = []

if words:
for word in words:
vector.append(self.model.get_word_vector(word.lower()))
# vector.append(self.model.get_word_vector(word))
else:
vector.append(np.zeros(300))
vector.append(np.zeros(self.dim))

return np.array(vector)

Expand Down Expand Up @@ -99,7 +99,8 @@ def _final(matrix):

align_matrix = np.logical_and(src2trg, trg2src)
union_matrix = np.logical_or(src2trg, trg2src)
neighbors = [(-1, 0), (0, -1), (1, 0), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
neighbors = [(-1, 0), (0, -1), (1, 0), (0, 1),
(-1, -1), (-1, 1), (1, -1), (1, 1)]

_grow_diag()
_final(src2trg)
Expand Down
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ def read_requirements():

setup(
name='sapphire',
version='0.1.0',
description='Simple Aligner for Phrasal Paraphrase with Hierarchical Representation',
version='0.1.1',
description='Simple Aligner for Phrasal Paraphrase \
with Hierarchical Representation',
author='Masato Yoshinaka',
author_email='yoshinaka.masato@ist.osaka-u.ac.jp',
install_requires=read_requirements(),
url='https://github.com/mybon13/sapphire',
url='https://github.com/m-yoshinaka/sapphire',
packages=find_packages()
)
)

0 comments on commit 1ea07f6

Please sign in to comment.