Skip to content

Commit

Permalink
Add GLEU scorer
Browse files Browse the repository at this point in the history
  • Loading branch information
Tbabm committed Nov 27, 2020
1 parent efa9549 commit 49f5592
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 16 deletions.
188 changes: 172 additions & 16 deletions compare_mt/scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import subprocess
import tempfile
from collections import Counter
from itertools import chain

from compare_mt import corpus_utils
from compare_mt import align_utils
Expand All @@ -25,7 +26,7 @@ def scale(self):

def score_corpus(self, ref, out, src=None):
pass

def score_sentence(self, ref, out, src=None):
pass

Expand Down Expand Up @@ -85,7 +86,7 @@ def cache_stats(self, ref, out, src=None):
src = [None for _ in ref] if src is None else src
for r, o, s in zip(ref, out, src):
cached_scores.append(self.score_sentence(r, o, s)[0])

return cached_scores

def score_cached_corpus(self, sent_ids, cached_stats):
Expand Down Expand Up @@ -157,7 +158,7 @@ def _precision(self, ref, out, n):
denom = max(1, denom)

return num, denom

def cache_stats(self, ref, out, src=None):
"""
Cache sufficient statistics for caculating BLEU score
Expand Down Expand Up @@ -202,7 +203,7 @@ def score_cached_corpus(self, sent_ids, cached_stats):

num_prec = Counter()
denom_prec = Counter()

ref_len = 0
out_len = 0
for sent_id in sent_ids:
Expand All @@ -220,8 +221,8 @@ def score_cached_corpus(self, sent_ids, cached_stats):
for i, w in enumerate(self.weights, start=1):
p = num_prec[i] / denom_prec[i] if denom_prec[i] != 0 else 0
p = math.log(p) if p > 0 else 0
prec += p * w
prec += p * w

bp = min(1, math.exp(1 - ref_len/out_len)) if out_len != 0 else 0

return self.scale * bp * math.exp(prec), None
Expand Down Expand Up @@ -258,7 +259,7 @@ def score_sentence(self, ref, out, src=None):
chencherry = nltk.translate.bleu_score.SmoothingFunction()
if self.case_insensitive:
bleu_score = nltk.translate.bleu_score.sentence_bleu([corpus_utils.lower(ref)], corpus_utils.lower(out), smoothing_function=chencherry.method2)
else:
else:
bleu_score = nltk.translate.bleu_score.sentence_bleu([ref], out, smoothing_function=chencherry.method2)
return self.scale * bleu_score, None

Expand Down Expand Up @@ -386,7 +387,7 @@ def _kendall_tau_distance(self, alignment):
for j in range(i+1, n):
if alignment[j] > alignment[i]:
dis += 1
return 2*dis/(n*n-n)
return 2*dis/(n*n-n)

def score_sentence(self, ref, out, src=None):
"""
Expand All @@ -401,7 +402,7 @@ def score_sentence(self, ref, out, src=None):
The RIBES score, and None
"""
alignment = align_utils.ngram_context_align(ref, out, order=self.order, case_insensitive=self.case_insensitive)
kt_dis = self._kendall_tau_distance(alignment)
kt_dis = self._kendall_tau_distance(alignment)
prec = len(alignment)/ len(out) if len(out) != 0 else 0
bp = min(1, math.exp(1-len(ref)/len(out))) if len(out) != 0 else 0
return self.scale * kt_dis * (prec**self.alpha) * (bp**self.beta), None
Expand Down Expand Up @@ -547,7 +548,7 @@ def __init__(self, rouge_type, score_type='fmeasure', use_stemmer=False, case_in
@property
def scale(self):
return global_scorer_scale

def score_sentence(self, ref, out, src=None):
if self.case_insensitive:
ref = corpus_utils.lower(ref)
Expand All @@ -556,7 +557,7 @@ def score_sentence(self, ref, out, src=None):
if self._stemmer:
ref = [self._stemmer.stem(x) if len(x) > 3 else x for x in ref]
out = [self._stemmer.stem(x) if len(x) > 3 else x for x in out]

if self.rouge_type == 'rougeL':
ref, out = self.tokenize(" ".join(ref)), self.tokenize(" ".join(out))
scores = rouge_scorer._score_lcs(ref, out)
Expand Down Expand Up @@ -679,7 +680,7 @@ def _edit_distance(self, ref, out, src=None):
if self.case_insensitive:
ref = corpus_utils.lower(ref)
out = corpus_utils.lower(out)

sp1 = len(ref)+1
tp1 = len(out)+1
scores = np.zeros((sp1, tp1))
Expand All @@ -692,10 +693,10 @@ def _edit_distance(self, ref, out, src=None):
for j in range(0, len(out)):
my_action = 0 if equals[i,j] else 1
my_score = scores[i,j] + my_action * self.sub_pen
del_score = scores[i,j+1] + self.del_pen
del_score = scores[i,j+1] + self.del_pen
if del_score < my_score:
my_score = del_score
ins_score = scores[i+1,j] + self.ins_pen
ins_score = scores[i+1,j] + self.ins_pen
if ins_score < my_score:
my_score = ins_score
scores[i+1,j+1] = my_score
Expand Down Expand Up @@ -824,9 +825,9 @@ def score_cached_corpus(self, sent_ids, cached_stats):
out_total_match = np.sum(out_content_match_stage) + np.sum(out_func_match_stage)
ref_total_match = np.sum(ref_content_match_stage) + np.sum(ref_func_match_stage)

frag = float(chunks) / (float(out_word_match+ref_word_match)/2)
frag = float(chunks) / (float(out_word_match+ref_word_match)/2)
frag = 0 if out_total_match == out_len and ref_total_match == ref_len and chunks == 1 else frag

frag_penalty = gamma * math.pow(frag, beta)

score = fmean * (1.0-frag_penalty)
Expand Down Expand Up @@ -905,6 +906,159 @@ def name(self):
def idstr(self):
return "comet"

class GleuScorer(Scorer):
"""
A scorer that calculates GLEU score.
References:
"Ground Truth for Grammatical Error Correction Metrics", Napoles et al.
"GLEU Without Tuning", Napoles et al.
"""
def __init__(self, weights=(0.25, 0.25, 0.25, 0.25), case_insensitive=False):
self.weights = weights
self.case_insensitive = case_insensitive

@property
def scale(self):
return global_scorer_scale

def score_corpus(self, ref, out, src=None):
"""
Score a corpus using GLEU score
Args:
ref: A reference corpus
out: An output corpus
src: A source corpus. Required
Returns:
A tuple containing a single value for the GLEU score and a string summarizing auxiliary information
"""
cached_stats = self.cache_stats(ref, out, src)
return self.score_cached_corpus(range(len(ref)), cached_stats)

def score_sentence(self, ref, out, src=None):
"""
Score a sentence using GLEU score
Args:
ref: A reference sentence
out: An output sentence
src: A source sentence. Required
Returns:
A tuple containing a single value for the GLEU score and a string summarizing auxiliary information
"""
cached_stats = self.cache_stats([ref], [out], [src])
# Smooth according to https://github.com/cnap/gec-ranking/blob/master/scripts/gleu.py
stat = cached_stats[0]
cached_stats[0] = (stat[0], stat[1],
[(max(num, 1), max(denom, 1)) for num, denom in stat[2]])
return self.score_cached_corpus(range(1), cached_stats)

def _precision(self, ref, out, src, n):
"""
Calcualte GLEU-specific n-gram precision
Args:
ref: A reference sentence
out: An output sentence
src: A source sentence
Returns:
Numerator and denominator of the precision
"""
ref_ngram = ngram_utils.sent_ngrams_list(ref, n)
out_ngram = ngram_utils.sent_ngrams_list(out, n)
src_ngram = ngram_utils.sent_ngrams_list(src, n)
ref_cnt = Counter(ref_ngram)
out_cnt = Counter(out_ngram)
src_cnt = Counter(src_ngram)

out_join_ref = out_cnt & ref_cnt
out_join_src = out_cnt & src_cnt

num = sum(out_join_ref.values()) - \
sum((out_join_src - out_join_ref).values())
# According to https://github.com/cnap/gec-ranking/blob/master/scripts/gleu.py
num = max(num, 0)
denom = sum(out_cnt.values())

return num, denom

def cache_stats(self, ref, out, src=None):
"""
Cache sufficient statistics for calculating BLEU score
Args:
ref: A reference corpus
out: An output corpus
src: A source corpus. Required.
Returns:
A list of cached statistics
"""
if self.case_insensitive:
ref = corpus_utils.lower(ref)
out = corpus_utils.lower(out)
src = corpus_utils.lower(src)

cached_stats = []
for r, o, s in zip(ref, out, src):
prec = []
for n in range(1, len(self.weights) + 1):
prec.append(self._precision(r, o, s, n))
cached_stats.append((len(r), len(o), prec))
return cached_stats

def score_cached_corpus(self, sent_ids, cached_stats):
"""
Score a corpus using GLEU score with cache
Args:
sent_ids: The sentence ids for reference and output corpora
cached_stats: A list of cached statistics
Returns:
A tuple containing a single value for the BLEU score and a string summarizing auxiliary information
"""
if len(cached_stats) == 0:
return 0.0, None

cached_ref_len, cached_out_len, cached_prec = zip(*cached_stats)

num_prec = Counter()
denom_prec = Counter()

ref_len = 0
out_len = 0
for sent_id in sent_ids:
ref_len += cached_ref_len[sent_id]
out_len += cached_out_len[sent_id]
for n in range(1, len(self.weights) + 1):
num, denom = cached_prec[sent_id][n-1]
num_prec[n] += num
denom_prec[n] += denom

# According to https://github.com/cnap/gec-ranking/blob/master/scripts/gleu.py
if any(map(lambda x: x == 0, chain(num_prec, denom_prec))):
return 0, None

prec = 0
for i, w in enumerate(self.weights, start=1):
p = math.log(num_prec[i] / denom_prec[i])
prec += p * w

bp = min(1, math.exp(1 - ref_len/out_len)) if out_len != 0 else 0

return self.scale * bp * math.exp(prec), None

def name(self):
return "GLEU"

def idstr(self):
return "gleu"

def create_scorer_from_profile(profile, case_insensitive=False, meteor_directory=None, options=None):
"""
Create a scorer from a profile string
Expand Down Expand Up @@ -939,5 +1093,7 @@ def create_scorer_from_profile(profile, case_insensitive=False, meteor_directory
return ExactMatchScorer()
elif profile == 'comet':
return COMETScorer()
elif profile == 'gleu':
return GleuScorer()
else:
raise ValueError(f'Invalid profile for scorer {profile}'.format(profile=profile))
23 changes: 23 additions & 0 deletions tests/test_scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,28 @@ def test_detok_bleu_corpus(self):
self.assertAlmostEqual(detok_bleu, 21.7, places=0)


class TestGleuScorer(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
example_path = os.path.join(compare_mt_root, "example")
filenames = ["ted.ref.eng", "ted.sys1.eng", "ted.orig.slk"]
cls.ref, cls.out, cls.src = [load_tokens(os.path.join(example_path, name)) for name in filenames]
cls.scorer = scorers.create_scorer_from_profile("gleu", case_insensitive=False)

def test_score_corpus(self):
gleu, _ = self.scorer.score_corpus(self.ref, self.out, self.src)
# Compare to https://github.com/cnap/gec-ranking
self.assertAlmostEqual(gleu, 22.39, places=1)

def test_score_sentence(self):
src = "A simple src sentence of test .".split()
ref = "A simple source sentence for testing .".split()
out = "A simple src sentence for testing .".split()
gleu, _ = self.scorer.score_sentence(ref, out, src)
# Compare to https://github.com/cnap/gec-ranking
self.assertAlmostEqual(gleu, 33.03, places=1)


if __name__ == "__main__":
unittest.main()

0 comments on commit 49f5592

Please sign in to comment.