-
Notifications
You must be signed in to change notification settings - Fork 0
/
sapphire.py
114 lines (98 loc) · 3.93 KB
/
sapphire.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from .word_alignment import (
FastTextVectorize, WordAlign, get_similarity_matrix
)
from .phrase_alignment import PhraseExtract, PhraseAlign
class Sapphire(object):
"""
SAPPHIRE : monolingual phrase aligner
Attributes
----------
vectorizer : FastTextVectorize
Vectorize words using fastText (Bojanowski et al., 2017).
word_aligner : WordAlign
Align words in two sentences.
extractor : PhraseExtract
Extract phrase pairs in two sentences based on word alignment and
calculate alignment scores of phrase pairs.
phrase_aligner : PhraseAlign
Search for a phrase alignment with the highest total alignment score.
Methods
-------
set_params(lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False)
Set hyper-parameters of SAPPHIRE.
align(tokens_src, tokens_trg)
Get word alignment and phrase alignment.
"""
def __init__(self, model):
self.vectorizer = FastTextVectorize(model)
self.lambda_ = 0.6
self.delta = 0.6
self.alpha = 0.01
self.use_hungarian = False
self.prune_k = -1
self.get_score = False
self.epsilon = None
self.word_aligner = WordAlign(self.lambda_, self.use_hungarian)
self.extractor = PhraseExtract(self.delta, self.alpha)
self.phrase_aligner = PhraseAlign(self.prune_k,
self.get_score,
self.epsilon)
def __call__(self, tokens_src, tokens_trg):
return self.align(tokens_src, tokens_trg)
def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False,
prune_k=-1, get_score=False, epsilon=None):
"""
Set hyper-parameters of SAPPHIRE.
Details are discussed in the following paper:
https://www.aclweb.org/anthology/2020.lrec-1.847/ .
Parameters
----------
lambda_ : float
Prunes word alignment candidates.
delta : float
Prunes phrase alignment candidates.
alpha : float
Biases the phrase alignment score based on the lengths of phrases.
hungarian : bool
Whether to use the extended Hangarian method to get word alignment.
prune_k : int
Prunes the number of nodes following a nodes in the lattice.
get_score : bool
Whether to output alignment scores with phrase alignments.
epsilon : float
Alignment score for a null alignment.
If epsilon is None, SAPPHIRE does not consider null alignment.
"""
self.lambda_ = lambda_
self.delta = delta
self.alpha = alpha
self.use_hungarian = hungarian
self.prune_k = prune_k
self.get_score = get_score
self.epsilon = epsilon
self.word_aligner.set_params(self.lambda_, self.use_hungarian)
self.extractor.set_params(self.delta, self.alpha)
self.phrase_aligner.set_params(self.prune_k,
self.get_score,
self.epsilon)
def align(self, tokens_src: list, tokens_trg: list):
"""
Align phrase pairs in two sentences.
Parameters
----------
tokens_src, tokens_trg : list
A tokenized sentence represented by a list of words.
Returns
-------
tuple
(word_alignment, phrase_alignment)
"""
len_src = len(tokens_src)
len_trg = len(tokens_trg)
vectors_src = self.vectorizer(tokens_src)
vectors_trg = self.vectorizer(tokens_trg)
sim_matrix = get_similarity_matrix(vectors_src, vectors_trg)
word_alignment = self.word_aligner(sim_matrix)
phrase_pairs = self.extractor(word_alignment, vectors_src, vectors_trg)
phrase_alignment = self.phrase_aligner(phrase_pairs, len_src, len_trg)
return word_alignment, phrase_alignment