Skip to content

Commit

Permalink
[update] Add a parameter & update PhraseAlign
Browse files Browse the repository at this point in the history
- 'epsilon' :  alignment score for a null alignment
- PhraseAlign class has new attributes
  • Loading branch information
m-yoshinaka committed Nov 10, 2020
1 parent fdb5b9a commit 60354d8
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 18 deletions.
44 changes: 31 additions & 13 deletions sapphire/phrase_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,20 @@ def extract(

class PhraseAlign(object):

def __init__(self):
self.name = ''
def __init__(self, prune_k=-1, get_score=False, epsilon=None):
self.prune_k = prune_k
self.get_score = get_score
self.epsilon = epsilon

def __call__(self, phrase_pairs, len_src, len_trg,
prune_k=-1, get_score=False):
return self.search_for_lattice(phrase_pairs, len_src, len_trg,
prune_k=prune_k, get_score=get_score)
def __call__(self, phrase_pairs, len_src, len_trg):
return self.search_for_lattice(phrase_pairs, len_src, len_trg)

@staticmethod
def search_for_lattice(phrase_pairs, len_src: int, len_trg: int,
prune_k=-1, get_score=False):
def set_params(self, prune_k, get_score, epsilon):
self.prune_k = prune_k
self.get_score = get_score
self.epsilon = epsilon

def search_for_lattice(self, phrase_pairs, len_src: int, len_trg: int):
"""
Construct a lattice of phrase pairs and depth-first search for the
path with the highest total alignment score.
Expand Down Expand Up @@ -172,8 +175,8 @@ def _forward(s, t, start_node, end_node, pairs):
if not nearer:
nearest_pairs.append(pair)

if prune_k != -1:
nearest_pairs = nearest_pairs[:prune_k]
if self.prune_k != -1:
nearest_pairs = nearest_pairs[:self.prune_k]

for next_pair in nearest_pairs:
ss, se, ts, te, __score = next_pair
Expand Down Expand Up @@ -205,7 +208,7 @@ def _forward(s, t, start_node, end_node, pairs):
return path

if not phrase_pairs:
return ([], 0) if get_score else []
return ([], 0) if self.get_score else []

s_start, s_end, t_start, t_end, score = sorted(
phrase_pairs, key=lambda x: x[4], reverse=True)[0]
Expand Down Expand Up @@ -243,9 +246,24 @@ def _forward(s, t, start_node, end_node, pairs):
if length != 0 else 0
alignments.append((concat_path, score))

if self.epsilon is not None:
new_alignments = []
for alignment, score in alignments:
nof_align = len(alignment)
nof_null_align = len_src + len_trg
for ss, se, ts, te in alignment:
nof_null_align -= se - ss + 1
nof_null_align -= te - ts + 1

score = (score * nof_align + self.epsilon * nof_null_align) \
/ (nof_align + nof_null_align)
new_alignments.append((alignment, score))

alignments = new_alignments

alignments.sort(key=lambda x: float(x[1]), reverse=True)

if get_score:
if self.get_score:
return alignments[0]

return alignments[0][0] # Return only the top one of phrase alignments
22 changes: 17 additions & 5 deletions sapphire/sapphire.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ def __init__(self, model):
self.use_hungarian = False
self.prune_k = -1
self.get_score = False
self.epsilon = None
self.word_aligner = WordAlign(self.lambda_, self.use_hungarian)
self.extractor = PhraseExtract(self.delta, self.alpha)
self.phrase_aligner = PhraseAlign()
self.phrase_aligner = PhraseAlign(self.prune_k,
self.get_score,
self.epsilon)

def __call__(self, tokens_src, tokens_trg):
return self.align(tokens_src, tokens_trg)

def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False,
prune_k=-1, get_score=False):
prune_k=-1, get_score=False, epsilon=None):
"""
Set hyper-parameters of SAPPHIRE.
Expand All @@ -61,15 +64,26 @@ def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False,
Biases the phrase alignment score based on the lengths of phrases.
hungarian : bool
Whether to use the extended Hangarian method to get word alignment.
prune_k : int
Prunes the number of nodes following a nodes in the lattice.
get_score : bool
Whether to output alignment scores with phrase alignments.
epsilon : float
Alignment score for a null alignment.
If epsilon is None, SAPPHIRE does not consider null alignment.
"""
self.lambda_ = lambda_
self.delta = delta
self.alpha = alpha
self.use_hungarian = hungarian
self.prune_k = prune_k
self.get_score = get_score
self.epsilon = epsilon
self.word_aligner.set_params(self.lambda_, self.use_hungarian)
self.extractor.set_params(self.delta, self.alpha)
self.phrase_aligner.set_params(self.prune_k,
self.get_score,
self.epsilon)

def align(self, tokens_src: list, tokens_trg: list):
"""
Expand All @@ -95,8 +109,6 @@ def align(self, tokens_src: list, tokens_trg: list):
word_alignment = self.word_aligner(sim_matrix)

phrase_pairs = self.extractor(word_alignment, vectors_src, vectors_trg)
phrase_alignment = self.phrase_aligner(phrase_pairs, len_src, len_trg,
prune_k=self.prune_k,
get_score=self.get_score)
phrase_alignment = self.phrase_aligner(phrase_pairs, len_src, len_trg)

return word_alignment, phrase_alignment

0 comments on commit 60354d8

Please sign in to comment.