qianxiao1111's picture
upgrade: add benchmarks eval
2a26d3b
raw
history blame
22.7 kB
import random
import regex
import re
import sympy
from latex2sympy2 import latex2sympy
from typing import TypeVar, Iterable, List, Union, Any, Dict
from word2number import w2n
from utils import *
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) > 0 and substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
if "sqrt" not in a:
a = int(a)
if "sqrt" not in b:
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except:
return string
def _fix_sqrt(string):
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
return _string
def convert_word_number(text: str) -> str:
try:
text = str(w2n.word_to_num(text))
except:
pass
return text
# units mainly from MathQA
unit_texts = [
"east",
"degree",
"mph",
"kmph",
"ft",
"m sqaure",
" m east",
"sq m",
"deg",
"mile",
"q .",
"monkey",
"prime",
"ratio",
"profit of rs",
"rd",
"o",
"gm",
"p . m",
"lb",
"tile",
"per",
"dm",
"lt",
"gain",
"ab",
"way",
"west",
"a .",
"b .",
"c .",
"d .",
"e .",
"f .",
"g .",
"h .",
"t",
"a",
"h",
"no change",
"men",
"soldier",
"pie",
"bc",
"excess",
"st",
"inches",
"noon",
"percent",
"by",
"gal",
"kmh",
"c",
"acre",
"rise",
"a . m",
"th",
"π r 2",
"sq",
"mark",
"l",
"toy",
"coin",
"sq . m",
"gallon",
"° f",
"profit",
"minw",
"yr",
"women",
"feet",
"am",
"pm",
"hr",
"cu cm",
"square",
"v â € ™",
"are",
"rupee",
"rounds",
"cubic",
"cc",
"mtr",
"s",
"ohm",
"number",
"kmph",
"day",
"hour",
"minute",
"min",
"second",
"man",
"woman",
"sec",
"cube",
"mt",
"sq inch",
"mp",
"∏ cm ³",
"hectare",
"more",
"sec",
"unit",
"cu . m",
"cm 2",
"rs .",
"rs",
"kg",
"g",
"month",
"km",
"m",
"cm",
"mm",
"apple",
"liter",
"loss",
"yard",
"pure",
"year",
"increase",
"decrease",
"d",
"less",
"Surface",
"litre",
"pi sq m",
"s .",
"metre",
"meter",
"inch",
]
unit_texts.extend([t + "s" for t in unit_texts])
def strip_string(string, skip_unit=False):
string = str(string).strip()
# linebreaks
string = string.replace("\n", "")
# right "."
string = string.rstrip(".")
# remove inverse spaces
# replace \\ with \
string = string.replace("\\!", "")
# string = string.replace("\\ ", "")
# string = string.replace("\\\\", "\\")
# matrix
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
string = string.replace("bmatrix", "pmatrix")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
string = (
string.replace("\\neq", "\\ne")
.replace("\\leq", "\\le")
.replace("\\geq", "\\ge")
)
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
string = string.replace("\\{", "{")
string = string.replace("\\}", "}")
# Remove unit: miles, dollars if after is not none
_string = re.sub(r"\\text{.*?}$", "", string).strip()
if _string != "" and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
if not skip_unit:
# Remove unit: texts
for _ in range(2):
for unit_text in unit_texts:
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
# the suffix should be either the end of the string or a non-alphanumeric character
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
if _string != "":
string = _string
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
string = string.replace("$", "")
string = string.replace("\\(", "").replace("\\)", "")
# convert word number to digit
string = convert_word_number(string)
# replace "\\text{...}" to "..."
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
string = string.replace(key, "")
string = string.replace("\\emptyset", r"{}")
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# cdot
# string = string.replace("\\cdot", "")
if (
string.startswith("{")
and string.endswith("}")
and string.isalnum()
or string.startswith("(")
and string.endswith(")")
and string.isalnum()
or string.startswith("[")
and string.endswith("]")
and string.isalnum()
):
string = string[1:-1]
# inf
string = string.replace("infinity", "\\infty")
if "\\infty" not in string:
string = string.replace("inf", "\\infty")
string = string.replace("+\\inity", "\\infty")
# and
string = string.replace("and", "")
string = string.replace("\\mathbf", "")
# use regex to remove \mbox{...}
string = re.sub(r"\\mbox{.*?}", "", string)
# quote
string.replace("'", "")
string.replace('"', "")
# i, j
if "j" in string and "i" not in string:
string = string.replace("j", "i")
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
string = re.sub(r"(\d+)\.0*$", r"\1", string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
string = _fix_sqrt(string)
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def extract_multi_choice_answer(pred_str):
# TODO: SFT models
if "Problem:" in pred_str:
pred_str = pred_str.split("Problem:", 1)[0]
pred_str = pred_str.replace("choice is", "answer is")
patt = regex.search(r"answer is \(?(?P<ans>[abcde])\)?", pred_str.lower())
if patt is not None:
return patt.group("ans").upper()
return "placeholder"
direct_answer_trigger_for_fewshot = ("choice is", "answer is")
def choice_answer_clean(pred: str):
pred = pred.strip("\n")
# Determine if this is ICL, if so, use \n\n to split the first chunk.
ICL = False
for trigger in direct_answer_trigger_for_fewshot:
if pred.count(trigger) > 1:
ICL = True
if ICL:
pred = pred.split("\n\n")[0]
# Split the trigger to find the answer.
preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred)
if len(preds) > 1:
answer_flag = True
pred = preds[-1]
else:
answer_flag = False
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
# Clean the answer based on the dataset
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
if tmp:
pred = tmp
else:
pred = [pred.strip().strip(".")]
if len(pred) == 0:
pred = ""
else:
if answer_flag:
# choose the first element in list ...
pred = pred[0]
else:
# choose the last e
pred = pred[-1]
# Remove the period at the end, again!
pred = pred.rstrip(".").rstrip("/")
return pred
def find_box(pred_str: str):
ans = pred_str.split("boxed")[-1]
if not ans:
return ""
if ans[0] == "{":
stack = 1
a = ""
for c in ans[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = ans.split("$")[0].strip()
return a
def clean_units(pred_str: str):
"""Clean the units in the number."""
def convert_pi_to_number(code_string):
code_string = code_string.replace("\\pi", "π")
# Replace \pi or π not preceded by a digit or } with 3.14
code_string = re.sub(r"(?<![\d}])\\?π", "3.14", code_string)
# Replace instances where π is preceded by a digit but without a multiplication symbol, e.g., "3π" -> "3*3.14"
code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string)
# Handle cases where π is within braces or followed by a multiplication symbol
# This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14"
code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string)
code_string = re.sub(r"\*(\\?π)", "*3.14", code_string)
return code_string
pred_str = convert_pi_to_number(pred_str)
pred_str = pred_str.replace("%", "/100")
pred_str = pred_str.replace("$", "")
pred_str = pred_str.replace("¥", "")
pred_str = pred_str.replace("°C", "")
pred_str = pred_str.replace(" C", "")
pred_str = pred_str.replace("°", "")
return pred_str
def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
if any([option in pred.lower() for option in ["yes", "true"]]):
pred = "True"
elif any([option in pred.lower() for option in ["no", "false"]]):
pred = "False"
elif any(
[
option in pred.lower()
for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]
]
):
pass
else:
# Some of the models somehow get used to boxed output from pre-training
if "boxed" in pred:
pred = find_box(pred)
if answer_flag:
# Extract the numbers out of the string
pred = pred.split("=")[-1].strip()
pred = clean_units(pred)
try:
tmp = str(latex2sympy(pred))
pred = str(eval(tmp))
except Exception:
if re.match(r"-?[\d\.]+\s\D+$", pred):
pred = pred.split(" ")[0]
elif re.match(r"-?[\d\.]+\s[^\s]+$", pred):
pred = pred.split(" ")[0]
else:
# desparate search over the last number
preds = re.findall(r"-?\d*\.?\d+", pred)
if len(preds) >= 1:
pred = preds[-1]
else:
pred = ""
return pred
def extract_answer(pred_str, data_name, use_last_number=True):
pred_str = pred_str.replace("\u043a\u0438", "")
if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]:
# TODO check multiple choice
return choice_answer_clean(pred_str)
if "final answer is $" in pred_str and "$. I hope" in pred_str:
# minerva_math
tmp = pred_str.split("final answer is $", 1)[1]
pred = tmp.split("$. I hope", 1)[0].strip()
elif "boxed" in pred_str:
ans = pred_str.split("boxed")[-1]
if len(ans) == 0:
return ""
elif ans[0] == "{":
stack = 1
a = ""
for c in ans[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = ans.split("$")[0].strip()
pred = a
elif "he answer is" in pred_str:
pred = pred_str.split("he answer is")[-1].strip()
elif "final answer is" in pred_str:
pred = pred_str.split("final answer is")[-1].strip()
elif "答案是" in pred_str:
# Handle Chinese few-shot multiple choice problem answer extraction
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
else: # use the last number
if use_last_number:
pattern = "-?\d*\.?\d+"
pred = re.findall(pattern, pred_str.replace(",", ""))
if len(pred) >= 1:
pred = pred[-1]
else:
pred = ""
else:
pred = ""
# choice answer
if (
data_name in ["sat_math", "aqua"]
or "mmlu" in data_name
):
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
if tmp:
pred = tmp[-1]
else:
pred = pred.strip().strip(".")
# multiple line
# pred = pred.split("\n")[0]
pred = re.sub(r"\n\s*", "", pred)
if pred != "" and pred[0] == ":":
pred = pred[1:]
if pred != "" and pred[-1] == ".":
pred = pred[:-1]
if pred != "" and pred[-1] == "/":
pred = pred[:-1]
pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"])
return pred
STRIP_EXCEPTIONS = ["carp_en", "minerva_math"]
def parse_ground_truth(example: Dict[str, Any], data_name):
if "gt_cot" in example and "gt" in example:
if data_name in ["math"]:
gt_ans = extract_answer(example["gt_cot"], data_name)
elif data_name in STRIP_EXCEPTIONS:
gt_ans = example["gt"]
else:
gt_ans = strip_string(example["gt"])
return example["gt_cot"], gt_ans
# parse ground truth
if data_name in ["math", "minerva_math"]:
gt_cot = example["solution"]
gt_ans = extract_answer(gt_cot, data_name)
elif data_name == "gsm8k":
gt_cot, gt_ans = example["answer"].split("####")
elif data_name == "svamp":
gt_cot, gt_ans = example["Equation"], example["Answer"]
elif data_name == "asdiv":
gt_cot = example["formula"]
gt_ans = re.sub(r"\(.*?\)", "", example["answer"])
elif data_name == "mawps":
gt_cot, gt_ans = None, example["target"]
elif data_name == "tabmwp":
gt_cot = example["solution"]
gt_ans = example["answer"]
if example["ans_type"] in ["integer_number", "decimal_number"]:
if "/" in gt_ans:
gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1])
elif "," in gt_ans:
gt_ans = float(gt_ans.replace(",", ""))
elif "%" in gt_ans:
gt_ans = float(gt_ans.split("%")[0]) / 100
else:
gt_ans = float(gt_ans)
elif data_name == "carp_en":
gt_cot, gt_ans = example["steps"], example["answer"]
elif data_name == "mmlu_stem":
abcd = "ABCD"
gt_cot, gt_ans = None, abcd[example["answer"]]
elif data_name == "sat_math":
gt_cot, gt_ans = None, example["Answer"]
elif data_name == "aqua":
gt_cot, gt_ans = None, example["correct"]
elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]:
gt_cot, gt_ans = None, example["answer"].replace("$", "").strip()
elif data_name == "gaokao_math_qa":
gt_cot, gt_ans = None, example["label"]
elif data_name in ["gaokao2024_mix", "cn_middle_school"]:
if len(example["choice_answer"]) > 0:
gt_cot, gt_ans = None, example["choice_answer"]
else:
gt_cot, gt_ans = None, example["answer"]
elif data_name == "olympiadbench":
gt_cot, gt_ans = None, example["final_answer"][0].strip("$")
elif data_name in [
"aime24",
"amc23",
"cmath",
"gaokao2024_I",
"gaokao2024_II",
"imo2024",
]:
gt_cot, gt_ans = None, example["answer"]
else:
raise NotImplementedError(f"`{data_name}`")
# post process
gt_cot = str(gt_cot).strip()
if data_name not in STRIP_EXCEPTIONS:
gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en")
else:
gt_ans = (
gt_ans.replace("\\neq", "\\ne")
.replace("\\leq", "\\le")
.replace("\\geq", "\\ge")
)
return gt_cot, gt_ans
def parse_question(example, data_name):
question = ""
if data_name == "asdiv":
question = f"{example['body'].strip()} {example['question'].strip()}"
elif data_name == "svamp":
body = example["Body"].strip()
if not body.endswith("."):
body = body + "."
question = f'{body} {example["Question"].strip()}'
elif data_name == "tabmwp":
title_str = (
f'regarding "{example["table_title"]}" ' if example["table_title"] else ""
)
question = f"Read the following table {title_str}and answer a question:\n"
question += f'{example["table"]}\n{example["question"]}'
if example["choices"]:
question += (
f' Please select from the following options: {example["choices"]}'
)
elif data_name == "carp_en":
question = example["content"]
elif data_name == "mmlu_stem":
options = example["choices"]
assert len(options) == 4
for i, (label, option) in enumerate(zip("ABCD", options)):
options[i] = f"({label}) {str(option).strip()}"
options = " ".join(options)
# question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
question = f"{example['question'].strip()}\nAnswer Choices: {options}"
elif data_name == "sat_math":
options = example["options"].strip()
assert "A" == options[0]
options = "(" + options
for ch in "BCD":
if f" {ch}) " in options:
options = regex.sub(f" {ch}\) ", f" ({ch}) ", options)
# question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
question = f"{example['question'].strip()}\nAnswer Choices: {options}"
elif "aqua" in data_name:
options = example["options"]
choice = "(" + "(".join(options)
choice = choice.replace("(", " (").replace(")", ") ").strip()
choice = "\nAnswer Choices: " + choice
question = example["question"].strip() + choice
elif data_name == "gaokao_math_qa":
options_dict = example["options"]
options = []
for key in options_dict:
options.append(f"({key}) {options_dict[key]}")
options = " ".join(options)
question = f"{example['question'].strip()}\n选项: {options}"
else:
for key in ["question", "problem", "Question", "input"]:
if key in example:
question = example[key]
break
# assert question != ""
# Yes or No question
_, gt_ans = parse_ground_truth(example, data_name)
if isinstance(gt_ans, str):
gt_lower = gt_ans.lower()
if gt_lower in ["true", "false"]:
question += " (True or False)"
if gt_lower in ["yes", "no"]:
question += " (Yes or No)"
return question.strip()
def run_execute(executor, result, prompt_type, data_name, execute=False):
if not result or result == "error":
return None, None
report = None
if "program_only" in prompt_type:
prediction = extract_program_output(result)
elif prompt_type in ["pot", "pal"] and execute:
code = extract_program(result)
prediction, report = executor.apply(code)
else:
prediction = extract_answer(result, data_name)
# prediction = strip_string(prediction, skip_unit=data_name == "carp_en")
prediction = strip_string(prediction, skip_unit=data_name in STRIP_EXCEPTIONS)
return prediction, report
def _test_extract_answer():
text = """
This is still not equal to $0$, so we must have made another mistake.
When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that:
\[\frac{386}{64} - 7 = \frac{386}{64} - \frac{7 \cdot 64}{1 \cdot 64} = \frac{386 - 448}{64} = \frac{-62}{64}.\]
This is still not equal to $0$, so we must have made another mistake.
When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that:
\[\frac{386}{64}
"""
print(extract_answer(text, "math-oai", use_last_number=False))
print(choice_answer_clean("\mathrm{(D)\}1,008,016"))
# should output a dict
if __name__ == "__main__":
_test_extract_answer()