Georgiy Grigorev
Create app.py
054082d
raw
history blame
4.55 kB
import gradio as gr
import os
from torch.optim import AdamW
from diffusers import StableDiffusionPipeline
from torch import autocast, inference_mode
import torch
import numpy as np
from scheduling_ddim import DDIMScheduler
device = 'cuda'
# don't forget to add your token or comment if already logged in
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
scheduler=DDIMScheduler(beta_end=0.012,
beta_schedule="scaled_linear",
beta_start=0.00085),
use_auth_token="").to(device)
_ = pipe.vae.requires_grad_(False)
_ = pipe.text_encoder.requires_grad_(False)
_ = pipe.unet.requires_grad_(False)
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def im2latent(pipe, im, generator):
init_image = preprocess(im).to(pipe.device)
init_latent_dist = pipe.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
return init_latents * 0.18215
def image_mod(init_image, source_prompt, prompt, scale, steps, seed):
# fix seed
g = torch.Generator(device=pipe.device).manual_seed(84)
image_latents = im2latent(pipe, init_image, g)
pipe.scheduler.set_timesteps(steps)
# use text describing an image
# source_prompt = "a photo of a woman"
context = pipe._encode_prompt(source_prompt, pipe.device, 1, False, "")
decoded_latents = image_latents.clone()
with autocast(device), inference_mode():
# we are pivoting timesteps as we are moving in opposite direction
timesteps = pipe.scheduler.timesteps.flip(0)
# this would be our targets for pivoting
init_trajectory = torch.empty(len(timesteps), *decoded_latents.size()[1:], device=decoded_latents.device, dtype=decoded_latents.dtype)
for i, t in enumerate(tqdm(timesteps)):
init_trajectory[i:i+1] = decoded_latents
noise_pred = pipe.unet(decoded_latents, t, encoder_hidden_states=context).sample
decoded_latents = pipe.scheduler.reverse_step(noise_pred, t, decoded_latents).next_sample
# we would need to flip trajectory values for pivoting in right direction
init_trajectory = init_trajectory.cpu().flip(0)
latents = decoded_latents.clone()
context_uncond = pipe._encode_prompt("", pipe.device, 1, False, "")
# we will be optimizing uncond text embedding
context_uncond.requires_grad_(True)
# use same text
# prompt = "a photo of a woman"
context_cond = pipe._encode_prompt(prompt, pipe.device, 1, False, "")
# default lr works
opt = AdamW([context_uncond])
# concat latents for classifier-free guidance
latents = torch.cat([latents, latents])
latents.requires_grad_(True)
context = torch.cat((context_uncond, context_cond))
with autocast(device):
for i, t in enumerate(tqdm(pipe.scheduler.timesteps)):
latents = pipe.scheduler.scale_model_input(latents, t)
uncond, cond = pipe.unet(latents, t, encoder_hidden_states=context).sample.chunk(2)
with torch.enable_grad():
latents = pipe.scheduler.step(uncond + scale * (cond - uncond), t, latents, generator=g).prev_sample
opt.zero_grad()
# optimize uncond text emb
pivot_value = init_trajectory[[i]].to(pipe.device)
(latents - pivot_value).mean().backward()
opt.step()
latents = latents.detach()
images = pipe.decode_latents(latents)
im = pipe.numpy_to_pil(images)[0]
return im
demo = gr.Interface(
image_mod,
inputs=[gr.Image(type="pil"), gr.Textbox("a photo of a person"), gr.Textbox("a photo of a person"), gr.Slider(0, 10, 0.5, 0.1), gr.Slider(0, 100, 51, 1), gr.Number(42)],
outputs="image",
flagging_options=["blurry", "incorrect", "other"], examples=[
os.path.join(os.path.dirname(__file__), "images/00001.jpg"),
])
if __name__ == "__main__":
demo.launch()