Spaces:
Running
Running
Upload 6 files
Browse files- README.md +6 -5
- app.py +57 -0
- civitai_to_hf.py +127 -0
- packages.txt +1 -0
- requirements.txt +2 -0
- utils.py +161 -0
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.44.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: CivitAI to HF Downloader Alpha
|
3 |
+
emoji: 🤗
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.44.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from civitai_to_hf import search_civitai, download_civitai, select_civitai_item, CIVITAI_TYPE, CIVITAI_BASEMODEL, CIVITAI_SORT, CIVITAI_PERIOD
|
3 |
+
|
4 |
+
css = """
|
5 |
+
.title { text-align: center; !important; }
|
6 |
+
"""
|
7 |
+
|
8 |
+
with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
|
9 |
+
with gr.Column():
|
10 |
+
gr.Markdown("# Civitai to HF Downloader Alpha", elem_classes="title")
|
11 |
+
with gr.Accordion("Search Civitai", open=False):
|
12 |
+
with gr.Row():
|
13 |
+
search_civitai_type = gr.CheckboxGroup(label="Type", choices=CIVITAI_TYPE, value=[])
|
14 |
+
search_civitai_basemodel = gr.CheckboxGroup(label="Base model", choices=CIVITAI_BASEMODEL, value=[])
|
15 |
+
with gr.Row():
|
16 |
+
search_civitai_sort = gr.Radio(label="Sort", choices=CIVITAI_SORT, value=CIVITAI_SORT[0])
|
17 |
+
search_civitai_period = gr.Radio(label="Period", choices=CIVITAI_PERIOD, value=CIVITAI_PERIOD[0])
|
18 |
+
with gr.Row():
|
19 |
+
search_civitai_query = gr.Textbox(label="Query", placeholder="oomuro sakurako...", lines=1)
|
20 |
+
search_civitai_tag = gr.Textbox(label="Tag", lines=1)
|
21 |
+
search_civitai_submit = gr.Button("Search on Civitai")
|
22 |
+
with gr.Row():
|
23 |
+
search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
|
24 |
+
search_civitai_json = gr.JSON(value={}, visible=False)
|
25 |
+
search_civitai_desc = gr.Markdown(value="", visible=False)
|
26 |
+
dl_url = gr.Textbox(label="Download URL", placeholder="https://civitai.com/api/download/models/28907", value="", lines=1)
|
27 |
+
civitai_key = gr.Textbox(label="Your Civitai Key", value="", max_lines=1)
|
28 |
+
with gr.Row():
|
29 |
+
hf_token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
|
30 |
+
gr.Markdown("Your token is available at [hf.co/settings/tokens](https://huggingface.co/settings/tokens).")
|
31 |
+
with gr.Row():
|
32 |
+
newrepo_id = gr.Textbox(label="Upload repo ID", placeholder="yourid/yourrepo", value="", max_lines=1)
|
33 |
+
newrepo_type = gr.Radio(label="Upload repo type", choices=["model", "dataset"], value="model")
|
34 |
+
is_private = gr.Checkbox(label="Create private repo", value=True)
|
35 |
+
uploaded_urls = gr.CheckboxGroup(visible=False, choices=[], value=None) # hidden
|
36 |
+
run_button = gr.Button(value="Download and Upload")
|
37 |
+
urls_md = gr.Markdown()
|
38 |
+
gr.DuplicateButton(value="Duplicate Space")
|
39 |
+
|
40 |
+
gr.on(
|
41 |
+
triggers=[run_button.click],
|
42 |
+
fn=download_civitai,
|
43 |
+
inputs=[dl_url, civitai_key, hf_token, uploaded_urls, newrepo_id, newrepo_type, is_private],
|
44 |
+
outputs=[uploaded_urls, urls_md],
|
45 |
+
)
|
46 |
+
gr.on(
|
47 |
+
triggers=[search_civitai_submit.click, search_civitai_query.submit, search_civitai_tag.submit],
|
48 |
+
fn=search_civitai,
|
49 |
+
inputs=[search_civitai_query, search_civitai_type, search_civitai_basemodel, search_civitai_sort, search_civitai_period, search_civitai_tag],
|
50 |
+
outputs=[search_civitai_result, search_civitai_desc, search_civitai_submit, search_civitai_query],
|
51 |
+
queue=True,
|
52 |
+
show_api=False,
|
53 |
+
)
|
54 |
+
search_civitai_result.change(select_civitai_item, [search_civitai_result], [dl_url, search_civitai_desc], queue=False, show_api=False)
|
55 |
+
|
56 |
+
demo.queue()
|
57 |
+
demo.launch()
|
civitai_to_hf.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import HfApi, hf_hub_url
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
import gc
|
6 |
+
import requests
|
7 |
+
from requests.adapters import HTTPAdapter
|
8 |
+
from urllib3.util import Retry
|
9 |
+
from utils import get_token, set_token, is_repo_exists, get_user_agent, get_download_file
|
10 |
+
|
11 |
+
|
12 |
+
def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
|
13 |
+
output_filename = Path(filename).name
|
14 |
+
hf_token = get_token()
|
15 |
+
api = HfApi(token=hf_token)
|
16 |
+
try:
|
17 |
+
if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
|
18 |
+
progress(0, desc="Start uploading...")
|
19 |
+
api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
|
20 |
+
progress(1, desc="Uploaded.")
|
21 |
+
url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
|
22 |
+
except Exception as e:
|
23 |
+
print(f"Error: Failed to upload to {repo_id}. {e}")
|
24 |
+
gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
|
25 |
+
return None
|
26 |
+
return url
|
27 |
+
|
28 |
+
|
29 |
+
def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
|
30 |
+
download_dir = "."
|
31 |
+
progress(0, desc="Start downloading...")
|
32 |
+
output_filename = get_download_file(download_dir, dl_url, civitai_key)
|
33 |
+
return output_filename
|
34 |
+
|
35 |
+
|
36 |
+
def download_civitai(dl_url, civitai_key, hf_token, urls,
|
37 |
+
newrepo_id, repo_type="model", is_private=True, progress=gr.Progress(track_tqdm=True)):
|
38 |
+
if hf_token: set_token(hf_token)
|
39 |
+
else: set_token(os.environ.get("HF_TOKEN"))
|
40 |
+
if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY")
|
41 |
+
if not hf_token or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.")
|
42 |
+
file = download_file(dl_url, civitai_key)
|
43 |
+
if not urls: urls = []
|
44 |
+
url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
|
45 |
+
progress(1, desc="Processing...")
|
46 |
+
if url: urls.append(url)
|
47 |
+
Path(file).unlink()
|
48 |
+
md = ""
|
49 |
+
for u in urls:
|
50 |
+
md += f"[Uploaded {str(u)}]({str(u)})<br>"
|
51 |
+
gc.collect()
|
52 |
+
return gr.update(value=urls, choices=urls), gr.update(value=md)
|
53 |
+
|
54 |
+
|
55 |
+
CIVITAI_TYPE = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "Controlnet", "Poses"]
|
56 |
+
CIVITAI_BASEMODEL = ["Pony", "SD 1.5", "SDXL 1.0", "Flux.1 D", "Flux.1 S"]
|
57 |
+
CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Newest"]
|
58 |
+
CIVITAI_PERIOD = ["AllTime", "Year", "Month", "Week", "Day"]
|
59 |
+
|
60 |
+
|
61 |
+
def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100,
|
62 |
+
sort: str = "Highest Rated", period: str = "AllTime", tag: str = ""):
|
63 |
+
|
64 |
+
user_agent = get_user_agent()
|
65 |
+
headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
|
66 |
+
base_url = 'https://civitai.com/api/v1/models'
|
67 |
+
params = {'sort': sort, 'period': period, 'limit': limit, 'nsfw': 'true'}
|
68 |
+
if len(types) != 0: params["types"] = types
|
69 |
+
if query: params["query"] = query
|
70 |
+
if tag: params["tag"] = tag
|
71 |
+
session = requests.Session()
|
72 |
+
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
|
73 |
+
session.mount("https://", HTTPAdapter(max_retries=retries))
|
74 |
+
try:
|
75 |
+
r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(3.0, 30))
|
76 |
+
except Exception as e:
|
77 |
+
print(e)
|
78 |
+
return None
|
79 |
+
else:
|
80 |
+
if not r.ok: return None
|
81 |
+
json = r.json()
|
82 |
+
if 'items' not in json: return None
|
83 |
+
items = []
|
84 |
+
for j in json['items']:
|
85 |
+
for model in j['modelVersions']:
|
86 |
+
item = {}
|
87 |
+
if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue
|
88 |
+
item['name'] = j['name']
|
89 |
+
item['creator'] = j['creator']['username']
|
90 |
+
item['tags'] = j['tags']
|
91 |
+
item['model_name'] = model['name']
|
92 |
+
item['base_model'] = model['baseModel']
|
93 |
+
item['dl_url'] = model['downloadUrl']
|
94 |
+
item['md'] = f'<img src="{model["images"][0]["url"]}" alt="thumbnail" width="150" height="240"><br>[LoRA Model URL](https://civitai.com/models/{j["id"]})'
|
95 |
+
items.append(item)
|
96 |
+
return items
|
97 |
+
|
98 |
+
|
99 |
+
civitai_last_results = {}
|
100 |
+
|
101 |
+
|
102 |
+
def search_civitai(query, types, base_model=[], sort=CIVITAI_SORT[0], period=CIVITAI_PERIOD[0], tag=""):
|
103 |
+
global civitai_last_results
|
104 |
+
items = search_on_civitai(query, types, base_model, 100, sort, period, tag)
|
105 |
+
if not items: return gr.update(choices=[("", "")], value="", visible=False),\
|
106 |
+
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
107 |
+
civitai_last_results = {}
|
108 |
+
choices = []
|
109 |
+
for item in items:
|
110 |
+
base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
|
111 |
+
name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
|
112 |
+
value = item['dl_url']
|
113 |
+
choices.append((name, value))
|
114 |
+
civitai_last_results[value] = item
|
115 |
+
if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
|
116 |
+
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
117 |
+
result = civitai_last_results.get(choices[0][1], "None")
|
118 |
+
md = result['md'] if result else ""
|
119 |
+
return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
|
120 |
+
gr.update(visible=True), gr.update(visible=True)
|
121 |
+
|
122 |
+
|
123 |
+
def select_civitai_item(search_result):
|
124 |
+
if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
|
125 |
+
result = civitai_last_results.get(search_result, "None")
|
126 |
+
md = result['md'] if result else ""
|
127 |
+
return gr.update(value=search_result), gr.update(value=md, visible=True)
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
git-lfs aria2
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
huggingface-hub
|
2 |
+
gdown
|
utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import HfApi, HfFolder, hf_hub_download
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
import gc
|
7 |
+
import re
|
8 |
+
import urllib.parse
|
9 |
+
|
10 |
+
|
11 |
+
def get_token():
|
12 |
+
try:
|
13 |
+
token = HfFolder.get_token()
|
14 |
+
except Exception:
|
15 |
+
token = ""
|
16 |
+
return token
|
17 |
+
|
18 |
+
|
19 |
+
def set_token(token):
|
20 |
+
try:
|
21 |
+
HfFolder.save_token(token)
|
22 |
+
except Exception:
|
23 |
+
print(f"Error: Failed to save token.")
|
24 |
+
|
25 |
+
|
26 |
+
def get_user_agent():
|
27 |
+
return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
|
28 |
+
|
29 |
+
|
30 |
+
def is_repo_exists(repo_id: str, repo_type: str="model"):
|
31 |
+
hf_token = get_token()
|
32 |
+
api = HfApi(token=hf_token)
|
33 |
+
try:
|
34 |
+
if api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return True
|
35 |
+
else: return False
|
36 |
+
except Exception as e:
|
37 |
+
print(f"Error: Failed to connect {repo_id} ({repo_type}). {e}")
|
38 |
+
return True # for safe
|
39 |
+
|
40 |
+
|
41 |
+
MODEL_TYPE_CLASS = {
|
42 |
+
"diffusers:StableDiffusionPipeline": "SD 1.5",
|
43 |
+
"diffusers:StableDiffusionXLPipeline": "SDXL",
|
44 |
+
"diffusers:FluxPipeline": "FLUX",
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
def get_model_type(repo_id: str):
|
49 |
+
hf_token = get_token()
|
50 |
+
api = HfApi(token=hf_token)
|
51 |
+
lora_filename = "pytorch_lora_weights.safetensors"
|
52 |
+
diffusers_filename = "model_index.json"
|
53 |
+
default = "SDXL"
|
54 |
+
try:
|
55 |
+
if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA"
|
56 |
+
if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None"
|
57 |
+
model = api.model_info(repo_id=repo_id, token=hf_token)
|
58 |
+
tags = model.tags
|
59 |
+
for tag in tags:
|
60 |
+
if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
|
61 |
+
except Exception:
|
62 |
+
return default
|
63 |
+
return default
|
64 |
+
|
65 |
+
|
66 |
+
def list_sub(a, b):
|
67 |
+
return [e for e in a if e not in b]
|
68 |
+
|
69 |
+
|
70 |
+
def is_repo_name(s):
|
71 |
+
return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
|
72 |
+
|
73 |
+
|
74 |
+
def split_hf_url(url: str):
|
75 |
+
try:
|
76 |
+
s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.safetensors)(?:\?download=true)?$', url)[0])
|
77 |
+
if len(s) < 4: return "", "", "", ""
|
78 |
+
repo_id = s[1]
|
79 |
+
repo_type = "dataset" if s[0] == "datasets" else "model"
|
80 |
+
subfolder = urllib.parse.unquote(s[2]) if s[2] else None
|
81 |
+
filename = urllib.parse.unquote(s[3])
|
82 |
+
return repo_id, filename, subfolder, repo_type
|
83 |
+
except Exception as e:
|
84 |
+
print(e)
|
85 |
+
|
86 |
+
|
87 |
+
def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
|
88 |
+
hf_token = get_token()
|
89 |
+
repo_id, filename, subfolder, repo_type = split_hf_url(url)
|
90 |
+
try:
|
91 |
+
if subfolder is not None: hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
|
92 |
+
else: hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
|
93 |
+
except Exception as e:
|
94 |
+
print(f"Failed to download: {e}")
|
95 |
+
|
96 |
+
|
97 |
+
def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
|
98 |
+
hf_token = get_token()
|
99 |
+
url = url.strip()
|
100 |
+
if "drive.google.com" in url:
|
101 |
+
original_dir = os.getcwd()
|
102 |
+
os.chdir(directory)
|
103 |
+
os.system(f"gdown --fuzzy {url}")
|
104 |
+
os.chdir(original_dir)
|
105 |
+
elif "huggingface.co" in url:
|
106 |
+
url = url.replace("?download=true", "")
|
107 |
+
if "/blob/" in url:
|
108 |
+
url = url.replace("/blob/", "/resolve/")
|
109 |
+
#user_header = f'"Authorization: Bearer {hf_token}"'
|
110 |
+
if hf_token:
|
111 |
+
download_hf_file(directory, url)
|
112 |
+
#os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
113 |
+
else:
|
114 |
+
os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
115 |
+
elif "civitai.com" in url:
|
116 |
+
if "?" in url:
|
117 |
+
url = url.split("?")[0]
|
118 |
+
if civitai_api_key:
|
119 |
+
url = url + f"?token={civitai_api_key}"
|
120 |
+
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
121 |
+
else:
|
122 |
+
print("You need an API key to download Civitai models.")
|
123 |
+
else:
|
124 |
+
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
125 |
+
|
126 |
+
|
127 |
+
def get_local_model_list(dir_path):
|
128 |
+
model_list = []
|
129 |
+
valid_extensions = ('.safetensors')
|
130 |
+
for file in Path(dir_path).glob("**/*.*"):
|
131 |
+
if file.is_file() and file.suffix in valid_extensions:
|
132 |
+
file_path = str(file)
|
133 |
+
model_list.append(file_path)
|
134 |
+
return model_list
|
135 |
+
|
136 |
+
|
137 |
+
def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
|
138 |
+
if not "http" in url and is_repo_name(url) and not Path(url).exists():
|
139 |
+
print(f"Use HF Repo: {url}")
|
140 |
+
new_file = url
|
141 |
+
elif not "http" in url and Path(url).exists():
|
142 |
+
print(f"Use local file: {url}")
|
143 |
+
new_file = url
|
144 |
+
elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
|
145 |
+
print(f"File to download alreday exists: {url}")
|
146 |
+
new_file = f"{temp_dir}/{url.split('/')[-1]}"
|
147 |
+
else:
|
148 |
+
print(f"Start downloading: {url}")
|
149 |
+
before = get_local_model_list(temp_dir)
|
150 |
+
try:
|
151 |
+
download_thing(temp_dir, url.strip(), civitai_key)
|
152 |
+
except Exception:
|
153 |
+
print(f"Download failed: {url}")
|
154 |
+
return ""
|
155 |
+
after = get_local_model_list(temp_dir)
|
156 |
+
new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
|
157 |
+
if not new_file:
|
158 |
+
print(f"Download failed: {url}")
|
159 |
+
return ""
|
160 |
+
print(f"Download completed: {url}")
|
161 |
+
return new_file
|