import spaces import gradio as gr import torch from PIL import Image from pathlib import Path import gc import subprocess from env import num_cns, model_trigger subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) subprocess.run('pip cache purge', shell=True) device = "cuda" if torch.cuda.is_available() else "cpu" torch.set_grad_enabled(False) control_images = [None] * num_cns control_modes = [-1] * num_cns control_scales = [0] * num_cns def is_repo_name(s): import re return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s) def is_repo_exists(repo_id): from huggingface_hub import HfApi api = HfApi() try: if api.repo_exists(repo_id=repo_id): return True else: return False except Exception as e: print(f"Error: Failed to connect {repo_id}.") print(e) return True # for safe from translatepy import Translator translator = Translator() def translate_to_en(input: str): try: output = str(translator.translate(input, 'English')) except Exception as e: output = input print(e) return output def clear_cache(): try: torch.cuda.empty_cache() #torch.cuda.reset_max_memory_allocated() #torch.cuda.reset_peak_memory_stats() gc.collect() except Exception as e: print(e) raise Exception(f"Cache clearing error: {e}") from e def deselect_lora(): selected_index = None new_placeholder = "Type a prompt" updated_text = "" width = 1024 height = 1024 return ( gr.update(placeholder=new_placeholder), updated_text, selected_index, width, height, ) def get_repo_safetensors(repo_id: str): from huggingface_hub import HfApi api = HfApi() try: if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[]) files = api.list_repo_files(repo_id=repo_id) except Exception as e: print(f"Error: Failed to get {repo_id}'s info.") print(e) gr.Warning(f"Error: Failed to get {repo_id}'s info.") return gr.update(choices=[]) files = [f for f in files if f.endswith(".safetensors")] if len(files) == 0: return gr.update(value="", choices=[]) else: return gr.update(value=files[0], choices=files) def expand2square(pil_img: Image.Image, background_color: tuple=(0, 0, 0)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result # https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny/blob/main/app.py def resize_image(image, target_width, target_height, crop=True): from image_datasets.canny_dataset import c_crop if crop: image = c_crop(image) # Crop the image to square original_width, original_height = image.size # Resize to match the target size without stretching scale = max(target_width / original_width, target_height / original_height) resized_width = int(scale * original_width) resized_height = int(scale * original_height) image = image.resize((resized_width, resized_height), Image.LANCZOS) # Center crop to match the target dimensions left = (resized_width - target_width) // 2 top = (resized_height - target_height) // 2 image = image.crop((left, top, left + target_width, top + target_height)) else: image = image.resize((target_width, target_height), Image.LANCZOS) return image # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union/blob/main/app.py # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union controlnet_union_modes = { "None": -1, #"scribble_hed": 0, "canny": 0, # supported "mlsd": 0, #supported "tile": 1, #supported "depth_midas": 2, # supported "blur": 3, # supported "openpose": 4, # supported "gray": 5, # supported "low_quality": 6, # supported } # https://github.com/pytorch/pytorch/issues/123834 def get_control_params(): from diffusers.utils import load_image modes = [] images = [] scales = [] for i, mode in enumerate(control_modes): if mode == -1 or control_images[i] is None: continue modes.append(control_modes[i]) images.append(load_image(control_images[i])) scales.append(control_scales[i]) return modes, images, scales from preprocessor import Preprocessor def preprocess_image(image: Image.Image, control_mode: str, height: int, width: int, preprocess_resolution: int): if control_mode == "None": return image image_resolution = max(width, height) image_before = resize_image(expand2square(image.convert("RGB")), image_resolution, image_resolution, False) # generated control_ print("start to generate control image") preprocessor = Preprocessor() if control_mode == "depth_midas": preprocessor.load("Midas") control_image = preprocessor( image=image_before, image_resolution=image_resolution, detect_resolution=preprocess_resolution, ) if control_mode == "openpose": preprocessor.load("Openpose") control_image = preprocessor( image=image_before, hand_and_face=True, image_resolution=image_resolution, detect_resolution=preprocess_resolution, ) if control_mode == "canny": preprocessor.load("Canny") control_image = preprocessor( image=image_before, image_resolution=image_resolution, detect_resolution=preprocess_resolution, ) if control_mode == "mlsd": preprocessor.load("MLSD") control_image = preprocessor( image=image_before, image_resolution=image_resolution, detect_resolution=preprocess_resolution, ) if control_mode == "scribble_hed": preprocessor.load("HED") control_image = preprocessor( image=image_before, image_resolution=image_resolution, detect_resolution=preprocess_resolution, ) if control_mode == "low_quality" or control_mode == "gray" or control_mode == "blur" or control_mode == "tile": control_image = image_before image_width = 768 image_height = 768 else: # make sure control image size is same as resized_image image_width, image_height = control_image.size image_after = resize_image(control_image, width, height, False) ref_width, ref_height = image.size print(f"generate control image success: {ref_width}x{ref_height} => {image_width}x{image_height}") return image_after def get_control_union_mode(): return list(controlnet_union_modes.keys()) def set_control_union_mode(i: int, mode: str, scale: str): global control_modes global control_scales control_modes[i] = controlnet_union_modes.get(mode, 0) control_scales[i] = scale if mode != "None": return True else: return gr.update(visible=True) def set_control_union_image(i: int, mode: str, image: Image.Image | None, height: int, width: int, preprocess_resolution: int): global control_images if image is None: return None control_images[i] = preprocess_image(image, mode, height, width, preprocess_resolution) return control_images[i] def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str): lorajson[i]["name"] = str(name) if name != "None" else "" lorajson[i]["scale"] = float(scale) lorajson[i]["filename"] = str(filename) lorajson[i]["trigger"] = str(trigger) return lorajson def is_valid_lora(lorajson: list[dict]): valid = False for d in lorajson: if "name" in d.keys() and d["name"] and d["name"] != "None": valid = True return valid def get_trigger_word(lorajson: list[dict]): trigger = "" for d in lorajson: if "name" in d.keys() and d["name"] and d["name"] != "None" and d["trigger"]: trigger += ", " + d["trigger"] return trigger def get_model_trigger(model_name: str): trigger = "" if model_name in model_trigger.keys(): trigger += ", " + model_trigger[model_name] return trigger # https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora # https://github.com/huggingface/diffusers/issues/4919 def fuse_loras(pipe, lorajson: list[dict]): try: if not lorajson or not isinstance(lorajson, list): return a_list = [] w_list = [] for d in lorajson: if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None": continue k = d["name"] if is_repo_name(k) and is_repo_exists(k): a_name = Path(k).stem pipe.load_lora_weights(k, weight_name=d["filename"], adapter_name = a_name) elif not Path(k).exists(): print(f"LoRA not found: {k}") continue else: w_name = Path(k).name a_name = Path(k).stem pipe.load_lora_weights(k, weight_name = w_name, adapter_name = a_name) a_list.append(a_name) w_list.append(d["scale"]) if not a_list: return pipe.set_adapters(a_list, adapter_weights=w_list) pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0) #pipe.unload_lora_weights() except Exception as e: print(f"External LoRA Error: {e}") raise Exception(f"External LoRA Error: {e}") from e def description_ui(): gr.Markdown( """ - Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer), [jiuface/FLUX.1-dev-Controlnet-Union](https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union), [DamarJati/FLUX.1-DEV-Canny](https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny), [gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator). """ ) from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM def load_prompt_enhancer(): try: model_checkpoint = "gokaygokay/Flux-Prompt-Enhance" tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device) enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device) except Exception as e: print(e) enhancer_flux = None return enhancer_flux enhancer_flux = load_prompt_enhancer() @spaces.GPU(duration=30) def enhance_prompt(input_prompt): result = enhancer_flux("enhance prompt: " + translate_to_en(input_prompt), max_length = 256) enhanced_text = result[0]['generated_text'] return enhanced_text def save_image(image, savefile, modelname, prompt, height, width, steps, cfg, seed): import uuid from PIL import Image, PngImagePlugin import json try: if savefile is None: savefile = f"{modelname.split('/')[-1]}_{str(uuid.uuid4())}.png" metadata = {"prompt": prompt, "Model": {"Model": modelname.split("/")[-1]}} metadata["num_inference_steps"] = steps metadata["guidance_scale"] = cfg metadata["seed"] = seed metadata["resolution"] = f"{width} x {height}" metadata_str = json.dumps(metadata) info = PngImagePlugin.PngInfo() info.add_text("metadata", metadata_str) image.save(savefile, "PNG", pnginfo=info) return str(Path(savefile).resolve()) except Exception as e: print(f"Failed to save image file: {e}") raise Exception(f"Failed to save image file:") from e load_prompt_enhancer.zerogpu = True fuse_loras.zerogpu = True preprocess_image.zerogpu = True get_control_params.zerogpu = True clear_cache.zerogpu = True