|
import torch |
|
import numpy as np |
|
import random |
|
import os |
|
|
|
from diffusers.utils import load_image |
|
from diffusers import DDIMScheduler |
|
|
|
from huggingface_hub import hf_hub_download |
|
import spaces |
|
import gradio as gr |
|
|
|
from pipeline import PhotoMakerStableDiffusionXLPipeline |
|
from style_template import styles |
|
|
|
|
|
base_model_path = 'SG161222/RealVisXL_V3.0' |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
MAX_SEED = np.iinfo(np.int32).max |
|
STYLE_NAMES = list(styles.keys()) |
|
DEFAULT_STYLE_NAME = "Photographic (Default)" |
|
|
|
|
|
photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model") |
|
|
|
pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained( |
|
base_model_path, |
|
torch_dtype=torch.bfloat16, |
|
use_safetensors=True, |
|
variant="fp16", |
|
).to(device) |
|
|
|
pipe.load_photomaker_adapter( |
|
os.path.dirname(photomaker_ckpt), |
|
subfolder="", |
|
weight_name=os.path.basename(photomaker_ckpt), |
|
trigger_word="img" |
|
) |
|
pipe.id_encoder.to(device) |
|
|
|
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
|
|
|
pipe.fuse_lora() |
|
|
|
@spaces.GPU |
|
def generate_image(upload_images, prompt, negative_prompt, style_name, num_steps, style_strength_ratio, num_outputs, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)): |
|
|
|
image_token_id = pipe.tokenizer.convert_tokens_to_ids(pipe.trigger_word) |
|
input_ids = pipe.tokenizer.encode(prompt) |
|
if image_token_id not in input_ids: |
|
raise gr.Error(f"Cannot find the trigger word '{pipe.trigger_word}' in text prompt! Please refer to step 2️⃣") |
|
|
|
if input_ids.count(image_token_id) > 1: |
|
raise gr.Error(f"Cannot use multiple trigger words '{pipe.trigger_word}' in text prompt!") |
|
|
|
|
|
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) |
|
|
|
|
|
negative_prompt = f"nsfw, naked, {negative_prompt}" |
|
if upload_images is None: |
|
raise gr.Error(f"Cannot find any input face image! Please refer to step 1️⃣") |
|
|
|
input_id_images = [] |
|
for img in upload_images: |
|
input_id_images.append(load_image(img)) |
|
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
print("Start inference...") |
|
print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") |
|
start_merge_step = int(float(style_strength_ratio) / 100 * num_steps) |
|
if start_merge_step > 30: |
|
start_merge_step = 30 |
|
print(start_merge_step) |
|
images = pipe( |
|
prompt=prompt, |
|
input_id_images=input_id_images, |
|
negative_prompt=negative_prompt, |
|
num_images_per_prompt=num_outputs, |
|
num_inference_steps=num_steps, |
|
start_merge_step=start_merge_step, |
|
generator=generator, |
|
guidance_scale=guidance_scale, |
|
).images |
|
return images, gr.update(visible=True) |
|
|
|
def swap_to_gallery(images): |
|
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) |
|
|
|
def upload_example_to_gallery(images, prompt, style, negative_prompt): |
|
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) |
|
|
|
def remove_back_to_files(): |
|
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) |
|
|
|
def remove_tips(): |
|
return gr.update(visible=False) |
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
return seed |
|
|
|
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]: |
|
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) |
|
return p.replace("{prompt}", positive), n + ' ' + negative |
|
|
|
def get_image_path_list(folder_name): |
|
image_basename_list = os.listdir(folder_name) |
|
image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list]) |
|
return image_path_list |
|
|
|
def get_example(): |
|
case = [ |
|
[ |
|
get_image_path_list('./examples/scarletthead_woman'), |
|
"instagram photo, portrait photo of a woman img, colorful, perfect face, natural skin, hard shadows, film grain", |
|
"(No style)", |
|
"(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth", |
|
], |
|
[ |
|
get_image_path_list('./examples/newton_man'), |
|
"sci-fi, closeup portrait photo of a man img wearing the sunglasses in Iron man suit, face, slim body, high quality, film grain", |
|
"(No style)", |
|
"(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth", |
|
], |
|
] |
|
return case |
|
|
|
|
|
tips = r""" """ |
|
|
|
|
|
css = ''' |
|
.gradio-container {width: 85% !important} |
|
''' |
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(logo) |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
files = gr.File( |
|
label="Drag (Select) 1 or more photos of your face", |
|
file_types=["image"], |
|
file_count="multiple" |
|
) |
|
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200) |
|
with gr.Column(visible=False) as clear_button: |
|
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm") |
|
prompt = gr.Textbox(label="Prompt", |
|
info="Try something like 'a photo of a man/woman img', 'img' is the trigger word.", |
|
placeholder="A photo of a [man/woman img]...") |
|
style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME) |
|
submit = gr.Button("Submit") |
|
|
|
with gr.Accordion(open=False, label="Advanced Options"): |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
placeholder="low quality", |
|
value="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry", |
|
) |
|
num_steps = gr.Slider( |
|
label="Number of sample steps", |
|
minimum=20, |
|
maximum=100, |
|
step=1, |
|
value=50, |
|
) |
|
style_strength_ratio = gr.Slider( |
|
label="Style strength (%)", |
|
minimum=15, |
|
maximum=50, |
|
step=1, |
|
value=20, |
|
) |
|
num_outputs = gr.Slider( |
|
label="Number of output images", |
|
minimum=1, |
|
maximum=4, |
|
step=1, |
|
value=2, |
|
) |
|
guidance_scale = gr.Slider( |
|
label="Guidance scale", |
|
minimum=0.1, |
|
maximum=10.0, |
|
step=0.1, |
|
value=5, |
|
) |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=0, |
|
) |
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
|
with gr.Column(): |
|
gallery = gr.Gallery(label="Generated Images") |
|
usage_tips = gr.Markdown(label="Usage tips of PhotoMaker", value=tips ,visible=False) |
|
|
|
files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files]) |
|
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files]) |
|
|
|
submit.click( |
|
fn=remove_tips, |
|
outputs=usage_tips, |
|
).then( |
|
fn=randomize_seed_fn, |
|
inputs=[seed, randomize_seed], |
|
outputs=seed, |
|
queue=False, |
|
api_name=False, |
|
).then( |
|
fn=generate_image, |
|
inputs=[files, prompt, negative_prompt, style, num_steps, style_strength_ratio, num_outputs, guidance_scale, seed], |
|
outputs=[gallery, usage_tips] |
|
) |
|
|
|
gr.Examples( |
|
examples=get_example(), |
|
inputs=[files, prompt, style, negative_prompt], |
|
run_on_click=True, |
|
fn=upload_example_to_gallery, |
|
outputs=[uploaded_files, clear_button, files], |
|
) |
|
|
|
gr.Markdown(article) |
|
|
|
demo.launch() |