Skip to content

Commit

Permalink
update_wizardmath
Browse files Browse the repository at this point in the history
  • Loading branch information
flyinghpluo authored Dec 19, 2023
1 parent 657bcb8 commit 8c68227
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 4 deletions.
141 changes: 141 additions & 0 deletions WizardMath/inference/grader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
"""
import multiprocessing
from math import isclose
from typing import Union

from sympy import simplify, N
from sympy.parsing.sympy_parser import parse_expr
from sympy.parsing.latex import parse_latex


def is_digit(s):
try:
float(str(s).replace(",", ""))
return True
except ValueError:
return False

def math_equal(prediction: Union[bool, float, str],
reference: Union[float, str],
include_percentage: bool = True,
is_close: bool = True,
timeout: bool = False,
) -> bool:
"""
Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference):
prediction = float(str(prediction).replace(",", ""))
reference = float(str(reference).replace(",", ""))
# number questions
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
try:
if is_close:
if isclose(item, prediction, rel_tol=1e-4):
return True
else:
if item == prediction:
return True
except Exception:
continue
return False
except:
pass

if not prediction and prediction not in [0, False]:
return False

# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()

## deal with [], (), {}
pred_str, ref_str = prediction, reference
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \
(prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
pred_str = pred_str.strip("[]()")
ref_str = ref_str.strip("[]()")
for s in ['{', "}", "(", ")"]:
ref_str = ref_str.replace(s, "")
pred_str = pred_str.replace(s, "")
if pred_str == ref_str:
return True

## [a, b] vs. [c, d], return a==c and b==d
if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \
(prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")):
pred_parts = prediction[1:-1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts):
if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
return True

# symbolic equal with sympy
if timeout:
if call_with_timeout(symbolic_equal_process, prediction, reference):
return True
else:
if symbolic_equal(prediction, reference):
return True

return False


def math_equal_process(param):
return math_equal(param[-2], param[-1])


def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except:
pass
return s
a = _parse(a)
b = _parse(b)

try:
if simplify(a-b) == 0:
return True
except:
pass

try:
if isclose(N(a), N(b), rel_tol=1e-3):
return True
except:
pass
return False


def symbolic_equal_process(a, b, output_queue):
result = symbolic_equal(a, b)
output_queue.put(result)


def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue()
process_args = args + (output_queue,)
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
process.start()
process.join(timeout)

if process.is_alive():
process.terminate()
process.join()
return False

return output_queue.get()

4 changes: 2 additions & 2 deletions WizardMath/inference/gsm8k_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys

from grader import math_equal
MAX_INT = sys.maxsize

def is_number(s):
Expand Down Expand Up @@ -110,7 +110,7 @@ def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_para
doc = {'question': prompt}
y_pred = extract_answer_number(completion)
if y_pred != None:
result.append(float(y_pred) == float(prompt_answer))
result.append(float(y_pred) == float(prompt_answer) or math_equal(y_pred, prompt_answer))
else:
result.append(False)
temp = {'question': prompt, 'output': completion, 'answer': prompt_answer}
Expand Down
9 changes: 7 additions & 2 deletions WizardMath/inference/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pprint
from grader import math_equal

def last_boxed_only(sample):
q, a = sample
Expand Down Expand Up @@ -239,9 +240,13 @@ def is_equiv(str1, str2, verbose=False):
#pdb.set_trace()
if verbose:
print(ss1, ss2)
return ss1 == ss2
#return ss1 == ss2
res = math_equal(ss1,ss2) or ss1 == ss2
return res
except Exception:
return str1 == str2
#return str1 == str2
res = math_equal(str1,str1) or str1 == str2
return res

class NotEqual:
def __eq__(self, other):
Expand Down

0 comments on commit 8c68227

Please sign in to comment.