import gradio as gr import os from torch.optim import AdamW from diffusers import StableDiffusionPipeline from torch import autocast, inference_mode import torch import numpy as np from scheduling_ddim import DDIMScheduler device = 'cuda' # don't forget to add your token or comment if already logged in pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=DDIMScheduler(beta_end=0.012, beta_schedule="scaled_linear", beta_start=0.00085), use_auth_token="").to(device) _ = pipe.vae.requires_grad_(False) _ = pipe.text_encoder.requires_grad_(False) _ = pipe.unet.requires_grad_(False) def preprocess(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 def im2latent(pipe, im, generator): init_image = preprocess(im).to(pipe.device) init_latent_dist = pipe.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) return init_latents * 0.18215 def image_mod(init_image, source_prompt, prompt, scale, steps, seed): # fix seed g = torch.Generator(device=pipe.device).manual_seed(84) image_latents = im2latent(pipe, init_image, g) pipe.scheduler.set_timesteps(steps) # use text describing an image # source_prompt = "a photo of a woman" context = pipe._encode_prompt(source_prompt, pipe.device, 1, False, "") decoded_latents = image_latents.clone() with autocast(device), inference_mode(): # we are pivoting timesteps as we are moving in opposite direction timesteps = pipe.scheduler.timesteps.flip(0) # this would be our targets for pivoting init_trajectory = torch.empty(len(timesteps), *decoded_latents.size()[1:], device=decoded_latents.device, dtype=decoded_latents.dtype) for i, t in enumerate(tqdm(timesteps)): init_trajectory[i:i+1] = decoded_latents noise_pred = pipe.unet(decoded_latents, t, encoder_hidden_states=context).sample decoded_latents = pipe.scheduler.reverse_step(noise_pred, t, decoded_latents).next_sample # we would need to flip trajectory values for pivoting in right direction init_trajectory = init_trajectory.cpu().flip(0) latents = decoded_latents.clone() context_uncond = pipe._encode_prompt("", pipe.device, 1, False, "") # we will be optimizing uncond text embedding context_uncond.requires_grad_(True) # use same text # prompt = "a photo of a woman" context_cond = pipe._encode_prompt(prompt, pipe.device, 1, False, "") # default lr works opt = AdamW([context_uncond]) # concat latents for classifier-free guidance latents = torch.cat([latents, latents]) latents.requires_grad_(True) context = torch.cat((context_uncond, context_cond)) with autocast(device): for i, t in enumerate(tqdm(pipe.scheduler.timesteps)): latents = pipe.scheduler.scale_model_input(latents, t) uncond, cond = pipe.unet(latents, t, encoder_hidden_states=context).sample.chunk(2) with torch.enable_grad(): latents = pipe.scheduler.step(uncond + scale * (cond - uncond), t, latents, generator=g).prev_sample opt.zero_grad() # optimize uncond text emb pivot_value = init_trajectory[[i]].to(pipe.device) (latents - pivot_value).mean().backward() opt.step() latents = latents.detach() images = pipe.decode_latents(latents) im = pipe.numpy_to_pil(images)[0] return im demo = gr.Interface( image_mod, inputs=[gr.Image(type="pil"), gr.Textbox("a photo of a person"), gr.Textbox("a photo of a person"), gr.Slider(0, 10, 0.5, 0.1), gr.Slider(0, 100, 51, 1), gr.Number(42)], outputs="image", flagging_options=["blurry", "incorrect", "other"], examples=[ os.path.join(os.path.dirname(__file__), "images/00001.jpg"), ]) if __name__ == "__main__": demo.launch()