Skip to content

Commit

Permalink
[merge] Merge branch 'dev_metrics' into develop
Browse files Browse the repository at this point in the history
* Enable to get phrase alignment scores from SAPPHIRE
* Enable to use BERT and chunker
  • Loading branch information
m-yoshinaka committed Feb 12, 2021
2 parents f6af6b4 + 096f9dd commit d1eec4d
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
__pycache__/
/.vscode/
/*.egg-info/
41 changes: 34 additions & 7 deletions sapphire/phrase_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def extract(
self, word_alignments, vectors_src: np.array, vectors_trg: np.array
):
"""
Extract phrase pairs using the hueristic of phrase-based SMT.
Extract phrase pairs using the heuristic of phrase-based SMT.
Parameters
----------
Expand Down Expand Up @@ -111,14 +111,20 @@ def extract(

class PhraseAlign(object):

def __init__(self):
self.name = ''
def __init__(self, prune_k=-1, get_score=False, epsilon=None):
self.prune_k = prune_k
self.get_score = get_score
self.epsilon = epsilon

def __call__(self, phrase_pairs, len_src, len_trg):
return self.search_for_lattice(phrase_pairs, len_src, len_trg)

@staticmethod
def search_for_lattice(phrase_pairs, len_src: int, len_trg: int):
def set_params(self, prune_k, get_score, epsilon):
self.prune_k = prune_k
self.get_score = get_score
self.epsilon = epsilon

def search_for_lattice(self, phrase_pairs, len_src: int, len_trg: int):
"""
Construct a lattice of phrase pairs and depth-first search for the
path with the highest total alignment score.
Expand Down Expand Up @@ -169,6 +175,9 @@ def _forward(s, t, start_node, end_node, pairs):
if not nearer:
nearest_pairs.append(pair)

if self.prune_k != -1:
nearest_pairs = nearest_pairs[:self.prune_k]

for next_pair in nearest_pairs:
ss, se, ts, te, __score = next_pair
next_node = {'index': (ss, se, ts, te),
Expand Down Expand Up @@ -199,7 +208,7 @@ def _forward(s, t, start_node, end_node, pairs):
return path

if not phrase_pairs:
return [([], 0)]
return ([], 0) if self.get_score else []

s_start, s_end, t_start, t_end, score = sorted(
phrase_pairs, key=lambda x: x[4], reverse=True)[0]
Expand Down Expand Up @@ -235,8 +244,26 @@ def _forward(s, t, start_node, end_node, pairs):
length = len(concat_path)
score = (prev_path[1] + next_path[1] + score) / length \
if length != 0 else 0
alignments.append((concat_path, str(score)))
alignments.append((concat_path, score))

if self.epsilon is not None:
new_alignments = []
for alignment, score in alignments:
nof_align = len(alignment)
nof_null_align = len_src + len_trg
for ss, se, ts, te in alignment:
nof_null_align -= se - ss + 1
nof_null_align -= te - ts + 1

score = (score * nof_align + self.epsilon * nof_null_align) \
/ (nof_align + nof_null_align)
new_alignments.append((alignment, score))

alignments = new_alignments

alignments.sort(key=lambda x: float(x[1]), reverse=True)

if self.get_score:
return alignments[0]

return alignments[0][0] # Return only the top one of phrase alignments
23 changes: 21 additions & 2 deletions sapphire/sapphire.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,20 @@ def __init__(self, model):
self.delta = 0.6
self.alpha = 0.01
self.use_hungarian = False
self.prune_k = -1
self.get_score = False
self.epsilon = None
self.word_aligner = WordAlign(self.lambda_, self.use_hungarian)
self.extractor = PhraseExtract(self.delta, self.alpha)
self.phrase_aligner = PhraseAlign()
self.phrase_aligner = PhraseAlign(self.prune_k,
self.get_score,
self.epsilon)

def __call__(self, tokens_src, tokens_trg):
return self.align(tokens_src, tokens_trg)

def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False):
def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False,
prune_k=-1, get_score=False, epsilon=None):
"""
Set hyper-parameters of SAPPHIRE.
Expand All @@ -58,13 +64,26 @@ def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False):
Biases the phrase alignment score based on the lengths of phrases.
hungarian : bool
Whether to use the extended Hangarian method to get word alignment.
prune_k : int
Prunes the number of nodes following a nodes in the lattice.
get_score : bool
Whether to output alignment scores with phrase alignments.
epsilon : float
Alignment score for a null alignment.
If epsilon is None, SAPPHIRE does not consider null alignment.
"""
self.lambda_ = lambda_
self.delta = delta
self.alpha = alpha
self.use_hungarian = hungarian
self.prune_k = prune_k
self.get_score = get_score
self.epsilon = epsilon
self.word_aligner.set_params(self.lambda_, self.use_hungarian)
self.extractor.set_params(self.delta, self.alpha)
self.phrase_aligner.set_params(self.prune_k,
self.get_score,
self.epsilon)

def align(self, tokens_src: list, tokens_trg: list):
"""
Expand Down
137 changes: 137 additions & 0 deletions sapphire/sapphire_using_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import numpy as np
import torch
from transformers import BertModel, BertTokenizer

from .word_alignment import WordAlign, WordEmbedding, get_similarity_matrix
from .phrase_alignment import PhraseExtract, PhraseAlign


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class BertEmbedding(WordEmbedding):

def __init__(self, model_name):
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertModel.from_pretrained(model_name)
self.model.to(device)

def __call__(self, words):
return self.vectorize(words)

@staticmethod
def _to_word_vectors(words, subwords, subword_vectors):
tokenized = []
word_vectors = []
i = 0
for subword, vector in zip(subwords, subword_vectors):
if '#' not in subword:
tokenized.append(subword)
word_vectors.append(vector)
else:
tokenized[i - 1] += subword.replace('#', '')
word_vectors[i - 1] += vector

if len(words) == len(word_vectors):
return np.array(word_vectors)

_words = [w.lower() for w in words]
tmp = []
j = 0
for i, token in enumerate(tokenized):
if j >= len(_words):
break
if token == _words[j]:
tmp.append(word_vectors[i])
j += 1
else:
for k in range(1, len(tokenized) - i + 1):
cand = ''.join(tokenized[i:i + k + 1])
if cand == _words[j]:
new_vectors = np.array(word_vectors[i:i + k + 1])
new_vector = np.mean(new_vectors, axis=0)
tmp.append(new_vector)
j += 1
break

return np.array(tmp)

def vectorize(self, words):
self.model.eval()

text = '[CLS] ' + ' '.join(words) + ' [SEP]'
tokenized = self.tokenizer.tokenize(text)

index_tokens = self.tokenizer.convert_tokens_to_ids(tokenized)
tokens_tensor = torch.tensor([index_tokens]).to(device)

with torch.no_grad():
encoded_layers, _ = self.model(tokens_tensor)

encoded_layers = encoded_layers[0][1:-1].detach().cpu().numpy()
vectors = self._to_word_vectors(words, tokenized[1:-1],
np.array(encoded_layers))

return vectors


class SapphireBert(object):

def __init__(self, model_name='bert-base-uncased'):
self.vectorizer = BertEmbedding(model_name)
self.lambda_ = 0.6
self.delta = 0.6
self.alpha = 0.01
self.use_hungarian = False
self.prune_k = -1
self.get_score = False
self.epsilon = None
self.word_aligner = WordAlign(self.lambda_, self.use_hungarian)
self.extractor = PhraseExtract(self.delta, self.alpha)
self.phrase_aligner = PhraseAlign(self.prune_k,
self.get_score,
self.epsilon)

def __call__(self, tokens_src, tokens_trg):
return self.align(tokens_src, tokens_trg)

def set_params(self, lambda_=0.6, delta=0.6, alpha=0.01, hungarian=False,
prune_k=-1, get_score=False, epsilon=None):
self.lambda_ = lambda_
self.delta = delta
self.alpha = alpha
self.use_hungarian = hungarian
self.prune_k = prune_k
self.get_score = get_score
self.epsilon = epsilon
self.word_aligner.set_params(self.lambda_, self.use_hungarian)
self.extractor.set_params(self.delta, self.alpha)
self.phrase_aligner.set_params(self.prune_k,
self.get_score,
self.epsilon)

def align(self, tokens_src: list, tokens_trg: list):
try:
len_src = len(tokens_src)
len_trg = len(tokens_trg)

vectors_src = self.vectorizer(tokens_src)
vectors_trg = self.vectorizer(tokens_trg)

sim_matrix = get_similarity_matrix(vectors_src, vectors_trg)
word_alignment = self.word_aligner(sim_matrix)

phrase_pairs = self.extractor(
word_alignment, vectors_src, vectors_trg)
phrase_alignment = self.phrase_aligner(
phrase_pairs, len_src, len_trg)

except ValueError:
word_alignment = []
if self.get_score:
phrase_alignment = ([], 0)
else:
phrase_alignment = []

return word_alignment, phrase_alignment
Loading

0 comments on commit d1eec4d

Please sign in to comment.