|
import os |
|
import json |
|
import random |
|
import datasets |
|
from datasets import load_dataset, Dataset, concatenate_datasets |
|
from utils import load_jsonl, lower_keys |
|
|
|
|
|
def load_data(data_name, split, data_dir="./data"): |
|
data_file = f"{data_dir}/{data_name}/{split}.jsonl" |
|
if os.path.exists(data_file): |
|
examples = list(load_jsonl(data_file)) |
|
else: |
|
if data_name == "math": |
|
dataset = load_dataset( |
|
"competition_math", |
|
split=split, |
|
name="main", |
|
cache_dir=f"{data_dir}/temp", |
|
) |
|
elif data_name == "gsm8k": |
|
dataset = load_dataset(data_name, split=split) |
|
elif data_name == "svamp": |
|
|
|
dataset = load_dataset("ChilleD/SVAMP", split="train") |
|
dataset = concatenate_datasets( |
|
[dataset, load_dataset("ChilleD/SVAMP", split="test")] |
|
) |
|
elif data_name == "asdiv": |
|
dataset = load_dataset("EleutherAI/asdiv", split="validation") |
|
dataset = dataset.filter( |
|
lambda x: ";" not in x["answer"] |
|
) |
|
elif data_name == "mawps": |
|
examples = [] |
|
|
|
for data_name in ["singleeq", "singleop", "addsub", "multiarith"]: |
|
sub_examples = list(load_jsonl(f"{data_dir}/mawps/{data_name}.jsonl")) |
|
for example in sub_examples: |
|
example["type"] = data_name |
|
examples.extend(sub_examples) |
|
dataset = Dataset.from_list(examples) |
|
elif data_name == "mmlu_stem": |
|
dataset = load_dataset("hails/mmlu_no_train", "all", split="test") |
|
|
|
stem_subjects = [ |
|
"abstract_algebra", |
|
"astronomy", |
|
"college_biology", |
|
"college_chemistry", |
|
"college_computer_science", |
|
"college_mathematics", |
|
"college_physics", |
|
"computer_security", |
|
"conceptual_physics", |
|
"electrical_engineering", |
|
"elementary_mathematics", |
|
"high_school_biology", |
|
"high_school_chemistry", |
|
"high_school_computer_science", |
|
"high_school_mathematics", |
|
"high_school_physics", |
|
"high_school_statistics", |
|
"machine_learning", |
|
] |
|
dataset = dataset.rename_column("subject", "type") |
|
dataset = dataset.filter(lambda x: x["type"] in stem_subjects) |
|
elif data_name == "carp_en": |
|
dataset = load_jsonl(f"{data_dir}/carp_en/test.jsonl") |
|
else: |
|
raise NotImplementedError(data_name) |
|
|
|
examples = list(dataset) |
|
examples = [lower_keys(example) for example in examples] |
|
dataset = Dataset.from_list(examples) |
|
os.makedirs(f"{data_dir}/{data_name}", exist_ok=True) |
|
dataset.to_json(data_file) |
|
|
|
|
|
if "idx" not in examples[0]: |
|
examples = [{"idx": i, **example} for i, example in enumerate(examples)] |
|
|
|
|
|
examples = sorted(examples, key=lambda x: x["idx"]) |
|
return examples |
|
|