|
import io |
|
import os |
|
import sys |
|
import argparse |
|
o_path = os.getcwd() |
|
sys.path.append(o_path) |
|
|
|
import torch |
|
import time |
|
import json |
|
import numpy as np |
|
import imageio |
|
import torchvision |
|
from einops import rearrange |
|
|
|
from models.autoencoder_kl import AutoencoderKL |
|
from models.unet import UNet3DVSRModel |
|
from models.pipeline_stable_diffusion_upscale_video_3d import StableDiffusionUpscalePipeline |
|
from diffusers import DDIMScheduler |
|
from omegaconf import OmegaConf |
|
|
|
|
|
def main(args) |
|
|
|
device = "cuda" |
|
|
|
|
|
pipeline = StableDiffusionUpscalePipeline.from_pretrained(args.pretrained_path + '/stable-diffusion-x4-upscaler', torch_dtype=torch.float16) |
|
|
|
|
|
pipeline.vae = AutoencoderKL.from_config("configs/vae_config.json") |
|
pretrained_model = args.pretrained_path + "/stable-diffusion-x4-upscaler/vae/diffusion_pytorch_model.bin" |
|
pipeline.vae.load_state_dict(torch.load(pretrained_model, map_location="cpu")) |
|
|
|
|
|
config_path = "./configs/unet_3d_config.json" |
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
config['video_condition'] = False |
|
pipeline.unet = UNet3DVSRModel.from_config(config) |
|
|
|
pretrained_model = args.pretrained_path + "/lavie_vsr.pt" |
|
checkpoint = torch.load(pretrained_model, map_location="cpu")['ema'] |
|
|
|
pipeline.unet.load_state_dict(checkpoint, True) |
|
pipeline.unet = pipeline.unet.half() |
|
pipeline.unet.eval() |
|
|
|
|
|
with open(args.pretrained_path + '/stable-diffusion-x4-upscaler/scheduler/scheduler_config.json', "r") as f: |
|
config = json.load(f) |
|
config["beta_schedule"] = "linear" |
|
pipeline.scheduler = DDIMScheduler.from_config(config) |
|
|
|
pipeline = pipeline.to("cuda") |
|
|
|
|
|
|
|
video_root = args.input_path |
|
video_list = sorted(os.listdir(video_root)) |
|
print('video num:', len(video_list)) |
|
|
|
|
|
save_root = args.output_path |
|
os.makedirs(save_root, exist_ok=True) |
|
|
|
|
|
noise_level = args.noise_level |
|
guidance_scale = args.guidance_scale |
|
num_inference_steps = args.inference_steps |
|
|
|
|
|
for i, video_name in enumerate(video_list): |
|
video_name = video_name.replace('.mp4', '') |
|
print(f'[{i+1}/{len(video_list)}]: ', video_name) |
|
|
|
lr_path = f"{video_root}/{video_name}.mp4" |
|
save_path = f"{save_root}/{video_name}.mp4" |
|
|
|
prompt = video_name |
|
print('Prompt: ', prompt) |
|
|
|
negative_prompt = "blur, worst quality" |
|
|
|
vframes, aframes, info = torchvision.io.read_video(filename=lr_path, pts_unit='sec', output_format='TCHW') |
|
vframes = vframes / 255. |
|
vframes = (vframes - 0.5) * 2 |
|
t, _, h, w = vframes.shape |
|
vframes = vframes.unsqueeze(dim=0) |
|
vframes = rearrange(vframes, 'b t c h w -> b c t h w').contiguous() |
|
print('Input_shape:', vframes.shape, 'Noise_level:', noise_level, 'Guidance_scale:', guidance_scale) |
|
|
|
fps = info['video_fps'] |
|
generator = torch.Generator(device=device).manual_seed(10) |
|
|
|
torch.cuda.synchronize() |
|
start_time = time.time() |
|
|
|
with torch.no_grad(): |
|
short_seq = 8 |
|
vframes_seq = vframes.shape[2] |
|
if vframes_seq > short_seq: |
|
upscaled_video_list = [] |
|
for start_f in range(0, vframes_seq, short_seq): |
|
print(f'Processing: [{start_f}-{start_f + short_seq}/{vframes_seq}]') |
|
torch.cuda.empty_cache() |
|
end_f = min(vframes_seq, start_f + short_seq) |
|
|
|
upscaled_video_ = pipeline( |
|
prompt, |
|
image=vframes[:,:,start_f:end_f], |
|
generator=generator, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
noise_level=noise_level, |
|
negative_prompt=negative_prompt, |
|
).images |
|
upscaled_video_list.append(upscaled_video_) |
|
upscaled_video = torch.cat(upscaled_video_list, dim=0) |
|
else: |
|
upscaled_video = pipeline( |
|
prompt, |
|
image=vframes, |
|
generator=generator, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
noise_level=noise_level, |
|
negative_prompt=negative_prompt, |
|
).images |
|
|
|
torch.cuda.synchronize() |
|
run_time = time.time() - start_time |
|
|
|
print('Output:', upscaled_video.shape) |
|
|
|
|
|
upscaled_video = (upscaled_video / 2 + 0.5).clamp(0, 1) * 255 |
|
upscaled_video = upscaled_video.permute(0, 2, 3, 1).to(torch.uint8) |
|
upscaled_video = upscaled_video.numpy().astype(np.uint8) |
|
imageio.mimwrite(save_path, upscaled_video, fps=fps, quality=9) |
|
|
|
print(f'Save upscaled video "{video_name}" in {save_path}, time (sec): {run_time} \n') |
|
print(f'\nAll results are saved in {save_path}') |
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", type=str, default="") |
|
args = parser.parse_args() |
|
|
|
main(OmegaConf.load(args.config)) |
|
|