import random import regex import re import sympy from latex2sympy2 import latex2sympy from typing import TypeVar, Iterable, List, Union, Any, Dict from word2number import w2n from utils import * def _fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += "\\frac" if len(substr) > 0 and substr[0] == "{": new_str += substr else: try: assert len(substr) >= 2 except: return string a = substr[0] b = substr[1] if b != "{": if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}{" + b + "}" + post_substr else: new_str += "{" + a + "}{" + b + "}" else: if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}" + b + post_substr else: new_str += "{" + a + "}" + b string = new_str return string def _fix_a_slash_b(string): if len(string.split("/")) != 2: return string a = string.split("/")[0] b = string.split("/")[1] try: if "sqrt" not in a: a = int(a) if "sqrt" not in b: b = int(b) assert string == "{}/{}".format(a, b) new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" return new_string except: return string def _fix_sqrt(string): _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) return _string def convert_word_number(text: str) -> str: try: text = str(w2n.word_to_num(text)) except: pass return text # units mainly from MathQA unit_texts = [ "east", "degree", "mph", "kmph", "ft", "m sqaure", " m east", "sq m", "deg", "mile", "q .", "monkey", "prime", "ratio", "profit of rs", "rd", "o", "gm", "p . m", "lb", "tile", "per", "dm", "lt", "gain", "ab", "way", "west", "a .", "b .", "c .", "d .", "e .", "f .", "g .", "h .", "t", "a", "h", "no change", "men", "soldier", "pie", "bc", "excess", "st", "inches", "noon", "percent", "by", "gal", "kmh", "c", "acre", "rise", "a . m", "th", "π r 2", "sq", "mark", "l", "toy", "coin", "sq . m", "gallon", "° f", "profit", "minw", "yr", "women", "feet", "am", "pm", "hr", "cu cm", "square", "v â € ™", "are", "rupee", "rounds", "cubic", "cc", "mtr", "s", "ohm", "number", "kmph", "day", "hour", "minute", "min", "second", "man", "woman", "sec", "cube", "mt", "sq inch", "mp", "∏ cm ³", "hectare", "more", "sec", "unit", "cu . m", "cm 2", "rs .", "rs", "kg", "g", "month", "km", "m", "cm", "mm", "apple", "liter", "loss", "yard", "pure", "year", "increase", "decrease", "d", "less", "Surface", "litre", "pi sq m", "s .", "metre", "meter", "inch", ] unit_texts.extend([t + "s" for t in unit_texts]) def strip_string(string, skip_unit=False): string = str(string).strip() # linebreaks string = string.replace("\n", "") # right "." string = string.rstrip(".") # remove inverse spaces # replace \\ with \ string = string.replace("\\!", "") # string = string.replace("\\ ", "") # string = string.replace("\\\\", "\\") # matrix string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) string = string.replace("bmatrix", "pmatrix") # replace tfrac and dfrac with frac string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") string = ( string.replace("\\neq", "\\ne") .replace("\\leq", "\\le") .replace("\\geq", "\\ge") ) # remove \left and \right string = string.replace("\\left", "") string = string.replace("\\right", "") string = string.replace("\\{", "{") string = string.replace("\\}", "}") # Remove unit: miles, dollars if after is not none _string = re.sub(r"\\text{.*?}$", "", string).strip() if _string != "" and _string != string: # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) string = _string if not skip_unit: # Remove unit: texts for _ in range(2): for unit_text in unit_texts: # use regex, the prefix should be either the start of the string or a non-alphanumeric character # the suffix should be either the end of the string or a non-alphanumeric character _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) if _string != "": string = _string # Remove circ (degrees) string = string.replace("^{\\circ}", "") string = string.replace("^\\circ", "") # remove dollar signs string = string.replace("\\$", "") string = string.replace("$", "") string = string.replace("\\(", "").replace("\\)", "") # convert word number to digit string = convert_word_number(string) # replace "\\text{...}" to "..." string = re.sub(r"\\text\{(.*?)\}", r"\1", string) for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: string = string.replace(key, "") string = string.replace("\\emptyset", r"{}") string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") # remove percentage string = string.replace("\\%", "") string = string.replace("\%", "") string = string.replace("%", "") # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(" .", " 0.") string = string.replace("{.", "{0.") # cdot # string = string.replace("\\cdot", "") if ( string.startswith("{") and string.endswith("}") and string.isalnum() or string.startswith("(") and string.endswith(")") and string.isalnum() or string.startswith("[") and string.endswith("]") and string.isalnum() ): string = string[1:-1] # inf string = string.replace("infinity", "\\infty") if "\\infty" not in string: string = string.replace("inf", "\\infty") string = string.replace("+\\inity", "\\infty") # and string = string.replace("and", "") string = string.replace("\\mathbf", "") # use regex to remove \mbox{...} string = re.sub(r"\\mbox{.*?}", "", string) # quote string.replace("'", "") string.replace('"', "") # i, j if "j" in string and "i" not in string: string = string.replace("j", "i") # replace a.000b where b is not number or b is end, with ab, use regex string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) string = re.sub(r"(\d+)\.0*$", r"\1", string) # if empty, return empty string if len(string) == 0: return string if string[0] == ".": string = "0" + string # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split("=")) == 2: if len(string.split("=")[0]) <= 2: string = string.split("=")[1] string = _fix_sqrt(string) string = string.replace(" ", "") # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = _fix_fracs(string) # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y string = _fix_a_slash_b(string) return string def extract_multi_choice_answer(pred_str): # TODO: SFT models if "Problem:" in pred_str: pred_str = pred_str.split("Problem:", 1)[0] pred_str = pred_str.replace("choice is", "answer is") patt = regex.search(r"answer is \(?(?P[abcde])\)?", pred_str.lower()) if patt is not None: return patt.group("ans").upper() return "placeholder" direct_answer_trigger_for_fewshot = ("choice is", "answer is") def choice_answer_clean(pred: str): pred = pred.strip("\n") # Determine if this is ICL, if so, use \n\n to split the first chunk. ICL = False for trigger in direct_answer_trigger_for_fewshot: if pred.count(trigger) > 1: ICL = True if ICL: pred = pred.split("\n\n")[0] # Split the trigger to find the answer. preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred) if len(preds) > 1: answer_flag = True pred = preds[-1] else: answer_flag = False pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") # Clean the answer based on the dataset tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) if tmp: pred = tmp else: pred = [pred.strip().strip(".")] if len(pred) == 0: pred = "" else: if answer_flag: # choose the first element in list ... pred = pred[0] else: # choose the last e pred = pred[-1] # Remove the period at the end, again! pred = pred.rstrip(".").rstrip("/") return pred def find_box(pred_str: str): ans = pred_str.split("boxed")[-1] if not ans: return "" if ans[0] == "{": stack = 1 a = "" for c in ans[1:]: if c == "{": stack += 1 a += c elif c == "}": stack -= 1 if stack == 0: break a += c else: a += c else: a = ans.split("$")[0].strip() return a def clean_units(pred_str: str): """Clean the units in the number.""" def convert_pi_to_number(code_string): code_string = code_string.replace("\\pi", "π") # Replace \pi or π not preceded by a digit or } with 3.14 code_string = re.sub(r"(? "3*3.14" code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string) # Handle cases where π is within braces or followed by a multiplication symbol # This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14" code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string) code_string = re.sub(r"\*(\\?π)", "*3.14", code_string) return code_string pred_str = convert_pi_to_number(pred_str) pred_str = pred_str.replace("%", "/100") pred_str = pred_str.replace("$", "") pred_str = pred_str.replace("¥", "") pred_str = pred_str.replace("°C", "") pred_str = pred_str.replace(" C", "") pred_str = pred_str.replace("°", "") return pred_str def extract_theoremqa_answer(pred: str, answer_flag: bool = True): if any([option in pred.lower() for option in ["yes", "true"]]): pred = "True" elif any([option in pred.lower() for option in ["no", "false"]]): pred = "False" elif any( [ option in pred.lower() for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"] ] ): pass else: # Some of the models somehow get used to boxed output from pre-training if "boxed" in pred: pred = find_box(pred) if answer_flag: # Extract the numbers out of the string pred = pred.split("=")[-1].strip() pred = clean_units(pred) try: tmp = str(latex2sympy(pred)) pred = str(eval(tmp)) except Exception: if re.match(r"-?[\d\.]+\s\D+$", pred): pred = pred.split(" ")[0] elif re.match(r"-?[\d\.]+\s[^\s]+$", pred): pred = pred.split(" ")[0] else: # desparate search over the last number preds = re.findall(r"-?\d*\.?\d+", pred) if len(preds) >= 1: pred = preds[-1] else: pred = "" return pred def extract_answer(pred_str, data_name, use_last_number=True): pred_str = pred_str.replace("\u043a\u0438", "") if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]: # TODO check multiple choice return choice_answer_clean(pred_str) if "final answer is $" in pred_str and "$. I hope" in pred_str: # minerva_math tmp = pred_str.split("final answer is $", 1)[1] pred = tmp.split("$. I hope", 1)[0].strip() elif "boxed" in pred_str: ans = pred_str.split("boxed")[-1] if len(ans) == 0: return "" elif ans[0] == "{": stack = 1 a = "" for c in ans[1:]: if c == "{": stack += 1 a += c elif c == "}": stack -= 1 if stack == 0: break a += c else: a += c else: a = ans.split("$")[0].strip() pred = a elif "he answer is" in pred_str: pred = pred_str.split("he answer is")[-1].strip() elif "final answer is" in pred_str: pred = pred_str.split("final answer is")[-1].strip() elif "答案是" in pred_str: # Handle Chinese few-shot multiple choice problem answer extraction pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() else: # use the last number if use_last_number: pattern = "-?\d*\.?\d+" pred = re.findall(pattern, pred_str.replace(",", "")) if len(pred) >= 1: pred = pred[-1] else: pred = "" else: pred = "" # choice answer if ( data_name in ["sat_math", "aqua"] or "mmlu" in data_name ): tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) if tmp: pred = tmp[-1] else: pred = pred.strip().strip(".") # multiple line # pred = pred.split("\n")[0] pred = re.sub(r"\n\s*", "", pred) if pred != "" and pred[0] == ":": pred = pred[1:] if pred != "" and pred[-1] == ".": pred = pred[:-1] if pred != "" and pred[-1] == "/": pred = pred[:-1] pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"]) return pred STRIP_EXCEPTIONS = ["carp_en", "minerva_math"] def parse_ground_truth(example: Dict[str, Any], data_name): if "gt_cot" in example and "gt" in example: if data_name in ["math"]: gt_ans = extract_answer(example["gt_cot"], data_name) elif data_name in STRIP_EXCEPTIONS: gt_ans = example["gt"] else: gt_ans = strip_string(example["gt"]) return example["gt_cot"], gt_ans # parse ground truth if data_name in ["math", "minerva_math"]: gt_cot = example["solution"] gt_ans = extract_answer(gt_cot, data_name) elif data_name == "gsm8k": gt_cot, gt_ans = example["answer"].split("####") elif data_name == "svamp": gt_cot, gt_ans = example["Equation"], example["Answer"] elif data_name == "asdiv": gt_cot = example["formula"] gt_ans = re.sub(r"\(.*?\)", "", example["answer"]) elif data_name == "mawps": gt_cot, gt_ans = None, example["target"] elif data_name == "tabmwp": gt_cot = example["solution"] gt_ans = example["answer"] if example["ans_type"] in ["integer_number", "decimal_number"]: if "/" in gt_ans: gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1]) elif "," in gt_ans: gt_ans = float(gt_ans.replace(",", "")) elif "%" in gt_ans: gt_ans = float(gt_ans.split("%")[0]) / 100 else: gt_ans = float(gt_ans) elif data_name == "carp_en": gt_cot, gt_ans = example["steps"], example["answer"] elif data_name == "mmlu_stem": abcd = "ABCD" gt_cot, gt_ans = None, abcd[example["answer"]] elif data_name == "sat_math": gt_cot, gt_ans = None, example["Answer"] elif data_name == "aqua": gt_cot, gt_ans = None, example["correct"] elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]: gt_cot, gt_ans = None, example["answer"].replace("$", "").strip() elif data_name == "gaokao_math_qa": gt_cot, gt_ans = None, example["label"] elif data_name in ["gaokao2024_mix", "cn_middle_school"]: if len(example["choice_answer"]) > 0: gt_cot, gt_ans = None, example["choice_answer"] else: gt_cot, gt_ans = None, example["answer"] elif data_name == "olympiadbench": gt_cot, gt_ans = None, example["final_answer"][0].strip("$") elif data_name in [ "aime24", "amc23", "cmath", "gaokao2024_I", "gaokao2024_II", "imo2024", ]: gt_cot, gt_ans = None, example["answer"] else: raise NotImplementedError(f"`{data_name}`") # post process gt_cot = str(gt_cot).strip() if data_name not in STRIP_EXCEPTIONS: gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en") else: gt_ans = ( gt_ans.replace("\\neq", "\\ne") .replace("\\leq", "\\le") .replace("\\geq", "\\ge") ) return gt_cot, gt_ans def parse_question(example, data_name): question = "" if data_name == "asdiv": question = f"{example['body'].strip()} {example['question'].strip()}" elif data_name == "svamp": body = example["Body"].strip() if not body.endswith("."): body = body + "." question = f'{body} {example["Question"].strip()}' elif data_name == "tabmwp": title_str = ( f'regarding "{example["table_title"]}" ' if example["table_title"] else "" ) question = f"Read the following table {title_str}and answer a question:\n" question += f'{example["table"]}\n{example["question"]}' if example["choices"]: question += ( f' Please select from the following options: {example["choices"]}' ) elif data_name == "carp_en": question = example["content"] elif data_name == "mmlu_stem": options = example["choices"] assert len(options) == 4 for i, (label, option) in enumerate(zip("ABCD", options)): options[i] = f"({label}) {str(option).strip()}" options = " ".join(options) # question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}" question = f"{example['question'].strip()}\nAnswer Choices: {options}" elif data_name == "sat_math": options = example["options"].strip() assert "A" == options[0] options = "(" + options for ch in "BCD": if f" {ch}) " in options: options = regex.sub(f" {ch}\) ", f" ({ch}) ", options) # question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}" question = f"{example['question'].strip()}\nAnswer Choices: {options}" elif "aqua" in data_name: options = example["options"] choice = "(" + "(".join(options) choice = choice.replace("(", " (").replace(")", ") ").strip() choice = "\nAnswer Choices: " + choice question = example["question"].strip() + choice elif data_name == "gaokao_math_qa": options_dict = example["options"] options = [] for key in options_dict: options.append(f"({key}) {options_dict[key]}") options = " ".join(options) question = f"{example['question'].strip()}\n选项: {options}" else: for key in ["question", "problem", "Question", "input"]: if key in example: question = example[key] break # assert question != "" # Yes or No question _, gt_ans = parse_ground_truth(example, data_name) if isinstance(gt_ans, str): gt_lower = gt_ans.lower() if gt_lower in ["true", "false"]: question += " (True or False)" if gt_lower in ["yes", "no"]: question += " (Yes or No)" return question.strip() def run_execute(executor, result, prompt_type, data_name, execute=False): if not result or result == "error": return None, None report = None if "program_only" in prompt_type: prediction = extract_program_output(result) elif prompt_type in ["pot", "pal"] and execute: code = extract_program(result) prediction, report = executor.apply(code) else: prediction = extract_answer(result, data_name) # prediction = strip_string(prediction, skip_unit=data_name == "carp_en") prediction = strip_string(prediction, skip_unit=data_name in STRIP_EXCEPTIONS) return prediction, report def _test_extract_answer(): text = """ This is still not equal to $0$, so we must have made another mistake. When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that: \[\frac{386}{64} - 7 = \frac{386}{64} - \frac{7 \cdot 64}{1 \cdot 64} = \frac{386 - 448}{64} = \frac{-62}{64}.\] This is still not equal to $0$, so we must have made another mistake. When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that: \[\frac{386}{64} """ print(extract_answer(text, "math-oai", use_last_number=False)) print(choice_answer_clean("\mathrm{(D)\}1,008,016")) # should output a dict if __name__ == "__main__": _test_extract_answer()