Skip to content

Commit

Permalink
[add] Add parameters: prune_k and get_score
Browse files Browse the repository at this point in the history
- 'prune_k': reducing num. of nodes for searching
- 'get_score': specifying whether to return an alignment score
  • Loading branch information
m-yoshinaka committed Nov 10, 2020
1 parent f6af6b4 commit 77df78d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
12 changes: 10 additions & 2 deletions sapphire/phrase_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ class PhraseAlign(object):
def __init__(self):
self.name = ''

def __call__(self, phrase_pairs, len_src, len_trg):
def __call__(self, phrase_pairs, len_src, len_trg,
prune_k=-1, get_score=False):
return self.search_for_lattice(phrase_pairs, len_src, len_trg)

@staticmethod
def search_for_lattice(phrase_pairs, len_src: int, len_trg: int):
def search_for_lattice(phrase_pairs, len_src: int, len_trg: int,
prune_k=-1, get_score=False):
"""
Construct a lattice of phrase pairs and depth-first search for the
path with the highest total alignment score.
Expand Down Expand Up @@ -169,6 +171,9 @@ def _forward(s, t, start_node, end_node, pairs):
if not nearer:
nearest_pairs.append(pair)

if prune_k != -1:
nearest_pairs = nearest_pairs[:prune_k]

for next_pair in nearest_pairs:
ss, se, ts, te, __score = next_pair
next_node = {'index': (ss, se, ts, te),
Expand Down Expand Up @@ -239,4 +244,7 @@ def _forward(s, t, start_node, end_node, pairs):

alignments.sort(key=lambda x: float(x[1]), reverse=True)

if get_score:
return alignments[0]

return alignments[0][0] # Return only the top one of phrase alignments
11 changes: 9 additions & 2 deletions sapphire/sapphire.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@ def __init__(self, model):
self.delta = 0.6
self.alpha = 0.01
self.use_hungarian = False
self.prune_k = -1
self.get_score = False
self.word_aligner = WordAlign(self.lambda_, self.use_hungarian)
self.extractor = PhraseExtract(self.delta, self.alpha)
self.phrase_aligner = PhraseAlign()

def __call__(self, tokens_src, tokens_trg):
return self.align(tokens_src, tokens_trg)

def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False):
def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False,
prune_k=-1, get_score=False):
"""
Set hyper-parameters of SAPPHIRE.
Expand All @@ -63,6 +66,8 @@ def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False):
self.delta = delta
self.alpha = alpha
self.use_hungarian = hungarian
self.prune_k = prune_k
self.get_score = get_score
self.word_aligner.set_params(self.lambda_, self.use_hungarian)
self.extractor.set_params(self.delta, self.alpha)

Expand Down Expand Up @@ -90,6 +95,8 @@ def align(self, tokens_src: list, tokens_trg: list):
word_alignment = self.word_aligner(sim_matrix)

phrase_pairs = self.extractor(word_alignment, vectors_src, vectors_trg)
phrase_alignment = self.phrase_aligner(phrase_pairs, len_src, len_trg)
phrase_alignment = self.phrase_aligner(phrase_pairs, len_src, len_trg,
prune_k=self.prune_k,
get_score=self.get_score)

return word_alignment, phrase_alignment

0 comments on commit 77df78d

Please sign in to comment.