Skip to content

Commit

Permalink
Merge pull request #89 from jordiclive/CIDEr
Browse files Browse the repository at this point in the history
CIDEr
  • Loading branch information
tuetschek authored Mar 25, 2022
2 parents 2bc146a + 25ce278 commit 1a183cc
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 0 deletions.
1 change: 1 addition & 0 deletions gem_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def metric_list_to_metric_dict(metric_list: List[str]) -> Dict[str, List]:
"questeval": "QuestEval",
"prism": "Prism",
"ter": "TER",
"cider": "CIDER",
}

referenced_list, referenceless_list, sourced_and_referenced_list = [], [], []
Expand Down
220 changes: 220 additions & 0 deletions gem_metrics/cider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
#!/usr/bin/env python3
# Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>

import copy
import math
from collections import defaultdict
from typing import Dict

import numpy as np

from .metric import ReferencedMetric
from .texts import Predictions, References


class CIDER(ReferencedMetric):
"""CIDEr (Consensus-Based Image Description Evaluation) Metric. Computation is done on lower-cased data without punctuation
(http://arxiv.org/abs/1411.5726).
This is based on scripts by authors Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>:
https://github.com/vrama91/cider/blob/master/pyciderevalcap/cider/cider.py. Implementation should be nearly identical
to original, the only difference is tokenization; the original uses stanford-corenlp-3.4.1.jar.
"""

def __init__(self, n=4, sigma=6.0):
self._n = n
# set the standard deviation parameter for gaussian penalty
self._sigma = sigma
self.crefs = []
self.ctest = []
self.document_frequency = defaultdict(float)
self.ref_len = None
self.cook_append(None, None)

def __iadd__(self, other):
"""add an instance (e.g., from another sentence)."""

if type(other) is tuple:
## avoid creating new CiderScorer instances
self.cook_append(other[0], other[1])
else:
self.ctest.extend(other.ctest)
self.crefs.extend(other.crefs)

return self

def support_caching(self):
# CIDEr is corpus-level, so individual examples can't be aggregated.
return False

def compute(self, cache, predictions: Predictions, references: References) -> Dict:
refs = references.list_tokenized_lower_nopunct
preds = predictions.list_tokenized_lower_nopunct
for i, pred in enumerate(preds):
self += (pred, refs[i])

(score, _) = self.compute_score()

return {"CIDEr": round(score, 5)}

@staticmethod
def precook(s, n=4):
"""
Takes a string as input and returns an object that can be given to
either cook_refs or cook_test. This is optional: cook_refs and cook_test
can take string arguments as well.
:param s: string : sentence to be converted into ngrams
:param n: int : number of ngrams for which representation is calculated
:return: term frequency vector for occuring ngrams
"""
# words = s.split()
words = s
counts = defaultdict(int)
for k in range(1, n + 1):
for i in range(len(words) - k + 1):
ngram = tuple(words[i : i + k])
counts[ngram] += 1
return counts

def cook_refs(self, refs, n=4): ## lhuang: oracle will call with "average"
"""Takes a list of reference sentences for a single segment
and returns an object that encapsulates everything that BLEU
needs to know about them.
:param refs: list of string : reference sentences for some image
:param n: int : number of ngrams for which (ngram) representation is calculated
:return: result (list of dict)
"""
return [self.precook(ref, n) for ref in refs]

def cook_test(self, test, n=4):
"""Takes a test sentence and returns an object that
encapsulates everything that BLEU needs to know about it.
:param test: list of string : hypothesis sentence for some image
:param n: int : number of ngrams for which (ngram) representation is calculated
:return: result (dict)
"""
return self.precook(test, n)

def copy(self):
"""copy the refs."""
new = CIDER(n=self._n)
new.ctest = copy.copy(self.ctest)
new.crefs = copy.copy(self.crefs)
return new

def cook_append(self, test, refs):
"""called by constructor and __iadd__ to avoid creating new instances."""

if refs is not None:
self.crefs.append(self.cook_refs(refs))
if test is not None:
self.ctest.append(self.cook_test(test)) ## N.B.: -1
else:
self.ctest.append(None) # lens of crefs and ctest have to match

