Skip to content

Commit

Permalink
[update] Refactor PhraseExtract & PhraseAlign
Browse files Browse the repository at this point in the history
- An instance of PhraseExtract has some parameters as variables
- Add '__call__' method
- Add some docstrings
  • Loading branch information
m-yoshinaka committed Oct 5, 2020
1 parent 611f819 commit 24eb132
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 75 deletions.
202 changes: 136 additions & 66 deletions sapphire/phrase_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,54 @@

class PhraseExtract(object):

def __init__(self):
self.name = ''
def __init__(self, delta, alpha):
self.delta = delta
self.alpha = alpha

def __call__(self, word_alignments, vectors_src, vectors_trg):
return self.extract(word_alignments, vectors_src, vectors_trg)

@staticmethod
def _no_adpoint(s_start, s_end, t_start, t_end, matrix):
def _no_additional_point(ss, se, ts, te, matrix):
"""
Check if there are any more adjacent points to be added.
ss: the index of the start of source phrase
se: the index of the end of source phrase
ts: the index of the start of target phrase
te: the index of the end of target phrase
"""
len_src, len_trg = matrix.shape

if s_start - 1 >= 0 and np.any(matrix[s_start - 1, :][t_start:t_end + 1]):
if ss - 1 >= 0 and np.any(matrix[ss - 1, :][ts:te + 1]):
return False
elif s_end + 1 < len_src and np.any(matrix[s_end + 1, :][t_start:t_end + 1]):
elif se + 1 < len_src and np.any(matrix[se + 1, :][ts:te + 1]):
return False
elif t_start - 1 >= 0 and np.any(matrix[:, t_start - 1][s_start:s_end + 1]):
elif ts - 1 >= 0 and np.any(matrix[:, ts - 1][ss:se + 1]):
return False
elif t_end + 1 < len_trg and np.any(matrix[:, t_end + 1][s_start:s_end + 1]):
elif te + 1 < len_trg and np.any(matrix[:, te + 1][ss:se + 1]):
return False
else:
return True

def extract(self, word_alignments: list, vectors_src: np.array, vectors_trg: np.array,
delta, alpha) -> list:
def extract(
self, word_alignments, vectors_src: np.array, vectors_trg: np.array
):
"""
Extract phrase pairs using the hueristic of phrase-based SMT.
Parameters
----------
word_alignments : list
A return value of 'align' method in WordAlign class.
vectors_src, vectors_trg : np.array
Matrix of similarities of word embeddings.
Returns
-------
list
All candidates of phrase alignment.
"""

