|
import random |
|
import regex |
|
import re |
|
import sympy |
|
from latex2sympy2 import latex2sympy |
|
from typing import TypeVar, Iterable, List, Union, Any, Dict |
|
from word2number import w2n |
|
from utils import * |
|
|
|
|
|
def _fix_fracs(string): |
|
substrs = string.split("\\frac") |
|
new_str = substrs[0] |
|
if len(substrs) > 1: |
|
substrs = substrs[1:] |
|
for substr in substrs: |
|
new_str += "\\frac" |
|
if len(substr) > 0 and substr[0] == "{": |
|
new_str += substr |
|
else: |
|
try: |
|
assert len(substr) >= 2 |
|
except: |
|
return string |
|
a = substr[0] |
|
b = substr[1] |
|
if b != "{": |
|
if len(substr) > 2: |
|
post_substr = substr[2:] |
|
new_str += "{" + a + "}{" + b + "}" + post_substr |
|
else: |
|
new_str += "{" + a + "}{" + b + "}" |
|
else: |
|
if len(substr) > 2: |
|
post_substr = substr[2:] |
|
new_str += "{" + a + "}" + b + post_substr |
|
else: |
|
new_str += "{" + a + "}" + b |
|
string = new_str |
|
return string |
|
|
|
|
|
def _fix_a_slash_b(string): |
|
if len(string.split("/")) != 2: |
|
return string |
|
a = string.split("/")[0] |
|
b = string.split("/")[1] |
|
try: |
|
if "sqrt" not in a: |
|
a = int(a) |
|
if "sqrt" not in b: |
|
b = int(b) |
|
assert string == "{}/{}".format(a, b) |
|
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" |
|
return new_string |
|
except: |
|
return string |
|
|
|
|
|
def _fix_sqrt(string): |
|
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) |
|
return _string |
|
|
|
|
|
def convert_word_number(text: str) -> str: |
|
try: |
|
text = str(w2n.word_to_num(text)) |
|
except: |
|
pass |
|
return text |
|
|
|
|
|
|
|
unit_texts = [ |
|
"east", |
|
"degree", |
|
"mph", |
|
"kmph", |
|
"ft", |
|
"m sqaure", |
|
" m east", |
|
"sq m", |
|
"deg", |
|
"mile", |
|
"q .", |
|
"monkey", |
|
"prime", |
|
"ratio", |
|
"profit of rs", |
|
"rd", |
|
"o", |
|
"gm", |
|
"p . m", |
|
"lb", |
|
"tile", |
|
"per", |
|
"dm", |
|
"lt", |
|
"gain", |
|
"ab", |
|
"way", |
|
"west", |
|
"a .", |
|
"b .", |
|
"c .", |
|
"d .", |
|
"e .", |
|
"f .", |
|
"g .", |
|
"h .", |
|
"t", |
|
"a", |
|
"h", |
|
"no change", |
|
"men", |
|
"soldier", |
|
"pie", |
|
"bc", |
|
"excess", |
|
"st", |
|
"inches", |
|
"noon", |
|
"percent", |
|
"by", |
|
"gal", |
|
"kmh", |
|
"c", |
|
"acre", |
|
"rise", |
|
"a . m", |
|
"th", |
|
"π r 2", |
|
"sq", |
|
"mark", |
|
"l", |
|
"toy", |
|
"coin", |
|
"sq . m", |
|
"gallon", |
|
"° f", |
|
"profit", |
|
"minw", |
|
"yr", |
|
"women", |
|
"feet", |
|
"am", |
|
"pm", |
|
"hr", |
|
"cu cm", |
|
"square", |
|
"v â € ™", |
|
"are", |
|
"rupee", |
|
"rounds", |
|
"cubic", |
|
"cc", |
|
"mtr", |
|
"s", |
|
"ohm", |
|
"number", |
|
"kmph", |
|
"day", |
|
"hour", |
|
"minute", |
|
"min", |
|
"second", |
|
"man", |
|
"woman", |
|
"sec", |
|
"cube", |
|
"mt", |
|
"sq inch", |
|
"mp", |
|
"∏ cm ³", |
|
"hectare", |
|
"more", |
|
"sec", |
|
"unit", |
|
"cu . m", |
|
"cm 2", |
|
"rs .", |
|
"rs", |
|
"kg", |
|
"g", |
|
"month", |
|
"km", |
|
"m", |
|
"cm", |
|
"mm", |
|
"apple", |
|
"liter", |
|
"loss", |
|
"yard", |
|
"pure", |
|
"year", |
|
"increase", |
|
"decrease", |
|
"d", |
|
"less", |
|
"Surface", |
|
"litre", |
|
"pi sq m", |
|
"s .", |
|
"metre", |
|
"meter", |
|
"inch", |
|
] |
|
|
|
unit_texts.extend([t + "s" for t in unit_texts]) |
|
|
|
|
|
def strip_string(string, skip_unit=False): |
|
string = str(string).strip() |
|
|
|
string = string.replace("\n", "") |
|
|
|
|
|
string = string.rstrip(".") |
|
|
|
|
|
|
|
string = string.replace("\\!", "") |
|
|
|
|
|
|
|
|
|
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) |
|
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) |
|
string = string.replace("bmatrix", "pmatrix") |
|
|
|
|
|
string = string.replace("tfrac", "frac") |
|
string = string.replace("dfrac", "frac") |
|
string = ( |
|
string.replace("\\neq", "\\ne") |
|
.replace("\\leq", "\\le") |
|
.replace("\\geq", "\\ge") |
|
) |
|
|
|
|
|
string = string.replace("\\left", "") |
|
string = string.replace("\\right", "") |
|
string = string.replace("\\{", "{") |
|
string = string.replace("\\}", "}") |
|
|
|
|
|
_string = re.sub(r"\\text{.*?}$", "", string).strip() |
|
if _string != "" and _string != string: |
|
|
|
string = _string |
|
|
|
if not skip_unit: |
|
|
|
for _ in range(2): |
|
for unit_text in unit_texts: |
|
|
|
|
|
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) |
|
if _string != "": |
|
string = _string |
|
|
|
|
|
string = string.replace("^{\\circ}", "") |
|
string = string.replace("^\\circ", "") |
|
|
|
|
|
string = string.replace("\\$", "") |
|
string = string.replace("$", "") |
|
string = string.replace("\\(", "").replace("\\)", "") |
|
|
|
|
|
string = convert_word_number(string) |
|
|
|
|
|
string = re.sub(r"\\text\{(.*?)\}", r"\1", string) |
|
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: |
|
string = string.replace(key, "") |
|
string = string.replace("\\emptyset", r"{}") |
|
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") |
|
|
|
|
|
string = string.replace("\\%", "") |
|
string = string.replace("\%", "") |
|
string = string.replace("%", "") |
|
|
|
|
|
string = string.replace(" .", " 0.") |
|
string = string.replace("{.", "{0.") |
|
|
|
|
|
|
|
if ( |
|
string.startswith("{") |
|
and string.endswith("}") |
|
and string.isalnum() |
|
or string.startswith("(") |
|
and string.endswith(")") |
|
and string.isalnum() |
|
or string.startswith("[") |
|
and string.endswith("]") |
|
and string.isalnum() |
|
): |
|
string = string[1:-1] |
|
|
|
|
|
string = string.replace("infinity", "\\infty") |
|
if "\\infty" not in string: |
|
string = string.replace("inf", "\\infty") |
|
string = string.replace("+\\inity", "\\infty") |
|
|
|
|
|
string = string.replace("and", "") |
|
string = string.replace("\\mathbf", "") |
|
|
|
|
|
string = re.sub(r"\\mbox{.*?}", "", string) |
|
|
|
|
|
string.replace("'", "") |
|
string.replace('"', "") |
|
|
|
|
|
if "j" in string and "i" not in string: |
|
string = string.replace("j", "i") |
|
|
|
|
|
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) |
|
string = re.sub(r"(\d+)\.0*$", r"\1", string) |
|
|
|
|
|
if len(string) == 0: |
|
return string |
|
if string[0] == ".": |
|
string = "0" + string |
|
|
|
|
|
if len(string.split("=")) == 2: |
|
if len(string.split("=")[0]) <= 2: |
|
string = string.split("=")[1] |
|
|
|
string = _fix_sqrt(string) |
|
string = string.replace(" ", "") |
|
|
|
|
|
string = _fix_fracs(string) |
|
|
|
|
|
string = _fix_a_slash_b(string) |
|
|
|
return string |
|
|
|
|
|
def extract_multi_choice_answer(pred_str): |
|
|
|
if "Problem:" in pred_str: |
|
pred_str = pred_str.split("Problem:", 1)[0] |
|
pred_str = pred_str.replace("choice is", "answer is") |
|
patt = regex.search(r"answer is \(?(?P<ans>[abcde])\)?", pred_str.lower()) |
|
if patt is not None: |
|
return patt.group("ans").upper() |
|
return "placeholder" |
|
|
|
|
|
direct_answer_trigger_for_fewshot = ("choice is", "answer is") |
|
|
|
|
|
def choice_answer_clean(pred: str): |
|
pred = pred.strip("\n") |
|
|
|
|
|
ICL = False |
|
for trigger in direct_answer_trigger_for_fewshot: |
|
if pred.count(trigger) > 1: |
|
ICL = True |
|
if ICL: |
|
pred = pred.split("\n\n")[0] |
|
|
|
|
|
preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred) |
|
if len(preds) > 1: |
|
answer_flag = True |
|
pred = preds[-1] |
|
else: |
|
answer_flag = False |
|
|
|
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") |
|
|
|
|
|
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) |
|
if tmp: |
|
pred = tmp |
|
else: |
|
pred = [pred.strip().strip(".")] |
|
|
|
if len(pred) == 0: |
|
pred = "" |
|
else: |
|
if answer_flag: |
|
|
|
pred = pred[0] |
|
else: |
|
|
|
pred = pred[-1] |
|
|
|
|
|
pred = pred.rstrip(".").rstrip("/") |
|
|
|
return pred |
|
|
|
|
|
def find_box(pred_str: str): |
|
ans = pred_str.split("boxed")[-1] |
|
if not ans: |
|
return "" |
|
if ans[0] == "{": |
|
stack = 1 |
|
a = "" |
|
for c in ans[1:]: |
|
if c == "{": |
|
stack += 1 |
|
a += c |
|
elif c == "}": |
|
stack -= 1 |
|
if stack == 0: |
|
break |
|
a += c |
|
else: |
|
a += c |
|
else: |
|
a = ans.split("$")[0].strip() |
|
return a |
|
|
|
|
|
def clean_units(pred_str: str): |
|
"""Clean the units in the number.""" |
|
|
|
def convert_pi_to_number(code_string): |
|
code_string = code_string.replace("\\pi", "π") |
|
|
|
code_string = re.sub(r"(?<![\d}])\\?π", "3.14", code_string) |
|
|
|
code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string) |
|
|
|
|
|
code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string) |
|
code_string = re.sub(r"\*(\\?π)", "*3.14", code_string) |
|
return code_string |
|
|
|
pred_str = convert_pi_to_number(pred_str) |
|
pred_str = pred_str.replace("%", "/100") |
|
pred_str = pred_str.replace("$", "") |
|
pred_str = pred_str.replace("¥", "") |
|
pred_str = pred_str.replace("°C", "") |
|
pred_str = pred_str.replace(" C", "") |
|
pred_str = pred_str.replace("°", "") |
|
return pred_str |
|
|
|
|
|
def extract_theoremqa_answer(pred: str, answer_flag: bool = True): |
|
if any([option in pred.lower() for option in ["yes", "true"]]): |
|
pred = "True" |
|
elif any([option in pred.lower() for option in ["no", "false"]]): |
|
pred = "False" |
|
elif any( |
|
[ |
|
option in pred.lower() |
|
for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"] |
|
] |
|
): |
|
pass |
|
else: |
|
|
|
if "boxed" in pred: |
|
pred = find_box(pred) |
|
|
|
if answer_flag: |
|
|
|
pred = pred.split("=")[-1].strip() |
|
pred = clean_units(pred) |
|
try: |
|
tmp = str(latex2sympy(pred)) |
|
pred = str(eval(tmp)) |
|
except Exception: |
|
if re.match(r"-?[\d\.]+\s\D+$", pred): |
|
pred = pred.split(" ")[0] |
|
elif re.match(r"-?[\d\.]+\s[^\s]+$", pred): |
|
pred = pred.split(" ")[0] |
|
else: |
|
|
|
preds = re.findall(r"-?\d*\.?\d+", pred) |
|
if len(preds) >= 1: |
|
pred = preds[-1] |
|
else: |
|
pred = "" |
|
|
|
return pred |
|
|
|
|
|
def extract_answer(pred_str, data_name, use_last_number=True): |
|
pred_str = pred_str.replace("\u043a\u0438", "") |
|
if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]: |
|
|
|
return choice_answer_clean(pred_str) |
|
|
|
if "final answer is $" in pred_str and "$. I hope" in pred_str: |
|
|
|
tmp = pred_str.split("final answer is $", 1)[1] |
|
pred = tmp.split("$. I hope", 1)[0].strip() |
|
elif "boxed" in pred_str: |
|
ans = pred_str.split("boxed")[-1] |
|
if len(ans) == 0: |
|
return "" |
|
elif ans[0] == "{": |
|
stack = 1 |
|
a = "" |
|
for c in ans[1:]: |
|
if c == "{": |
|
stack += 1 |
|
a += c |
|
elif c == "}": |
|
stack -= 1 |
|
if stack == 0: |
|
break |
|
a += c |
|
else: |
|
a += c |
|
else: |
|
a = ans.split("$")[0].strip() |
|
pred = a |
|
elif "he answer is" in pred_str: |
|
pred = pred_str.split("he answer is")[-1].strip() |
|
elif "final answer is" in pred_str: |
|
pred = pred_str.split("final answer is")[-1].strip() |
|
elif "答案是" in pred_str: |
|
|
|
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() |
|
else: |
|
if use_last_number: |
|
pattern = "-?\d*\.?\d+" |
|
pred = re.findall(pattern, pred_str.replace(",", "")) |
|
if len(pred) >= 1: |
|
pred = pred[-1] |
|
else: |
|
pred = "" |
|
else: |
|
pred = "" |
|
|
|
|
|
if ( |
|
data_name in ["sat_math", "aqua"] |
|
or "mmlu" in data_name |
|
): |
|
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) |
|
if tmp: |
|
pred = tmp[-1] |
|
else: |
|
pred = pred.strip().strip(".") |
|
|
|
|
|
|
|
pred = re.sub(r"\n\s*", "", pred) |
|
if pred != "" and pred[0] == ":": |
|
pred = pred[1:] |
|
if pred != "" and pred[-1] == ".": |
|
pred = pred[:-1] |
|
if pred != "" and pred[-1] == "/": |
|
pred = pred[:-1] |
|
pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"]) |
|
return pred |
|
|
|
|
|
STRIP_EXCEPTIONS = ["carp_en", "minerva_math"] |
|
|
|
|
|
def parse_ground_truth(example: Dict[str, Any], data_name): |
|
if "gt_cot" in example and "gt" in example: |
|
if data_name in ["math"]: |
|
gt_ans = extract_answer(example["gt_cot"], data_name) |
|
elif data_name in STRIP_EXCEPTIONS: |
|
gt_ans = example["gt"] |
|
else: |
|
gt_ans = strip_string(example["gt"]) |
|
return example["gt_cot"], gt_ans |
|
|
|
|
|
if data_name in ["math", "minerva_math"]: |
|
gt_cot = example["solution"] |
|
gt_ans = extract_answer(gt_cot, data_name) |
|
elif data_name == "gsm8k": |
|
gt_cot, gt_ans = example["answer"].split("####") |
|
elif data_name == "svamp": |
|
gt_cot, gt_ans = example["Equation"], example["Answer"] |
|
elif data_name == "asdiv": |
|
gt_cot = example["formula"] |
|
gt_ans = re.sub(r"\(.*?\)", "", example["answer"]) |
|
elif data_name == "mawps": |
|
gt_cot, gt_ans = None, example["target"] |
|
elif data_name == "tabmwp": |
|
gt_cot = example["solution"] |
|
gt_ans = example["answer"] |
|
if example["ans_type"] in ["integer_number", "decimal_number"]: |
|
if "/" in gt_ans: |
|
gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1]) |
|
elif "," in gt_ans: |
|
gt_ans = float(gt_ans.replace(",", "")) |
|
elif "%" in gt_ans: |
|
gt_ans = float(gt_ans.split("%")[0]) / 100 |
|
else: |
|
gt_ans = float(gt_ans) |
|
elif data_name == "carp_en": |
|
gt_cot, gt_ans = example["steps"], example["answer"] |
|
elif data_name == "mmlu_stem": |
|
abcd = "ABCD" |
|
gt_cot, gt_ans = None, abcd[example["answer"]] |
|
elif data_name == "sat_math": |
|
gt_cot, gt_ans = None, example["Answer"] |
|
elif data_name == "aqua": |
|
gt_cot, gt_ans = None, example["correct"] |
|
elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]: |
|
gt_cot, gt_ans = None, example["answer"].replace("$", "").strip() |
|
elif data_name == "gaokao_math_qa": |
|
gt_cot, gt_ans = None, example["label"] |
|
elif data_name in ["gaokao2024_mix", "cn_middle_school"]: |
|
if len(example["choice_answer"]) > 0: |
|
gt_cot, gt_ans = None, example["choice_answer"] |
|
else: |
|
gt_cot, gt_ans = None, example["answer"] |
|
elif data_name == "olympiadbench": |
|
gt_cot, gt_ans = None, example["final_answer"][0].strip("$") |
|
elif data_name in [ |
|
"aime24", |
|
"amc23", |
|
"cmath", |
|
"gaokao2024_I", |
|
"gaokao2024_II", |
|
"imo2024", |
|
]: |
|
gt_cot, gt_ans = None, example["answer"] |
|
else: |
|
raise NotImplementedError(f"`{data_name}`") |
|
|
|
gt_cot = str(gt_cot).strip() |
|
if data_name not in STRIP_EXCEPTIONS: |
|
gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en") |
|
else: |
|
gt_ans = ( |
|
gt_ans.replace("\\neq", "\\ne") |
|
.replace("\\leq", "\\le") |
|
.replace("\\geq", "\\ge") |
|
) |
|
return gt_cot, gt_ans |
|
|
|
|
|
def parse_question(example, data_name): |
|
question = "" |
|
if data_name == "asdiv": |
|
question = f"{example['body'].strip()} {example['question'].strip()}" |
|
elif data_name == "svamp": |
|
body = example["Body"].strip() |
|
if not body.endswith("."): |
|
body = body + "." |
|
question = f'{body} {example["Question"].strip()}' |
|
elif data_name == "tabmwp": |
|
title_str = ( |
|
f'regarding "{example["table_title"]}" ' if example["table_title"] else "" |
|
) |
|
question = f"Read the following table {title_str}and answer a question:\n" |
|
question += f'{example["table"]}\n{example["question"]}' |
|
if example["choices"]: |
|
question += ( |
|
f' Please select from the following options: {example["choices"]}' |
|
) |
|
elif data_name == "carp_en": |
|
question = example["content"] |
|
elif data_name == "mmlu_stem": |
|
options = example["choices"] |
|
assert len(options) == 4 |
|
for i, (label, option) in enumerate(zip("ABCD", options)): |
|
options[i] = f"({label}) {str(option).strip()}" |
|
options = " ".join(options) |
|
|
|
question = f"{example['question'].strip()}\nAnswer Choices: {options}" |
|
elif data_name == "sat_math": |
|
options = example["options"].strip() |
|
assert "A" == options[0] |
|
options = "(" + options |
|
for ch in "BCD": |
|
if f" {ch}) " in options: |
|
options = regex.sub(f" {ch}\) ", f" ({ch}) ", options) |
|
|
|
question = f"{example['question'].strip()}\nAnswer Choices: {options}" |
|
elif "aqua" in data_name: |
|
options = example["options"] |
|
choice = "(" + "(".join(options) |
|
choice = choice.replace("(", " (").replace(")", ") ").strip() |
|
choice = "\nAnswer Choices: " + choice |
|
question = example["question"].strip() + choice |
|
elif data_name == "gaokao_math_qa": |
|
options_dict = example["options"] |
|
options = [] |
|
for key in options_dict: |
|
options.append(f"({key}) {options_dict[key]}") |
|
options = " ".join(options) |
|
question = f"{example['question'].strip()}\n选项: {options}" |
|
else: |
|
for key in ["question", "problem", "Question", "input"]: |
|
if key in example: |
|
question = example[key] |
|
break |
|
|
|
|
|
_, gt_ans = parse_ground_truth(example, data_name) |
|
if isinstance(gt_ans, str): |
|
gt_lower = gt_ans.lower() |
|
if gt_lower in ["true", "false"]: |
|
question += " (True or False)" |
|
if gt_lower in ["yes", "no"]: |
|
question += " (Yes or No)" |
|
return question.strip() |
|
|
|
|
|
def run_execute(executor, result, prompt_type, data_name, execute=False): |
|
if not result or result == "error": |
|
return None, None |
|
report = None |
|
|
|
if "program_only" in prompt_type: |
|
prediction = extract_program_output(result) |
|
elif prompt_type in ["pot", "pal"] and execute: |
|
code = extract_program(result) |
|
prediction, report = executor.apply(code) |
|
else: |
|
prediction = extract_answer(result, data_name) |
|
|
|
|
|
prediction = strip_string(prediction, skip_unit=data_name in STRIP_EXCEPTIONS) |
|
return prediction, report |
|
|
|
|
|
def _test_extract_answer(): |
|
text = """ |
|
This is still not equal to $0$, so we must have made another mistake. |
|
|
|
When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that: |
|
|
|
\[\frac{386}{64} - 7 = \frac{386}{64} - \frac{7 \cdot 64}{1 \cdot 64} = \frac{386 - 448}{64} = \frac{-62}{64}.\] |
|
|
|
This is still not equal to $0$, so we must have made another mistake. |
|
|
|
When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that: |
|
|
|
\[\frac{386}{64} |
|
""" |
|
print(extract_answer(text, "math-oai", use_last_number=False)) |
|
print(choice_answer_clean("\mathrm{(D)\}1,008,016")) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
_test_extract_answer() |
|
|