|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
from src.backend.envs import EVAL_REQUESTS_PATH_BACKEND |
|
from src.backend.manage_requests import get_eval_requests |
|
from src.backend.manage_requests import EvalRequest |
|
from src.backend.run_eval_suite import run_evaluation |
|
|
|
from src.backend.tasks.xsum.task import XSum |
|
from src.backend.tasks.xsum.task_v2 import XSumv2 |
|
|
|
from src.backend.tasks.cnndm.task import CNNDM |
|
from src.backend.tasks.cnndm.task_v2 import CNNDMv2 |
|
|
|
from src.backend.tasks.selfcheckgpt.task import SelfCheckGPT |
|
|
|
from lm_eval.tasks import TaskManager |
|
from lm_eval import tasks, evaluator, utils |
|
|
|
from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task |
|
from src.envs import QUEUE_REPO |
|
|
|
from lm_eval.models.huggingface import HFLM |
|
|
|
|
|
def main(): |
|
|
|
|
|
PENDING_STATUS = "PENDING" |
|
RUNNING_STATUS = "RUNNING" |
|
FINISHED_STATUS = "FINISHED" |
|
FAILED_STATUS = "FAILED" |
|
|
|
status = [PENDING_STATUS, RUNNING_STATUS, FINISHED_STATUS, FAILED_STATUS] |
|
|
|
|
|
eval_requests: list[EvalRequest] = get_eval_requests( |
|
job_status=status, hf_repo=QUEUE_REPO, local_dir=EVAL_REQUESTS_PATH_BACKEND, do_download=False |
|
) |
|
|
|
eval_request = [r for r in eval_requests if "meta-llama/Llama-2-7b-hf" in r.model][0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
my_task = Task("xsum_v2", "rougeL", "XXX", 0) |
|
|
|
|
|
|
|
eval_logger = utils.eval_logger |
|
import logging |
|
|
|
eval_logger.setLevel(getattr(logging, "DEBUG")) |
|
|
|
TASKS_HARNESS = [my_task] |
|
|
|
|
|
|
|
|
|
task_manager = TaskManager(include_path="./src/backend/tasks/") |
|
|
|
|
|
|
|
|
|
print(task_manager.all_tasks) |
|
|
|
for task in TASKS_HARNESS: |
|
print(f"Selected Tasks: [{task}]") |
|
import torch |
|
|
|
|
|
results = evaluator.simple_evaluate( |
|
model="hf", |
|
model_args=eval_request.get_model_args(), |
|
tasks=[task.benchmark], |
|
num_fewshot=task.num_fewshot, |
|
batch_size=1, |
|
device="mps", |
|
use_cache=None, |
|
limit=2, |
|
write_out=True, |
|
task_manager=task_manager, |
|
) |
|
print("AAA", results["results"]) |
|
|
|
breakpoint() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|