LaVie / vsr /sample.py
camenduru's picture
thanks to Vchitect ❤
2f760fe
import io
import os
import sys
import argparse
o_path = os.getcwd()
sys.path.append(o_path)
import torch
import time
import json
import numpy as np
import imageio
import torchvision
from einops import rearrange
from models.autoencoder_kl import AutoencoderKL
from models.unet import UNet3DVSRModel
from models.pipeline_stable_diffusion_upscale_video_3d import StableDiffusionUpscalePipeline
from diffusers import DDIMScheduler
from omegaconf import OmegaConf
def main(args)
device = "cuda"
# ---------------------- load models ----------------------
pipeline = StableDiffusionUpscalePipeline.from_pretrained(args.pretrained_path + '/stable-diffusion-x4-upscaler', torch_dtype=torch.float16)
# vae
pipeline.vae = AutoencoderKL.from_config("configs/vae_config.json")
pretrained_model = args.pretrained_path + "/stable-diffusion-x4-upscaler/vae/diffusion_pytorch_model.bin"
pipeline.vae.load_state_dict(torch.load(pretrained_model, map_location="cpu"))
# unet
config_path = "./configs/unet_3d_config.json"
with open(config_path, "r") as f:
config = json.load(f)
config['video_condition'] = False
pipeline.unet = UNet3DVSRModel.from_config(config)
pretrained_model = args.pretrained_path + "/lavie_vsr.pt"
checkpoint = torch.load(pretrained_model, map_location="cpu")['ema']
pipeline.unet.load_state_dict(checkpoint, True)
pipeline.unet = pipeline.unet.half()
pipeline.unet.eval() # important!
# DDIMScheduler
with open(args.pretrained_path + '/stable-diffusion-x4-upscaler/scheduler/scheduler_config.json', "r") as f:
config = json.load(f)
config["beta_schedule"] = "linear"
pipeline.scheduler = DDIMScheduler.from_config(config)
pipeline = pipeline.to("cuda")
# ---------------------- load user's prompt ----------------------
# input
video_root = args.input_path
video_list = sorted(os.listdir(video_root))
print('video num:', len(video_list))
# output
save_root = args.output_path
os.makedirs(save_root, exist_ok=True)
# inference params
noise_level = args.noise_level
guidance_scale = args.guidance_scale
num_inference_steps = args.inference_steps
# ---------------------- start inferencing ----------------------
for i, video_name in enumerate(video_list):
video_name = video_name.replace('.mp4', '')
print(f'[{i+1}/{len(video_list)}]: ', video_name)
lr_path = f"{video_root}/{video_name}.mp4"
save_path = f"{save_root}/{video_name}.mp4"
prompt = video_name
print('Prompt: ', prompt)
negative_prompt = "blur, worst quality"
vframes, aframes, info = torchvision.io.read_video(filename=lr_path, pts_unit='sec', output_format='TCHW') # RGB
vframes = vframes / 255.
vframes = (vframes - 0.5) * 2 # T C H W [-1, 1]
t, _, h, w = vframes.shape
vframes = vframes.unsqueeze(dim=0) # 1 T C H W
vframes = rearrange(vframes, 'b t c h w -> b c t h w').contiguous() # 1 C T H W
print('Input_shape:', vframes.shape, 'Noise_level:', noise_level, 'Guidance_scale:', guidance_scale)
fps = info['video_fps']
generator = torch.Generator(device=device).manual_seed(10)
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
short_seq = 8
vframes_seq = vframes.shape[2]
if vframes_seq > short_seq: # for VSR
upscaled_video_list = []
for start_f in range(0, vframes_seq, short_seq):
print(f'Processing: [{start_f}-{start_f + short_seq}/{vframes_seq}]')
torch.cuda.empty_cache() # delete for VSR
end_f = min(vframes_seq, start_f + short_seq)
upscaled_video_ = pipeline(
prompt,
image=vframes[:,:,start_f:end_f],
generator=generator,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
noise_level=noise_level,
negative_prompt=negative_prompt,
).images # T C H W [-1, 1]
upscaled_video_list.append(upscaled_video_)
upscaled_video = torch.cat(upscaled_video_list, dim=0)
else:
upscaled_video = pipeline(
prompt,
image=vframes,
generator=generator,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
noise_level=noise_level,
negative_prompt=negative_prompt,
).images # T C H W [-1, 1]
torch.cuda.synchronize()
run_time = time.time() - start_time
print('Output:', upscaled_video.shape)
# save video
upscaled_video = (upscaled_video / 2 + 0.5).clamp(0, 1) * 255
upscaled_video = upscaled_video.permute(0, 2, 3, 1).to(torch.uint8)
upscaled_video = upscaled_video.numpy().astype(np.uint8)
imageio.mimwrite(save_path, upscaled_video, fps=fps, quality=9) # Highest quality is 10, lowest is 0
print(f'Save upscaled video "{video_name}" in {save_path}, time (sec): {run_time} \n')
print(f'\nAll results are saved in {save_path}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="")
args = parser.parse_args()
main(OmegaConf.load(args.config))