phrase_dict = {}
len_src = len(vectors_src)
len_trg = len(vectors_trg)
Expand All @@ -34,34 +62,37 @@ def extract(self, word_alignments: list, vectors_src: np.array, vectors_trg: np.
for s, t in word_alignments:
matrix[s - 1][t - 1] = 1

for (src1, trg1), (src2, trg2) in itertools.product(word_alignments, word_alignments):
### s_start, s_end, t_start and t_end are 0-index ###
s_start, s_end = min(src1 - 1, src2 - 1), max(src1 - 1, src2 - 1)
t_start, t_end = min(trg1 - 1, trg2 - 1), max(trg1 - 1, trg2 - 1)
for (src1, trg1), (src2, trg2) in itertools.product(word_alignments,
word_alignments):
ss, se = min(src1 - 1, src2 - 1), max(src1 - 1, src2 - 1)
ts, te = min(trg1 - 1, trg2 - 1), max(trg1 - 1, trg2 - 1)
# ss, se, ts and te are 0-index at this time

while True:
if s_start - 1 >= 0 and np.any(matrix[s_start - 1, :][t_start:t_end + 1]):
s_start -= 1
if s_end + 1 < len_src and np.any(matrix[s_end + 1, :][t_start:t_end + 1]):
s_end += 1
if t_start - 1 >= 0 and np.any(matrix[:, t_start - 1][s_start:s_end + 1]):
t_start -= 1
if t_end + 1 < len_trg and np.any(matrix[:, t_end + 1][s_start:s_end + 1]):
t_end += 1

if self._no_adpoint(s_start, s_end, t_start, t_end, matrix):
if ss - 1 >= 0 and np.any(matrix[ss - 1, :][ts:te + 1]):
ss -= 1
if se + 1 < len_src and np.any(matrix[se + 1, :][ts:te + 1]):
se += 1
if ts - 1 >= 0 and np.any(matrix[:, ts - 1][ss:se + 1]):
ts -= 1
if te + 1 < len_trg and np.any(matrix[:, te + 1][ss:se + 1]):
te += 1

if self._no_additional_point(ss, se, ts, te, matrix):
break

if (s_start + 1, s_end + 1, t_start + 1, t_end + 1) not in phrase_dict:
phrase_vec_src = np.array(vectors_src[s_start:s_end + 1]).mean(axis=0)
phrase_vec_trg = np.array(vectors_trg[t_start:t_end + 1]).mean(axis=0)
sim = 1 - distance.cosine(phrase_vec_src, phrase_vec_trg)
if (ss + 1, se + 1, ts + 1, te + 1) not in phrase_dict:
phrase_vec_src = np.array(
vectors_src[ss:se + 1]).mean(axis=0)
phrase_vec_trg = np.array(
vectors_trg[ts:te + 1]).mean(axis=0)

sim -= alpha / (s_end - s_start + t_end - t_start + 2)

phrase_dict[(s_start + 1, s_end + 1, t_start + 1, t_end + 1)] = sim
sim = 1 - distance.cosine(phrase_vec_src, phrase_vec_trg)
sim -= self.alpha / (se - ss + te - ts + 2)
phrase_dict[(ss + 1, se + 1, ts + 1, te + 1)] = sim

phrase_pairs = [(k[0], k[1], k[2], k[3], v) for k, v in phrase_dict.items() if v >= delta]
phrase_pairs = [(k[0], k[1], k[2], k[3], v)
for k, v in phrase_dict.items() if v >= self.delta]
phrase_pairs.sort(key=lambda x: (x[0], x[2], x[1], x[3]))

return phrase_pairs
Expand All @@ -72,24 +103,51 @@ class PhraseAlign(object):
def __init__(self):
self.name = ''

def __call__(self, phrase_pairs, len_src, len_trg):
return self.search_for_lattice(phrase_pairs, len_src, len_trg)

@staticmethod
def create_lattice(phrase_pairs: list, len_src: int, len_trg: int) -> list:
node_list = defaultdict(lambda: defaultdict(list))
def search_for_lattice(phrase_pairs, len_src: int, len_trg: int):
"""
Construct a lattice of phrase pairs and depth-first search for the
path with the highest total alignment score.
Parameters
----------
phrase_pairs : list
A return value of 'extract' method in PhraseExtract class.
len_src, len_trg : int
Length of sentence.
Return
------
list
List of tuples consisting of indexes of phrase pairs
= one of the phrase alignments
= the path of the lattice with the highest total alignment score.
"""

bos_node = {'index': (0, 0, 0, 0), 'sim': 0, 'next': []}
eos_node = {'index': (len_src + 1, len_src + 1, len_trg + 1, len_trg + 1), 'sim': 0, 'next': []}
node_list = defaultdict(lambda: defaultdict(list))
bos_node = {'index': (0, 0, 0, 0),
'score': 0, 'next': []}
eos_node = {'index': (len_src + 1, len_src + 1,
len_trg + 1, len_trg + 1),
'score': 0, 'next': []}
node_list[0][0].append(bos_node)
node_list[len_src + 1][len_trg + 1].append(eos_node)

def _forward(s, t, start_node, end_node, pairs):
path = []
"""Depth-first search for a lattice."""

path = []
if start_node == end_node or not pairs:
return [[sum(similarity)]]
return [[sum(alignment_scores)]]

min_s, _, min_t, _, _ = min(pairs, key=lambda x: ((x[0] - s) ** 2 + (x[2] - t) ** 2))
min_s, _, min_t, _, _ = min(pairs, key=lambda x: (
(x[0] - s) ** 2 + (x[2] - t) ** 2))
min_dist = (min_s - s) ** 2 + (min_t - t) ** 2
nearest_pairs = [p for p in pairs if (p[0] - s) ** 2 + (p[2] - t) ** 2 == min_dist]
nearest_pairs = [p for p in pairs
if (p[0] - s) ** 2 + (p[2] - t) ** 2 == min_dist]

for pair in pairs[len(nearest_pairs):]:
nearer = False
Expand All @@ -101,61 +159,73 @@ def _forward(s, t, start_node, end_node, pairs):
nearest_pairs.append(pair)

for next_pair in nearest_pairs:
s_start, s_end, t_start, t_end, sim = next_pair
next_node = {'index': (s_start, s_end, t_start, t_end), 'sim': sim, 'next': []}
rest_pairs = [p for p in pairs if p[0] > s_end and p[2] > t_end]
ss, se, ts, te, __score = next_pair
next_node = {'index': (ss, se, ts, te),
'score': __score, 'next': []}
rest_pairs = [p for p in pairs
if p[0] > se and p[2] > te]

checked = False
for checked_node in node_list[s_start][t_start]:
for checked_node in node_list[ss][ts]:
if next_node['index'] == checked_node['index']:
next_node = checked_node
checked = True
break
if not checked:
node_list[s_start][t_start].append(next_node)

if not checked:
node_list[ss][ts].append(next_node)
if next_node != end_node:
similarity.append(next_node['sim'])
alignment_scores.append(next_node['score'])

for solution in _forward(s_end + 1, t_end + 1, next_node, end_node, rest_pairs):
for solution in _forward(se + 1, te + 1,
next_node, end_node, rest_pairs):
ids = start_node['index']
path.append([(ids)] + solution)

if next_node != end_node:
similarity.pop()
alignment_scores.pop()

return path

if not phrase_pairs:
return [([], 0)]

_s_start, _s_end, _t_start, _t_end, _sim = sorted(phrase_pairs, key=lambda x: x[4], reverse=True)[0]
top_node = {'index': (_s_start, _s_end, _t_start, _t_end), 'sim': _sim, 'next': []}
node_list[_s_start][_t_start].append(top_node)
s_start, s_end, t_start, t_end, score = sorted(
phrase_pairs, key=lambda x: x[4], reverse=True)[0]
top_node = {'index': (s_start, s_end, t_start,
t_end), 'score': score, 'next': []}
node_list[s_start][t_start].append(top_node)

top_index = [top_node['index']]

prev_pairs = [p for p in phrase_pairs if p[1] < _s_start and p[3] < _t_start]
prev_pairs.append((_s_start, _s_end, _t_start, _t_end, _sim))

next_pairs = [p for p in phrase_pairs if p[0] > _s_end and p[2] > _t_end]
next_pairs.append((len_src + 1, len_src + 1, len_trg + 1, len_trg + 1, 0))

similarity = []
prev_align = [(sol[1:-1], sol[-1]) for sol in
_forward(1, 1, bos_node, top_node, prev_pairs)]

similarity = []
next_align = [(sol[1:-1], sol[-1]) for sol in
_forward(_s_end + 1, _t_end + 1, top_node, eos_node, next_pairs)]
prev_pairs = [p for p in phrase_pairs
if p[1] < s_start and p[3] < t_start]
prev_pairs.append((s_start, s_end, t_start, t_end, score))
next_pairs = [p for p in phrase_pairs
if p[0] > s_end and p[2] > t_end]
next_pairs.append((len_src + 1, len_src + 1,
len_trg + 1, len_trg + 1, 0))

alignment_scores = [] # Initialize the stack of alignment scores
prev_align = [
(sol[1:-1], sol[-1]) for sol
in _forward(1, 1, bos_node, top_node, prev_pairs)
]

alignment_scores = [] # Re-initialize the stack of alignment scores
next_align = [
(sol[1:-1], sol[-1]) for sol
in _forward(s_end + 1, t_end + 1, top_node, eos_node, next_pairs)
]

alignments = []
for prev_path, next_path in itertools.product(prev_align, next_align):
concat_path = prev_path[0] + top_index + next_path[0]
length = len(concat_path)
score = (prev_path[1] + next_path[1] + _sim) / length if length != 0 else 0
score = (prev_path[1] + next_path[1] + score) / length \
if length != 0 else 0
alignments.append((concat_path, str(score)))

alignments.sort(key=lambda x: float(x[1]), reverse=True)

return alignments[0][0]
return alignments[0][0] # Return only the top one of phrase alignments
13 changes: 4 additions & 9 deletions sapphire/sapphire.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, model):
self.vectorizer = FastTextVectorize(model)
self.set_params()
self.word_aligner = WordAlign(self.lambda_, self.use_hungarian)
self.extractor = PhraseExtract()
self.extractor = PhraseExtract(self.delta, self.alpha)
self.phrase_aligner = PhraseAlign()

def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False):
Expand All @@ -29,12 +29,7 @@ def align(self, tokens_src: list, tokens_trg: list):
sim_matrix = get_similarity_matrix(vectors_src, vectors_trg)
word_alignment = self.word_aligner(sim_matrix)

phrase_pairs = self.extractor.extract(word_alignment,
vectors_src,
vectors_trg,
self.delta,
self.alpha)
phrase_alignment = self.phrase_aligner.create_lattice(phrase_pairs,
len_src,
len_trg)
phrase_pairs = self.extractor(word_alignment, vectors_src, vectors_trg)
phrase_alignment = self.phrase_aligner(phrase_pairs, len_src, len_trg)

return word_alignment, phrase_alignment

0 comments on commit 24eb132

Please sign in to comment.