diff --git a/sapphire/phrase_alignment.py b/sapphire/phrase_alignment.py index f272f30..9d41ff5 100644 --- a/sapphire/phrase_alignment.py +++ b/sapphire/phrase_alignment.py @@ -6,26 +6,54 @@ class PhraseExtract(object): - def __init__(self): - self.name = '' + def __init__(self, delta, alpha): + self.delta = delta + self.alpha = alpha + + def __call__(self, word_alignments, vectors_src, vectors_trg): + return self.extract(word_alignments, vectors_src, vectors_trg) @staticmethod - def _no_adpoint(s_start, s_end, t_start, t_end, matrix): + def _no_additional_point(ss, se, ts, te, matrix): + """ + Check if there are any more adjacent points to be added. + ss: the index of the start of source phrase + se: the index of the end of source phrase + ts: the index of the start of target phrase + te: the index of the end of target phrase + """ len_src, len_trg = matrix.shape - if s_start - 1 >= 0 and np.any(matrix[s_start - 1, :][t_start:t_end + 1]): + if ss - 1 >= 0 and np.any(matrix[ss - 1, :][ts:te + 1]): return False - elif s_end + 1 < len_src and np.any(matrix[s_end + 1, :][t_start:t_end + 1]): + elif se + 1 < len_src and np.any(matrix[se + 1, :][ts:te + 1]): return False - elif t_start - 1 >= 0 and np.any(matrix[:, t_start - 1][s_start:s_end + 1]): + elif ts - 1 >= 0 and np.any(matrix[:, ts - 1][ss:se + 1]): return False - elif t_end + 1 < len_trg and np.any(matrix[:, t_end + 1][s_start:s_end + 1]): + elif te + 1 < len_trg and np.any(matrix[:, te + 1][ss:se + 1]): return False else: return True - def extract(self, word_alignments: list, vectors_src: np.array, vectors_trg: np.array, - delta, alpha) -> list: + def extract( + self, word_alignments, vectors_src: np.array, vectors_trg: np.array + ): + """ + Extract phrase pairs using the hueristic of phrase-based SMT. + + Parameters + ---------- + word_alignments : list + A return value of 'align' method in WordAlign class. + vectors_src, vectors_trg : np.array + Matrix of similarities of word embeddings. + + Returns + ------- + list + All candidates of phrase alignment. + """ + phrase_dict = {} len_src = len(vectors_src) len_trg = len(vectors_trg) @@ -34,34 +62,37 @@ def extract(self, word_alignments: list, vectors_src: np.array, vectors_trg: np. for s, t in word_alignments: matrix[s - 1][t - 1] = 1 - for (src1, trg1), (src2, trg2) in itertools.product(word_alignments, word_alignments): - ### s_start, s_end, t_start and t_end are 0-index ### - s_start, s_end = min(src1 - 1, src2 - 1), max(src1 - 1, src2 - 1) - t_start, t_end = min(trg1 - 1, trg2 - 1), max(trg1 - 1, trg2 - 1) + for (src1, trg1), (src2, trg2) in itertools.product(word_alignments, + word_alignments): + ss, se = min(src1 - 1, src2 - 1), max(src1 - 1, src2 - 1) + ts, te = min(trg1 - 1, trg2 - 1), max(trg1 - 1, trg2 - 1) + # ss, se, ts and te are 0-index at this time while True: - if s_start - 1 >= 0 and np.any(matrix[s_start - 1, :][t_start:t_end + 1]): - s_start -= 1 - if s_end + 1 < len_src and np.any(matrix[s_end + 1, :][t_start:t_end + 1]): - s_end += 1 - if t_start - 1 >= 0 and np.any(matrix[:, t_start - 1][s_start:s_end + 1]): - t_start -= 1 - if t_end + 1 < len_trg and np.any(matrix[:, t_end + 1][s_start:s_end + 1]): - t_end += 1 - - if self._no_adpoint(s_start, s_end, t_start, t_end, matrix): + if ss - 1 >= 0 and np.any(matrix[ss - 1, :][ts:te + 1]): + ss -= 1 + if se + 1 < len_src and np.any(matrix[se + 1, :][ts:te + 1]): + se += 1 + if ts - 1 >= 0 and np.any(matrix[:, ts - 1][ss:se + 1]): + ts -= 1 + if te + 1 < len_trg and np.any(matrix[:, te + 1][ss:se + 1]): + te += 1 + + if self._no_additional_point(ss, se, ts, te, matrix): break - if (s_start + 1, s_end + 1, t_start + 1, t_end + 1) not in phrase_dict: - phrase_vec_src = np.array(vectors_src[s_start:s_end + 1]).mean(axis=0) - phrase_vec_trg = np.array(vectors_trg[t_start:t_end + 1]).mean(axis=0) - sim = 1 - distance.cosine(phrase_vec_src, phrase_vec_trg) + if (ss + 1, se + 1, ts + 1, te + 1) not in phrase_dict: + phrase_vec_src = np.array( + vectors_src[ss:se + 1]).mean(axis=0) + phrase_vec_trg = np.array( + vectors_trg[ts:te + 1]).mean(axis=0) - sim -= alpha / (s_end - s_start + t_end - t_start + 2) - - phrase_dict[(s_start + 1, s_end + 1, t_start + 1, t_end + 1)] = sim + sim = 1 - distance.cosine(phrase_vec_src, phrase_vec_trg) + sim -= self.alpha / (se - ss + te - ts + 2) + phrase_dict[(ss + 1, se + 1, ts + 1, te + 1)] = sim - phrase_pairs = [(k[0], k[1], k[2], k[3], v) for k, v in phrase_dict.items() if v >= delta] + phrase_pairs = [(k[0], k[1], k[2], k[3], v) + for k, v in phrase_dict.items() if v >= self.delta] phrase_pairs.sort(key=lambda x: (x[0], x[2], x[1], x[3])) return phrase_pairs @@ -72,24 +103,51 @@ class PhraseAlign(object): def __init__(self): self.name = '' + def __call__(self, phrase_pairs, len_src, len_trg): + return self.search_for_lattice(phrase_pairs, len_src, len_trg) + @staticmethod - def create_lattice(phrase_pairs: list, len_src: int, len_trg: int) -> list: - node_list = defaultdict(lambda: defaultdict(list)) + def search_for_lattice(phrase_pairs, len_src: int, len_trg: int): + """ + Construct a lattice of phrase pairs and depth-first search for the + path with the highest total alignment score. + + Parameters + ---------- + phrase_pairs : list + A return value of 'extract' method in PhraseExtract class. + len_src, len_trg : int + Length of sentence. + + Return + ------ + list + List of tuples consisting of indexes of phrase pairs + = one of the phrase alignments + = the path of the lattice with the highest total alignment score. + """ - bos_node = {'index': (0, 0, 0, 0), 'sim': 0, 'next': []} - eos_node = {'index': (len_src + 1, len_src + 1, len_trg + 1, len_trg + 1), 'sim': 0, 'next': []} + node_list = defaultdict(lambda: defaultdict(list)) + bos_node = {'index': (0, 0, 0, 0), + 'score': 0, 'next': []} + eos_node = {'index': (len_src + 1, len_src + 1, + len_trg + 1, len_trg + 1), + 'score': 0, 'next': []} node_list[0][0].append(bos_node) node_list[len_src + 1][len_trg + 1].append(eos_node) def _forward(s, t, start_node, end_node, pairs): - path = [] + """Depth-first search for a lattice.""" + path = [] if start_node == end_node or not pairs: - return [[sum(similarity)]] + return [[sum(alignment_scores)]] - min_s, _, min_t, _, _ = min(pairs, key=lambda x: ((x[0] - s) ** 2 + (x[2] - t) ** 2)) + min_s, _, min_t, _, _ = min(pairs, key=lambda x: ( + (x[0] - s) ** 2 + (x[2] - t) ** 2)) min_dist = (min_s - s) ** 2 + (min_t - t) ** 2 - nearest_pairs = [p for p in pairs if (p[0] - s) ** 2 + (p[2] - t) ** 2 == min_dist] + nearest_pairs = [p for p in pairs + if (p[0] - s) ** 2 + (p[2] - t) ** 2 == min_dist] for pair in pairs[len(nearest_pairs):]: nearer = False @@ -101,61 +159,73 @@ def _forward(s, t, start_node, end_node, pairs): nearest_pairs.append(pair) for next_pair in nearest_pairs: - s_start, s_end, t_start, t_end, sim = next_pair - next_node = {'index': (s_start, s_end, t_start, t_end), 'sim': sim, 'next': []} - rest_pairs = [p for p in pairs if p[0] > s_end and p[2] > t_end] + ss, se, ts, te, __score = next_pair + next_node = {'index': (ss, se, ts, te), + 'score': __score, 'next': []} + rest_pairs = [p for p in pairs + if p[0] > se and p[2] > te] checked = False - for checked_node in node_list[s_start][t_start]: + for checked_node in node_list[ss][ts]: if next_node['index'] == checked_node['index']: next_node = checked_node checked = True break - if not checked: - node_list[s_start][t_start].append(next_node) + if not checked: + node_list[ss][ts].append(next_node) if next_node != end_node: - similarity.append(next_node['sim']) + alignment_scores.append(next_node['score']) - for solution in _forward(s_end + 1, t_end + 1, next_node, end_node, rest_pairs): + for solution in _forward(se + 1, te + 1, + next_node, end_node, rest_pairs): ids = start_node['index'] path.append([(ids)] + solution) if next_node != end_node: - similarity.pop() + alignment_scores.pop() return path if not phrase_pairs: return [([], 0)] - _s_start, _s_end, _t_start, _t_end, _sim = sorted(phrase_pairs, key=lambda x: x[4], reverse=True)[0] - top_node = {'index': (_s_start, _s_end, _t_start, _t_end), 'sim': _sim, 'next': []} - node_list[_s_start][_t_start].append(top_node) + s_start, s_end, t_start, t_end, score = sorted( + phrase_pairs, key=lambda x: x[4], reverse=True)[0] + top_node = {'index': (s_start, s_end, t_start, + t_end), 'score': score, 'next': []} + node_list[s_start][t_start].append(top_node) top_index = [top_node['index']] - prev_pairs = [p for p in phrase_pairs if p[1] < _s_start and p[3] < _t_start] - prev_pairs.append((_s_start, _s_end, _t_start, _t_end, _sim)) - - next_pairs = [p for p in phrase_pairs if p[0] > _s_end and p[2] > _t_end] - next_pairs.append((len_src + 1, len_src + 1, len_trg + 1, len_trg + 1, 0)) - - similarity = [] - prev_align = [(sol[1:-1], sol[-1]) for sol in - _forward(1, 1, bos_node, top_node, prev_pairs)] - - similarity = [] - next_align = [(sol[1:-1], sol[-1]) for sol in - _forward(_s_end + 1, _t_end + 1, top_node, eos_node, next_pairs)] + prev_pairs = [p for p in phrase_pairs + if p[1] < s_start and p[3] < t_start] + prev_pairs.append((s_start, s_end, t_start, t_end, score)) + next_pairs = [p for p in phrase_pairs + if p[0] > s_end and p[2] > t_end] + next_pairs.append((len_src + 1, len_src + 1, + len_trg + 1, len_trg + 1, 0)) + + alignment_scores = [] # Initialize the stack of alignment scores + prev_align = [ + (sol[1:-1], sol[-1]) for sol + in _forward(1, 1, bos_node, top_node, prev_pairs) + ] + + alignment_scores = [] # Re-initialize the stack of alignment scores + next_align = [ + (sol[1:-1], sol[-1]) for sol + in _forward(s_end + 1, t_end + 1, top_node, eos_node, next_pairs) + ] alignments = [] for prev_path, next_path in itertools.product(prev_align, next_align): concat_path = prev_path[0] + top_index + next_path[0] length = len(concat_path) - score = (prev_path[1] + next_path[1] + _sim) / length if length != 0 else 0 + score = (prev_path[1] + next_path[1] + score) / length \ + if length != 0 else 0 alignments.append((concat_path, str(score))) alignments.sort(key=lambda x: float(x[1]), reverse=True) - return alignments[0][0] + return alignments[0][0] # Return only the top one of phrase alignments diff --git a/sapphire/sapphire.py b/sapphire/sapphire.py index 654f414..da9c813 100644 --- a/sapphire/sapphire.py +++ b/sapphire/sapphire.py @@ -10,7 +10,7 @@ def __init__(self, model): self.vectorizer = FastTextVectorize(model) self.set_params() self.word_aligner = WordAlign(self.lambda_, self.use_hungarian) - self.extractor = PhraseExtract() + self.extractor = PhraseExtract(self.delta, self.alpha) self.phrase_aligner = PhraseAlign() def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False): @@ -29,12 +29,7 @@ def align(self, tokens_src: list, tokens_trg: list): sim_matrix = get_similarity_matrix(vectors_src, vectors_trg) word_alignment = self.word_aligner(sim_matrix) - phrase_pairs = self.extractor.extract(word_alignment, - vectors_src, - vectors_trg, - self.delta, - self.alpha) - phrase_alignment = self.phrase_aligner.create_lattice(phrase_pairs, - len_src, - len_trg) + phrase_pairs = self.extractor(word_alignment, vectors_src, vectors_trg) + phrase_alignment = self.phrase_aligner(phrase_pairs, len_src, len_trg) + return word_alignment, phrase_alignment