import io import os import sys import argparse o_path = os.getcwd() sys.path.append(o_path) import torch import time import json import numpy as np import imageio import torchvision from einops import rearrange from models.autoencoder_kl import AutoencoderKL from models.unet import UNet3DVSRModel from models.pipeline_stable_diffusion_upscale_video_3d import StableDiffusionUpscalePipeline from diffusers import DDIMScheduler from omegaconf import OmegaConf def main(args) device = "cuda" # ---------------------- load models ---------------------- pipeline = StableDiffusionUpscalePipeline.from_pretrained(args.pretrained_path + '/stable-diffusion-x4-upscaler', torch_dtype=torch.float16) # vae pipeline.vae = AutoencoderKL.from_config("configs/vae_config.json") pretrained_model = args.pretrained_path + "/stable-diffusion-x4-upscaler/vae/diffusion_pytorch_model.bin" pipeline.vae.load_state_dict(torch.load(pretrained_model, map_location="cpu")) # unet config_path = "./configs/unet_3d_config.json" with open(config_path, "r") as f: config = json.load(f) config['video_condition'] = False pipeline.unet = UNet3DVSRModel.from_config(config) pretrained_model = args.pretrained_path + "/lavie_vsr.pt" checkpoint = torch.load(pretrained_model, map_location="cpu")['ema'] pipeline.unet.load_state_dict(checkpoint, True) pipeline.unet = pipeline.unet.half() pipeline.unet.eval() # important! # DDIMScheduler with open(args.pretrained_path + '/stable-diffusion-x4-upscaler/scheduler/scheduler_config.json', "r") as f: config = json.load(f) config["beta_schedule"] = "linear" pipeline.scheduler = DDIMScheduler.from_config(config) pipeline = pipeline.to("cuda") # ---------------------- load user's prompt ---------------------- # input video_root = args.input_path video_list = sorted(os.listdir(video_root)) print('video num:', len(video_list)) # output save_root = args.output_path os.makedirs(save_root, exist_ok=True) # inference params noise_level = args.noise_level guidance_scale = args.guidance_scale num_inference_steps = args.inference_steps # ---------------------- start inferencing ---------------------- for i, video_name in enumerate(video_list): video_name = video_name.replace('.mp4', '') print(f'[{i+1}/{len(video_list)}]: ', video_name) lr_path = f"{video_root}/{video_name}.mp4" save_path = f"{save_root}/{video_name}.mp4" prompt = video_name print('Prompt: ', prompt) negative_prompt = "blur, worst quality" vframes, aframes, info = torchvision.io.read_video(filename=lr_path, pts_unit='sec', output_format='TCHW') # RGB vframes = vframes / 255. vframes = (vframes - 0.5) * 2 # T C H W [-1, 1] t, _, h, w = vframes.shape vframes = vframes.unsqueeze(dim=0) # 1 T C H W vframes = rearrange(vframes, 'b t c h w -> b c t h w').contiguous() # 1 C T H W print('Input_shape:', vframes.shape, 'Noise_level:', noise_level, 'Guidance_scale:', guidance_scale) fps = info['video_fps'] generator = torch.Generator(device=device).manual_seed(10) torch.cuda.synchronize() start_time = time.time() with torch.no_grad(): short_seq = 8 vframes_seq = vframes.shape[2] if vframes_seq > short_seq: # for VSR upscaled_video_list = [] for start_f in range(0, vframes_seq, short_seq): print(f'Processing: [{start_f}-{start_f + short_seq}/{vframes_seq}]') torch.cuda.empty_cache() # delete for VSR end_f = min(vframes_seq, start_f + short_seq) upscaled_video_ = pipeline( prompt, image=vframes[:,:,start_f:end_f], generator=generator, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, noise_level=noise_level, negative_prompt=negative_prompt, ).images # T C H W [-1, 1] upscaled_video_list.append(upscaled_video_) upscaled_video = torch.cat(upscaled_video_list, dim=0) else: upscaled_video = pipeline( prompt, image=vframes, generator=generator, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, noise_level=noise_level, negative_prompt=negative_prompt, ).images # T C H W [-1, 1] torch.cuda.synchronize() run_time = time.time() - start_time print('Output:', upscaled_video.shape) # save video upscaled_video = (upscaled_video / 2 + 0.5).clamp(0, 1) * 255 upscaled_video = upscaled_video.permute(0, 2, 3, 1).to(torch.uint8) upscaled_video = upscaled_video.numpy().astype(np.uint8) imageio.mimwrite(save_path, upscaled_video, fps=fps, quality=9) # Highest quality is 10, lowest is 0 print(f'Save upscaled video "{video_name}" in {save_path}, time (sec): {run_time} \n') print(f'\nAll results are saved in {save_path}') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="") args = parser.parse_args() main(OmegaConf.load(args.config))