import os import time import pdb import gradio as gr import spaces import numpy as np import sys import subprocess from huggingface_hub import snapshot_download import requests import argparse import os from omegaconf import OmegaConf import numpy as np import cv2 import torch import glob import pickle from tqdm import tqdm import copy from argparse import Namespace import shutil import gdown def download_model(): if not os.path.exists(CheckpointsDir): os.makedirs(CheckpointsDir) print("Checkpoint Not Downloaded, start downloading...") tic = time.time() snapshot_download( repo_id="TMElyralab/MuseTalk", local_dir=CheckpointsDir, max_workers=8, local_dir_use_symlinks=True, ) # weight snapshot_download( repo_id="stabilityai/sd-vae-ft-mse", local_dir=CheckpointsDir, max_workers=8, local_dir_use_symlinks=True, ) #dwpose snapshot_download( repo_id="yzd-v/DWPose", local_dir=CheckpointsDir, max_workers=8, local_dir_use_symlinks=True, ) #vae url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt" response = requests.get(url) # 确保请求成功 if response.status_code == 200: # 指定文件保存的位置 file_path = f"{CheckpointsDir}/whisper/tiny.pt" os.makedirs(f"{CheckpointsDir}/whisper/") # 将文件内容写入指定位置 with open(file_path, "wb") as f: f.write(response.content) else: print(f"请求失败,状态码:{response.status_code}") #gdown face parse url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812" os.makedirs(f"{CheckpointsDir}/face-parse-bisent/") file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth" gdown.download(url, output, quiet=False) #resnet url = "https://download.pytorch.org/models/resnet18-5c106cde.pth" response = requests.get(url) # 确保请求成功 if response.status_code == 200: # 指定文件保存的位置 file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth" # 将文件内容写入指定位置 with open(file_path, "wb") as f: f.write(response.content) else: print(f"请求失败,状态码:{response.status_code}") toc = time.time() print(f"download cost {toc-tic} seconds") else: print("Already download the model.") download_model() # for huggingface deployment. from musetalk.utils.utils import get_file_type,get_video_fps,datagen from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder from musetalk.utils.blending import get_image from musetalk.utils.utils import load_all_model ProjectDir = os.path.abspath(os.path.dirname(__file__)) CheckpointsDir = os.path.join(ProjectDir, "checkpoints") @spaces.GPU(duration=600) @torch.no_grad() def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)): args_dict={"result_dir":'./results', "fps":25, "batch_size":8, "output_vid_name":'', "use_saved_coord":False}#same with inferenece script args = Namespace(**args_dict) input_basename = os.path.basename(video_path).split('.')[0] audio_basename = os.path.basename(audio_path).split('.')[0] output_basename = f"{input_basename}_{audio_basename}" result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input os.makedirs(result_img_save_path,exist_ok =True) if args.output_vid_name=="": output_vid_name = os.path.join(args.result_dir, output_basename+".mp4") else: output_vid_name = os.path.join(args.result_dir, args.output_vid_name) ############################################## extract frames from source video ############################################## if get_file_type(video_path)=="video": save_dir_full = os.path.join(args.result_dir, input_basename) os.makedirs(save_dir_full,exist_ok = True) cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" os.system(cmd) input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) fps = get_video_fps(video_path) else: # input img folder input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) fps = args.fps #print(input_img_list) ############################################## extract audio feature ############################################## whisper_feature = audio_processor.audio2feat(audio_path) whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) ############################################## preprocess input image ############################################## if os.path.exists(crop_coord_save_path) and args.use_saved_coord: print("using extracted coordinates") with open(crop_coord_save_path,'rb') as f: coord_list = pickle.load(f) frame_list = read_imgs(input_img_list) else: print("extracting landmarks...time consuming") coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) with open(crop_coord_save_path, 'wb') as f: pickle.dump(coord_list, f) i = 0 input_latent_list = [] for bbox, frame in zip(coord_list, frame_list): if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox crop_frame = frame[y1:y2, x1:x2] crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) latents = vae.get_latents_for_unet(crop_frame) input_latent_list.append(latents) # to smooth the first and the last frame frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] ############################################## inference batch by batch ############################################## print("start inference") video_num = len(whisper_chunks) batch_size = args.batch_size gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) res_frame_list = [] for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch] audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384 audio_feature_batch = pe(audio_feature_batch) pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample recon = vae.decode_latents(pred_latents) for res_frame in recon: res_frame_list.append(res_frame) ############################################## pad to full image ############################################## print("pad talking image to original video") for i, res_frame in enumerate(tqdm(res_frame_list)): bbox = coord_list_cycle[i%(len(coord_list_cycle))] ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) x1, y1, x2, y2 = bbox try: res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) except: # print(bbox) continue combine_frame = get_image(ori_frame,res_frame,bbox) cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" print(cmd_img2video) os.system(cmd_img2video) cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}" print(cmd_combine_audio) os.system(cmd_combine_audio) os.remove("temp.mp4") shutil.rmtree(result_img_save_path) print(f"result is save to {output_vid_name}") return output_vid_name # load model weights audio_processor,vae,unet,pe = load_all_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) def check_video(video): # Define the output video file name dir_path, file_name = os.path.split(video) if file_name.startswith("outputxxx_"): return video # Add the output prefix to the file name output_file_name = "outputxxx_" + file_name # Combine the directory path and the new file name output_video = os.path.join(dir_path, output_file_name) # Run the ffmpeg command to change the frame rate to 25fps command = f"ffmpeg -i {video} -r 25 {output_video} -y" subprocess.run(command, shell=True, check=True) return output_video css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}""" with gr.Blocks(css=css) as demo: gr.Markdown( "