import torch.cuda import gradio as gr import mdtex2html import tempfile from PIL import Image import scipy import argparse from llama.m2ugen import M2UGen import llama import numpy as np import os import torch import torchaudio import torchvision.transforms as transforms import av import subprocess import librosa parser = argparse.ArgumentParser() parser.add_argument( "--model", default="./ckpts/checkpoint.pth", type=str, help="Name of or path to M2UGen pretrained checkpoint", ) parser.add_argument( "--llama_type", default="7B", type=str, help="Type of llama original weight", ) parser.add_argument( "--llama_dir", default="/path/to/llama", type=str, help="Path to LLaMA pretrained checkpoint", ) parser.add_argument( "--mert_path", default="m-a-p/MERT-v1-330M", type=str, help="Path to MERT pretrained checkpoint", ) parser.add_argument( "--vit_path", default="m-a-p/MERT-v1-330M", type=str, help="Path to ViT pretrained checkpoint", ) parser.add_argument( "--vivit_path", default="m-a-p/MERT-v1-330M", type=str, help="Path to ViViT pretrained checkpoint", ) parser.add_argument( "--knn_dir", default="./ckpts", type=str, help="Path to directory with KNN Index", ) parser.add_argument( '--music_decoder', default="musicgen", type=str, help='Decoder to use musicgen/audioldm2') parser.add_argument( '--music_decoder_path', default="facebook/musicgen-medium", type=str, help='Path to decoder to use musicgen/audioldm2') args = parser.parse_args() generated_audio_files = [] llama_type = args.llama_type llama_ckpt_dir = os.path.join(args.llama_dir, llama_type) llama_tokenzier_path = args.llama_dir model = M2UGen(llama_ckpt_dir, llama_tokenzier_path, args, knn=False, stage=None, load_llama=False) print("Loading Model Checkpoint") checkpoint = torch.load(args.model, map_location='cpu') new_ckpt = {} for key, value in checkpoint['model'].items(): if "generation_model" in key: continue key = key.replace("module.", "") new_ckpt[key] = value load_result = model.load_state_dict(new_ckpt, strict=False) assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}" model.eval() model.to("cuda") #model.generation_model.to("cuda") #model.mert_model.to("cuda") #model.vit_model.to("cuda") #model.vivit_model.to("cuda") transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x)]) def postprocess(self, y): if y is None: return [] for i, (message, response) in enumerate(y): y[i] = ( None if message is None else mdtex2html.convert((message)), None if response is None else mdtex2html.convert(response), ) return y gr.Chatbot.postprocess = postprocess def parse_text(text, image_path, video_path, audio_path): """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" outputs = text lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split('`') if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f'
' else: if i > 0: if count % 2 == 1: line = line.replace("`", "\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) + "
" if image_path is not None: text += f'
' outputs = f'{image_path} ' + outputs if video_path is not None: text += f'