import os from typing import Optional import requests from fastapi import BackgroundTasks, FastAPI, Header, HTTPException from fastapi.responses import FileResponse from huggingface_hub.hf_api import HfApi from .models import WebhookPayload, config WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET") HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN") AUTOTRAIN_API_URL = "https://api.autotrain.huggingface.co" AUTOTRAIN_UI_URL = "https://ui.autotrain.huggingface.co" app = FastAPI() @app.get("/") async def home(): return FileResponse("home.html") @app.post("/webhook") async def post_webhook( payload: WebhookPayload, task_queue: BackgroundTasks, x_webhook_secret: Optional[str] = Header(default=None), ): if x_webhook_secret is None: raise HTTPException(401) if x_webhook_secret != WEBHOOK_SECRET: raise HTTPException(403) print(payload) if not ( payload.event.action == "update" and payload.event.scope.startswith("repo.content") and payload.repo.name == config.input_dataset and payload.repo.type == "dataset" ): # no-op return {"processed": False} task_queue.add_task(schedule_retrain, payload) return {"processed": True} def schedule_retrain(payload: WebhookPayload): # Create the autotrain project try: project = AutoTrain.create_project(payload) AutoTrain.add_data(project_id=project["id"]) AutoTrain.start_processing(project_id=project["id"]) except requests.HTTPError as err: print("ERROR while requesting AutoTrain API:") print(f" code: {err.response.status_code}") print(f" {err.response.json()}") raise # Notify in the community tab notify_success(project["id"]) return {"processed": True} class AutoTrain: @staticmethod def create_project(payload: WebhookPayload) -> dict: project_resp = requests.post( f"{AUTOTRAIN_API_URL}/projects/create", json={ "username": config.target_namespace, "proj_name": ( f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}" ), "task": 18, # image-multi-class-classification "config": { "hub-model": config.input_model, "max_models": 1, "language": "unk", }, }, headers={"Authorization": f"Bearer {HF_ACCESS_TOKEN}"}, ) project_resp.raise_for_status() return project_resp.json() @staticmethod def add_data(project_id: int): requests.post( f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/dataset", json={ "dataset_id": config.input_dataset, "dataset_split": "train", "split": 4, "col_mapping": { "image": "image", "label": "target", }, }, headers={ "Authorization": f"Bearer {HF_ACCESS_TOKEN}", }, ).raise_for_status() @staticmethod def start_processing(project_id: int): resp = requests.post( f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/start_processing", headers={ "Authorization": f"Bearer {HF_ACCESS_TOKEN}", }, ) resp.raise_for_status() return resp def notify_success(project_id: int): message = NOTIFICATION_TEMPLATE.format( input_model=config.input_model, input_dataset=config.input_dataset, project_id=project_id, ui_url=AUTOTRAIN_UI_URL, ) return HfApi(token=HF_ACCESS_TOKEN).create_discussion( repo_id=config.input_dataset, repo_type="dataset", title="✨ Retraining started!", description=message, token=HF_ACCESS_TOKEN, ) NOTIFICATION_TEMPLATE = """\ 🌸 Hello there! Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain! Please review and approve the project [here]({ui_url}/{project_id}/trainings) to start the training job. (This is an automated message) """