File size: 4,315 Bytes
122057f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
import uuid
from typing import Optional

import requests
from huggingface_hub import Discussion, HfApi, get_repo_discussions

from .utils import cached_file, logging


logger = logging.get_logger(__name__)


def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]:
    main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id
    for discussion in get_repo_discussions(repo_id=model_id, token=token):
        if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request:
            commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token)

            if main_commit == commits[1].commit_id:
                return discussion
    return None


def spawn_conversion(token: str, private: bool, model_id: str):
    logger.info("Attempting to convert .bin model on the fly to safetensors.")

    safetensors_convert_space_url = "https://safetensors-convert.hf.space"
    sse_url = f"{safetensors_convert_space_url}/queue/join"
    sse_data_url = f"{safetensors_convert_space_url}/queue/data"

    # The `fn_index` is necessary to indicate to gradio that we will use the `run` method of the Space.
    hash_data = {"fn_index": 1, "session_hash": str(uuid.uuid4())}

    def start(_sse_connection, payload):
        for line in _sse_connection.iter_lines():
            line = line.decode()
            if line.startswith("data:"):
                resp = json.loads(line[5:])
                logger.debug(f"Safetensors conversion status: {resp['msg']}")
                if resp["msg"] == "queue_full":
                    raise ValueError("Queue is full! Please try again.")
                elif resp["msg"] == "send_data":
                    event_id = resp["event_id"]
                    response = requests.post(
                        sse_data_url,
                        stream=True,
                        params=hash_data,
                        json={"event_id": event_id, **payload, **hash_data},
                    )
                    response.raise_for_status()
                elif resp["msg"] == "process_completed":
                    return

    with requests.get(sse_url, stream=True, params=hash_data) as sse_connection:
        data = {"data": [model_id, private, token]}
        try:
            logger.debug("Spawning safetensors automatic conversion.")
            start(sse_connection, data)
        except Exception as e:
            logger.warning(f"Error during conversion: {repr(e)}")


def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
    private = api.model_info(model_id).private

    logger.info("Attempting to create safetensors variant")
    pr_title = "Adding `safetensors` variant of this model"
    token = kwargs.get("token")

    # This looks into the current repo's open PRs to see if a PR for safetensors was already open. If so, it
    # returns it. It checks that the PR was opened by the bot and not by another user so as to prevent
    # security breaches.
    pr = previous_pr(api, model_id, pr_title, token=token)

    if pr is None or (not private and pr.author != "SFConvertBot"):
        spawn_conversion(token, private, model_id)
        pr = previous_pr(api, model_id, pr_title, token=token)
    else:
        logger.info("Safetensors PR exists")

    sha = f"refs/pr/{pr.num}"

    return sha


def auto_conversion(pretrained_model_name_or_path: str, **cached_file_kwargs):
    api = HfApi(token=cached_file_kwargs.get("token"))
    sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)

    if sha is None:
        return None, None
    cached_file_kwargs["revision"] = sha
    del cached_file_kwargs["_commit_hash"]

    # This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
    # description.
    sharded = api.file_exists(
        pretrained_model_name_or_path,
        "model.safetensors.index.json",
        revision=sha,
        token=cached_file_kwargs.get("token"),
    )
    filename = "model.safetensors.index.json" if sharded else "model.safetensors"

    resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
    return resolved_archive_file, sha, sharded