def size(self):
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (
len(self.crefs),
len(self.ctest),
)
return len(self.crefs)

def compute_doc_freq(self):
"""
Compute term frequency for reference data.
This will be used to compute idf (inverse document frequency later)
The term frequency is stored in the object
:return: None
"""
for refs in self.crefs:
# refs, k ref captions of one image
for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]):
self.document_frequency[ngram] += 1
# maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

def compute_cider(self):
def counts2vec(cnts):
"""
Function maps counts of ngram to vector of tfidf weights.
The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
The n-th entry of array denotes length of n-grams.
:param cnts:
:return: vec (array of dict), norm (array of float), length (int)
"""
vec = [defaultdict(float) for _ in range(self._n)]
length = 0
norm = [0.0 for _ in range(self._n)]
for (ngram, term_freq) in cnts.items():
# give word count 1 if it doesn't appear in reference corpus
df = np.log(max(1.0, self.document_frequency[ngram]))
# ngram index
n = len(ngram) - 1
# tf (term_freq) * idf (precomputed idf) for n-grams
vec[n][ngram] = float(term_freq) * (self.ref_len - df)
# compute norm for the vector. the norm will be used for computing similarity
norm[n] += pow(vec[n][ngram], 2)

if n == 1:
length += term_freq
norm = [np.sqrt(n) for n in norm]
return vec, norm, length

def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
"""
Compute the cosine similarity of two vectors.
:param vec_hyp: array of dictionary for vector corresponding to hypothesis
:param vec_ref: array of dictionary for vector corresponding to reference
:param norm_hyp: array of float for vector corresponding to hypothesis
:param norm_ref: array of float for vector corresponding to reference
:param length_hyp: int containing length of hypothesis
:param length_ref: int containing length of reference
:return: array of score for each n-grams cosine similarity
"""
delta = float(length_hyp - length_ref)
# measure consine similarity
val = np.array([0.0 for _ in range(self._n)])
for n in range(self._n):
# ngram
for (ngram, count) in vec_hyp[n].items():
# vrama91 : added clipping
val[n] += (
min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
)

if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
val[n] /= norm_hyp[n] * norm_ref[n]

assert not math.isnan(val[n])
# vrama91: added a length based gaussian penalty
val[n] *= np.e ** (-(delta**2) / (2 * self._sigma**2))
return val

# compute log reference length
self.ref_len = np.log(float(len(self.crefs)))

scores = []
for test, refs in zip(self.ctest, self.crefs):
# compute vector for test captions
vec, norm, length = counts2vec(test)
# compute vector for ref captions
score = np.array([0.0 for _ in range(self._n)])
for ref in refs:
vec_ref, norm_ref, length_ref = counts2vec(ref)
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
# change by vrama91 - mean of ngram scores, instead of sum
score_avg = np.mean(score)
# divide by number of references
score_avg /= len(refs)
# multiply score by 10
score_avg *= 10.0
# append score of an image to the score list
scores.append(score_avg)
return scores

def compute_score(self, option=None, verbose=0):
# compute idf
self.compute_doc_freq()
# assert to check document frequency
assert len(self.ctest) >= max(self.document_frequency.values())
# compute cider score
score = self.compute_cider()
return np.mean(np.array(score)), np.array(score)
17 changes: 17 additions & 0 deletions tests/test_cider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import unittest
import gem_metrics.cider
from tests.test_referenced import TestReferencedMetric


class TestCider(TestReferencedMetric, unittest.TestCase):
def setUp(self):
super().setUp()
self.metric = gem_metrics.cider.CIDER()
self.true_results_basic = {"CIDEr": 1.89}
self.true_results_identical_pred_ref = {"CIDEr": 10.0}
self.true_results_mismatched_pred_ref = {"CIDEr": 0.0}
self.true_results_empty_pred = {"CIDEr": 0.0}


if __name__ == "__main__":
unittest.main()

0 comments on commit 1a183cc

Please sign in to comment.