Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
from torch.optim import AdamW | |
from diffusers import StableDiffusionPipeline | |
from torch import autocast, inference_mode | |
import torch | |
import numpy as np | |
from scheduling_ddim import DDIMScheduler | |
device = 'cuda' | |
# don't forget to add your token or comment if already logged in | |
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", | |
scheduler=DDIMScheduler(beta_end=0.012, | |
beta_schedule="scaled_linear", | |
beta_start=0.00085), | |
use_auth_token="").to(device) | |
_ = pipe.vae.requires_grad_(False) | |
_ = pipe.text_encoder.requires_grad_(False) | |
_ = pipe.unet.requires_grad_(False) | |
def preprocess(image): | |
w, h = image.size | |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
image = image.resize((w, h), resample=Image.LANCZOS) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image[None].transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image) | |
return 2.0 * image - 1.0 | |
def im2latent(pipe, im, generator): | |
init_image = preprocess(im).to(pipe.device) | |
init_latent_dist = pipe.vae.encode(init_image).latent_dist | |
init_latents = init_latent_dist.sample(generator=generator) | |
return init_latents * 0.18215 | |
def image_mod(init_image, source_prompt, prompt, scale, steps, seed): | |
# fix seed | |
g = torch.Generator(device=pipe.device).manual_seed(84) | |
image_latents = im2latent(pipe, init_image, g) | |
pipe.scheduler.set_timesteps(steps) | |
# use text describing an image | |
# source_prompt = "a photo of a woman" | |
context = pipe._encode_prompt(source_prompt, pipe.device, 1, False, "") | |
decoded_latents = image_latents.clone() | |
with autocast(device), inference_mode(): | |
# we are pivoting timesteps as we are moving in opposite direction | |
timesteps = pipe.scheduler.timesteps.flip(0) | |
# this would be our targets for pivoting | |
init_trajectory = torch.empty(len(timesteps), *decoded_latents.size()[1:], device=decoded_latents.device, dtype=decoded_latents.dtype) | |
for i, t in enumerate(tqdm(timesteps)): | |
init_trajectory[i:i+1] = decoded_latents | |
noise_pred = pipe.unet(decoded_latents, t, encoder_hidden_states=context).sample | |
decoded_latents = pipe.scheduler.reverse_step(noise_pred, t, decoded_latents).next_sample | |
# we would need to flip trajectory values for pivoting in right direction | |
init_trajectory = init_trajectory.cpu().flip(0) | |
latents = decoded_latents.clone() | |
context_uncond = pipe._encode_prompt("", pipe.device, 1, False, "") | |
# we will be optimizing uncond text embedding | |
context_uncond.requires_grad_(True) | |
# use same text | |
# prompt = "a photo of a woman" | |
context_cond = pipe._encode_prompt(prompt, pipe.device, 1, False, "") | |
# default lr works | |
opt = AdamW([context_uncond]) | |
# concat latents for classifier-free guidance | |
latents = torch.cat([latents, latents]) | |
latents.requires_grad_(True) | |
context = torch.cat((context_uncond, context_cond)) | |
with autocast(device): | |
for i, t in enumerate(tqdm(pipe.scheduler.timesteps)): | |
latents = pipe.scheduler.scale_model_input(latents, t) | |
uncond, cond = pipe.unet(latents, t, encoder_hidden_states=context).sample.chunk(2) | |
with torch.enable_grad(): | |
latents = pipe.scheduler.step(uncond + scale * (cond - uncond), t, latents, generator=g).prev_sample | |
opt.zero_grad() | |
# optimize uncond text emb | |
pivot_value = init_trajectory[[i]].to(pipe.device) | |
(latents - pivot_value).mean().backward() | |
opt.step() | |
latents = latents.detach() | |
images = pipe.decode_latents(latents) | |
im = pipe.numpy_to_pil(images)[0] | |
return im | |
demo = gr.Interface( | |
image_mod, | |
inputs=[gr.Image(type="pil"), gr.Textbox("a photo of a person"), gr.Textbox("a photo of a person"), gr.Slider(0, 10, 0.5, 0.1), gr.Slider(0, 100, 51, 1), gr.Number(42)], | |
outputs="image", | |
flagging_options=["blurry", "incorrect", "other"], examples=[ | |
os.path.join(os.path.dirname(__file__), "images/00001.jpg"), | |
]) | |
if __name__ == "__main__": | |
demo.launch() | |