Spaces:
Running
on
A10G
Running
on
A10G
import io | |
import os | |
import sys | |
import argparse | |
o_path = os.getcwd() | |
sys.path.append(o_path) | |
import torch | |
import time | |
import json | |
import numpy as np | |
import imageio | |
import torchvision | |
from einops import rearrange | |
from models.autoencoder_kl import AutoencoderKL | |
from models.unet import UNet3DVSRModel | |
from models.pipeline_stable_diffusion_upscale_video_3d import StableDiffusionUpscalePipeline | |
from diffusers import DDIMScheduler | |
from omegaconf import OmegaConf | |
def main(args) | |
device = "cuda" | |
# ---------------------- load models ---------------------- | |
pipeline = StableDiffusionUpscalePipeline.from_pretrained(args.pretrained_path + '/stable-diffusion-x4-upscaler', torch_dtype=torch.float16) | |
# vae | |
pipeline.vae = AutoencoderKL.from_config("configs/vae_config.json") | |
pretrained_model = args.pretrained_path + "/stable-diffusion-x4-upscaler/vae/diffusion_pytorch_model.bin" | |
pipeline.vae.load_state_dict(torch.load(pretrained_model, map_location="cpu")) | |
# unet | |
config_path = "./configs/unet_3d_config.json" | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
config['video_condition'] = False | |
pipeline.unet = UNet3DVSRModel.from_config(config) | |
pretrained_model = args.pretrained_path + "/lavie_vsr.pt" | |
checkpoint = torch.load(pretrained_model, map_location="cpu")['ema'] | |
pipeline.unet.load_state_dict(checkpoint, True) | |
pipeline.unet = pipeline.unet.half() | |
pipeline.unet.eval() # important! | |
# DDIMScheduler | |
with open(args.pretrained_path + '/stable-diffusion-x4-upscaler/scheduler/scheduler_config.json', "r") as f: | |
config = json.load(f) | |
config["beta_schedule"] = "linear" | |
pipeline.scheduler = DDIMScheduler.from_config(config) | |
pipeline = pipeline.to("cuda") | |
# ---------------------- load user's prompt ---------------------- | |
# input | |
video_root = args.input_path | |
video_list = sorted(os.listdir(video_root)) | |
print('video num:', len(video_list)) | |
# output | |
save_root = args.output_path | |
os.makedirs(save_root, exist_ok=True) | |
# inference params | |
noise_level = args.noise_level | |
guidance_scale = args.guidance_scale | |
num_inference_steps = args.inference_steps | |
# ---------------------- start inferencing ---------------------- | |
for i, video_name in enumerate(video_list): | |
video_name = video_name.replace('.mp4', '') | |
print(f'[{i+1}/{len(video_list)}]: ', video_name) | |
lr_path = f"{video_root}/{video_name}.mp4" | |
save_path = f"{save_root}/{video_name}.mp4" | |
prompt = video_name | |
print('Prompt: ', prompt) | |
negative_prompt = "blur, worst quality" | |
vframes, aframes, info = torchvision.io.read_video(filename=lr_path, pts_unit='sec', output_format='TCHW') # RGB | |
vframes = vframes / 255. | |
vframes = (vframes - 0.5) * 2 # T C H W [-1, 1] | |
t, _, h, w = vframes.shape | |
vframes = vframes.unsqueeze(dim=0) # 1 T C H W | |
vframes = rearrange(vframes, 'b t c h w -> b c t h w').contiguous() # 1 C T H W | |
print('Input_shape:', vframes.shape, 'Noise_level:', noise_level, 'Guidance_scale:', guidance_scale) | |
fps = info['video_fps'] | |
generator = torch.Generator(device=device).manual_seed(10) | |
torch.cuda.synchronize() | |
start_time = time.time() | |
with torch.no_grad(): | |
short_seq = 8 | |
vframes_seq = vframes.shape[2] | |
if vframes_seq > short_seq: # for VSR | |
upscaled_video_list = [] | |
for start_f in range(0, vframes_seq, short_seq): | |
print(f'Processing: [{start_f}-{start_f + short_seq}/{vframes_seq}]') | |
torch.cuda.empty_cache() # delete for VSR | |
end_f = min(vframes_seq, start_f + short_seq) | |
upscaled_video_ = pipeline( | |
prompt, | |
image=vframes[:,:,start_f:end_f], | |
generator=generator, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
noise_level=noise_level, | |
negative_prompt=negative_prompt, | |
).images # T C H W [-1, 1] | |
upscaled_video_list.append(upscaled_video_) | |
upscaled_video = torch.cat(upscaled_video_list, dim=0) | |
else: | |
upscaled_video = pipeline( | |
prompt, | |
image=vframes, | |
generator=generator, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
noise_level=noise_level, | |
negative_prompt=negative_prompt, | |
).images # T C H W [-1, 1] | |
torch.cuda.synchronize() | |
run_time = time.time() - start_time | |
print('Output:', upscaled_video.shape) | |
# save video | |
upscaled_video = (upscaled_video / 2 + 0.5).clamp(0, 1) * 255 | |
upscaled_video = upscaled_video.permute(0, 2, 3, 1).to(torch.uint8) | |
upscaled_video = upscaled_video.numpy().astype(np.uint8) | |
imageio.mimwrite(save_path, upscaled_video, fps=fps, quality=9) # Highest quality is 10, lowest is 0 | |
print(f'Save upscaled video "{video_name}" in {save_path}, time (sec): {run_time} \n') | |
print(f'\nAll results are saved in {save_path}') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="") | |
args = parser.parse_args() | |
main(OmegaConf.load(args.config)) | |