Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from time import time_ns | |
import gradio as gr | |
import torch | |
from huggingface_hub import Repository | |
from llama_cpp import Llama, LLAMA_SPLIT_MODE_NONE | |
from transformers import LlamaForCausalLM, LlamaTokenizer | |
from kgen.generate import tag_gen | |
from kgen.metainfo import SPECIAL, TARGET | |
MODEL_PATH = "KBlueLeaf/DanTagGen" | |
def get_result( | |
text_model: LlamaForCausalLM, | |
tokenizer: LlamaTokenizer, | |
rating: str = "", | |
artist: str = "", | |
characters: str = "", | |
copyrights: str = "", | |
target: str = "long", | |
special_tags: list[str] = ["1girl"], | |
general: str = "", | |
aspect_ratio: float = 0.0, | |
blacklist: str = "", | |
escape_bracket: bool = False, | |
temperature: float = 1.35, | |
): | |
start = time_ns() | |
print("=" * 50, "\n") | |
# Use LLM to predict possible summary | |
# This prompt allow model itself to make request longer based on what it learned | |
# Which will be better for preference sim and pref-sum contrastive scorer | |
prompt = f""" | |
rating: {rating or '<|empty|>'} | |
artist: {artist.strip() or '<|empty|>'} | |
characters: {characters.strip() or '<|empty|>'} | |
copyrights: {copyrights.strip() or '<|empty|>'} | |
aspect ratio: {f"{aspect_ratio:.1f}" or '<|empty|>'} | |
target: {'<|' + target + '|>' if target else '<|long|>'} | |
general: {", ".join(special_tags)}, {general.strip().strip(",")}<|input_end|> | |
""".strip() | |
artist = artist.strip().strip(",").replace("_", " ") | |
characters = characters.strip().strip(",").replace("_", " ") | |
copyrights = copyrights.strip().strip(",").replace("_", " ") | |
special_tags = [tag.strip().replace("_", " ") for tag in special_tags] | |
general = general.strip().strip(",") | |
black_list = set( | |
[tag.strip().replace("_", " ") for tag in blacklist.strip().split(",")] | |
) | |
prompt_tags = special_tags + general.strip().strip(",").split(",") | |
len_target = TARGET[target] | |
llm_gen = "" | |
for llm_gen, extra_tokens in tag_gen( | |
text_model, | |
tokenizer, | |
prompt, | |
prompt_tags, | |
len_target, | |
black_list, | |
temperature=temperature, | |
top_p=0.95, | |
top_k=100, | |
max_new_tokens=256, | |
max_retry=5, | |
): | |
yield "", llm_gen, f"Total cost time: {(time_ns()-start)/1e9:.2f}s" | |
print() | |
print("-" * 50) | |
general = f"{general.strip().strip(',')}, {','.join(extra_tokens)}" | |
tags = general.strip().split(",") | |
tags = [tag.strip() for tag in tags if tag.strip()] | |
special = special_tags + [tag for tag in tags if tag in SPECIAL] | |
tags = [tag for tag in tags if tag not in special] | |
final_prompt = ", ".join(special) | |
if characters: | |
final_prompt += f", \n\n{characters}" | |
if copyrights: | |
final_prompt += ", " | |
if not characters: | |
final_prompt += "\n\n" | |
final_prompt += copyrights | |
if artist: | |
final_prompt += f", \n\n{artist}" | |
final_prompt += f""", \n\n{', '.join(tags)}, | |
masterpiece, newest, absurdres, {rating}""" | |
print(final_prompt) | |
print("=" * 50) | |
if escape_bracket: | |
final_prompt = ( | |
final_prompt.replace("[", "\\[") | |
.replace("]", "\\]") | |
.replace("(", "\\(") | |
.replace(")", "\\)") | |
) | |
yield final_prompt, llm_gen, f"Total cost time: {(time_ns()-start)/1e9:.2f}s | Total general tags: {len(special+tags)}" | |
if __name__ == "__main__": | |
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) | |
text_model = LlamaForCausalLM.from_pretrained(MODEL_PATH) | |
text_model = text_model.eval() | |
def wrapper( | |
rating: str, | |
artist: str, | |
characters: str, | |
copyrights: str, | |
target: str, | |
special_tags: list[str], | |
general: str, | |
width: float, | |
height: float, | |
blacklist: str, | |
escape_bracket: bool, | |
temperature: float = 1.35, | |
): | |
yield from get_result( | |
text_model, | |
tokenizer, | |
rating, | |
artist, | |
characters, | |
copyrights, | |
target, | |
special_tags, | |
general, | |
width / height, | |
blacklist, | |
escape_bracket, | |
temperature, | |
) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
with gr.Row(): | |
with gr.Column(scale=4): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
rating = gr.Radio( | |
["safe", "sensitive", "nsfw", "nsfw, explicit"], | |
label="Rating", | |
) | |
special_tags = gr.Dropdown( | |
SPECIAL, | |
value=["1girl"], | |
label="Special tags", | |
multiselect=True, | |
) | |
characters = gr.Textbox(label="Characters") | |
copyrights = gr.Textbox(label="Copyrights(Series)") | |
artist = gr.Textbox(label="Artist") | |
target = gr.Radio( | |
["very_short", "short", "long", "very_long"], | |
label="Target length", | |
) | |
with gr.Column(scale=2): | |
general = gr.TextArea(label="Input your general tags") | |
black_list = gr.TextArea( | |
label="tag Black list (seperated by comma)" | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
value=1024, | |
minimum=256, | |
maximum=4096, | |
step=32, | |
label="Width", | |
) | |
height = gr.Slider( | |
value=1024, | |
minimum=256, | |
maximum=4096, | |
step=32, | |
label="Height", | |
) | |
with gr.Row(): | |
temperature = gr.Slider( | |
value=1.35, | |
minimum=0.1, | |
maximum=2, | |
step=0.05, | |
label="Temperature", | |
) | |
escape_bracket = gr.Checkbox( | |
value=False, | |
label="Escape bracket", | |
) | |
submit = gr.Button("Submit") | |
with gr.Column(scale=3): | |
formated_result = gr.TextArea( | |
label="Final output", lines=14, show_copy_button=True | |
) | |
llm_result = gr.TextArea(label="LLM output", lines=10) | |
cost_time = gr.Markdown() | |
submit.click( | |
wrapper, | |
inputs=[ | |
rating, | |
artist, | |
characters, | |
copyrights, | |
target, | |
special_tags, | |
general, | |
width, | |
height, | |
black_list, | |
escape_bracket, | |
temperature, | |
], | |
outputs=[ | |
formated_result, | |
llm_result, | |
cost_time, | |
], | |
show_progress=True, | |
) | |
demo.launch() | |