forked from nlpxucan/WizardLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
657bcb8
commit 8c68227
Showing
3 changed files
with
150 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
""" | ||
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: | ||
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC | ||
""" | ||
import multiprocessing | ||
from math import isclose | ||
from typing import Union | ||
|
||
from sympy import simplify, N | ||
from sympy.parsing.sympy_parser import parse_expr | ||
from sympy.parsing.latex import parse_latex | ||
|
||
|
||
def is_digit(s): | ||
try: | ||
float(str(s).replace(",", "")) | ||
return True | ||
except ValueError: | ||
return False | ||
|
||
def math_equal(prediction: Union[bool, float, str], | ||
reference: Union[float, str], | ||
include_percentage: bool = True, | ||
is_close: bool = True, | ||
timeout: bool = False, | ||
) -> bool: | ||
""" | ||
Exact match of math if and only if: | ||
1. numerical equal: both can convert to float and are equal | ||
2. symbolic equal: both can convert to sympy expression and are equal | ||
""" | ||
try: # 1. numerical equal | ||
if is_digit(prediction) and is_digit(reference): | ||
prediction = float(str(prediction).replace(",", "")) | ||
reference = float(str(reference).replace(",", "")) | ||
# number questions | ||
if include_percentage: | ||
gt_result = [reference / 100, reference, reference * 100] | ||
else: | ||
gt_result = [reference] | ||
for item in gt_result: | ||
try: | ||
if is_close: | ||
if isclose(item, prediction, rel_tol=1e-4): | ||
return True | ||
else: | ||
if item == prediction: | ||
return True | ||
except Exception: | ||
continue | ||
return False | ||
except: | ||
pass | ||
|
||
if not prediction and prediction not in [0, False]: | ||
return False | ||
|
||
# 2. symbolic equal | ||
reference = str(reference).strip() | ||
prediction = str(prediction).strip() | ||
|
||
## deal with [], (), {} | ||
pred_str, ref_str = prediction, reference | ||
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \ | ||
(prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): | ||
pred_str = pred_str.strip("[]()") | ||
ref_str = ref_str.strip("[]()") | ||
for s in ['{', "}", "(", ")"]: | ||
ref_str = ref_str.replace(s, "") | ||
pred_str = pred_str.replace(s, "") | ||
if pred_str == ref_str: | ||
return True | ||
|
||
## [a, b] vs. [c, d], return a==c and b==d | ||
if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \ | ||
(prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")): | ||
pred_parts = prediction[1:-1].split(",") | ||
ref_parts = reference[1:-1].split(",") | ||
if len(pred_parts) == len(ref_parts): | ||
if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): | ||
return True | ||
|
||
# symbolic equal with sympy | ||
if timeout: | ||
if call_with_timeout(symbolic_equal_process, prediction, reference): | ||
return True | ||
else: | ||
if symbolic_equal(prediction, reference): | ||
return True | ||
|
||
return False | ||
|
||
|
||
def math_equal_process(param): | ||
return math_equal(param[-2], param[-1]) | ||
|
||
|
||
def symbolic_equal(a, b): | ||
def _parse(s): | ||
for f in [parse_latex, parse_expr]: | ||
try: | ||
return f(s) | ||
except: | ||
pass | ||
return s | ||
a = _parse(a) | ||
b = _parse(b) | ||
|
||
try: | ||
if simplify(a-b) == 0: | ||
return True | ||
except: | ||
pass | ||
|
||
try: | ||
if isclose(N(a), N(b), rel_tol=1e-3): | ||
return True | ||
except: | ||
pass | ||
return False | ||
|
||
|
||
def symbolic_equal_process(a, b, output_queue): | ||
result = symbolic_equal(a, b) | ||
output_queue.put(result) | ||
|
||
|
||
def call_with_timeout(func, *args, timeout=1, **kwargs): | ||
output_queue = multiprocessing.Queue() | ||
process_args = args + (output_queue,) | ||
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) | ||
process.start() | ||
process.join(timeout) | ||
|
||
if process.is_alive(): | ||
process.terminate() | ||
process.join() | ||
return False | ||
|
||
return output_queue.get() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters