import os from time import time_ns import gradio as gr import torch from huggingface_hub import Repository from llama_cpp import Llama, LLAMA_SPLIT_MODE_NONE from transformers import LlamaForCausalLM, LlamaTokenizer from kgen.generate import tag_gen from kgen.metainfo import SPECIAL, TARGET MODEL_PATH = "KBlueLeaf/DanTagGen" @torch.no_grad() def get_result( text_model: LlamaForCausalLM, tokenizer: LlamaTokenizer, rating: str = "", artist: str = "", characters: str = "", copyrights: str = "", target: str = "long", special_tags: list[str] = ["1girl"], general: str = "", aspect_ratio: float = 0.0, blacklist: str = "", escape_bracket: bool = False, temperature: float = 1.35, ): start = time_ns() print("=" * 50, "\n") # Use LLM to predict possible summary # This prompt allow model itself to make request longer based on what it learned # Which will be better for preference sim and pref-sum contrastive scorer prompt = f""" rating: {rating or '<|empty|>'} artist: {artist.strip() or '<|empty|>'} characters: {characters.strip() or '<|empty|>'} copyrights: {copyrights.strip() or '<|empty|>'} aspect ratio: {f"{aspect_ratio:.1f}" or '<|empty|>'} target: {'<|' + target + '|>' if target else '<|long|>'} general: {", ".join(special_tags)}, {general.strip().strip(",")}<|input_end|> """.strip() artist = artist.strip().strip(",").replace("_", " ") characters = characters.strip().strip(",").replace("_", " ") copyrights = copyrights.strip().strip(",").replace("_", " ") special_tags = [tag.strip().replace("_", " ") for tag in special_tags] general = general.strip().strip(",") black_list = set( [tag.strip().replace("_", " ") for tag in blacklist.strip().split(",")] ) prompt_tags = special_tags + general.strip().strip(",").split(",") len_target = TARGET[target] llm_gen = "" for llm_gen, extra_tokens in tag_gen( text_model, tokenizer, prompt, prompt_tags, len_target, black_list, temperature=temperature, top_p=0.95, top_k=100, max_new_tokens=256, max_retry=5, ): yield "", llm_gen, f"Total cost time: {(time_ns()-start)/1e9:.2f}s" print() print("-" * 50) general = f"{general.strip().strip(',')}, {','.join(extra_tokens)}" tags = general.strip().split(",") tags = [tag.strip() for tag in tags if tag.strip()] special = special_tags + [tag for tag in tags if tag in SPECIAL] tags = [tag for tag in tags if tag not in special] final_prompt = ", ".join(special) if characters: final_prompt += f", \n\n{characters}" if copyrights: final_prompt += ", " if not characters: final_prompt += "\n\n" final_prompt += copyrights if artist: final_prompt += f", \n\n{artist}" final_prompt += f""", \n\n{', '.join(tags)}, masterpiece, newest, absurdres, {rating}""" print(final_prompt) print("=" * 50) if escape_bracket: final_prompt = ( final_prompt.replace("[", "\\[") .replace("]", "\\]") .replace("(", "\\(") .replace(")", "\\)") ) yield final_prompt, llm_gen, f"Total cost time: {(time_ns()-start)/1e9:.2f}s | Total general tags: {len(special+tags)}" if __name__ == "__main__": tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) text_model = LlamaForCausalLM.from_pretrained(MODEL_PATH) text_model = text_model.eval() def wrapper( rating: str, artist: str, characters: str, copyrights: str, target: str, special_tags: list[str], general: str, width: float, height: float, blacklist: str, escape_bracket: bool, temperature: float = 1.35, ): yield from get_result( text_model, tokenizer, rating, artist, characters, copyrights, target, special_tags, general, width / height, blacklist, escape_bracket, temperature, ) with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Row(): with gr.Column(scale=4): with gr.Row(): with gr.Column(scale=2): rating = gr.Radio( ["safe", "sensitive", "nsfw", "nsfw, explicit"], label="Rating", ) special_tags = gr.Dropdown( SPECIAL, value=["1girl"], label="Special tags", multiselect=True, ) characters = gr.Textbox(label="Characters") copyrights = gr.Textbox(label="Copyrights(Series)") artist = gr.Textbox(label="Artist") target = gr.Radio( ["very_short", "short", "long", "very_long"], label="Target length", ) with gr.Column(scale=2): general = gr.TextArea(label="Input your general tags") black_list = gr.TextArea( label="tag Black list (seperated by comma)" ) with gr.Row(): width = gr.Slider( value=1024, minimum=256, maximum=4096, step=32, label="Width", ) height = gr.Slider( value=1024, minimum=256, maximum=4096, step=32, label="Height", ) with gr.Row(): temperature = gr.Slider( value=1.35, minimum=0.1, maximum=2, step=0.05, label="Temperature", ) escape_bracket = gr.Checkbox( value=False, label="Escape bracket", ) submit = gr.Button("Submit") with gr.Column(scale=3): formated_result = gr.TextArea( label="Final output", lines=14, show_copy_button=True ) llm_result = gr.TextArea(label="LLM output", lines=10) cost_time = gr.Markdown() submit.click( wrapper, inputs=[ rating, artist, characters, copyrights, target, special_tags, general, width, height, black_list, escape_bracket, temperature, ], outputs=[ formated_result, llm_result, cost_time, ], show_progress=True, ) demo.launch()