Mocci lutha commited on
Commit
4384ac1
1 Parent(s): 8ea7968

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import traceback
5
+ import logging
6
+ import gradio as gr
7
+ import numpy as np
8
+ import librosa
9
+ import torch
10
+ import asyncio
11
+ import edge_tts
12
+ from datetime import datetime
13
+ from fairseq import checkpoint_utils
14
+ from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono
15
+ from vc_infer_pipeline import VC
16
+ from config import (
17
+ is_half,
18
+ device
19
+ )
20
+ logging.getLogger("numba").setLevel(logging.WARNING)
21
+ limitation = os.getenv("SYSTEM") == "spaces" # limit audio length in huggingface spaces
22
+
23
+ def create_vc_fn(tgt_sr, net_g, vc, if_f0, file_index, file_big_npy):
24
+ def vc_fn(
25
+ input_audio,
26
+ f0_up_key,
27
+ f0_method,
28
+ index_rate,
29
+ tts_mode,
30
+ tts_text,
31
+ tts_voice
32
+ ):
33
+ try:
34
+ if tts_mode:
35
+ if len(tts_text) > 100 and limitation:
36
+ return "Text is too long", None
37
+ if tts_text is None or tts_voice is None:
38
+ return "You need to enter text and select a voice", None
39
+ asyncio.run(edge_tts.Communicate(tts_text, "-".join(tts_voice.split('-')[:-1])).save("tts.mp3"))
40
+ audio, sr = librosa.load("tts.mp3", sr=16000, mono=True)
41
+ else:
42
+ if args.files:
43
+ audio, sr = librosa.load(input_audio, sr=16000, mono=True)
44
+ else:
45
+ if input_audio is None:
46
+ return "You need to upload an audio", None
47
+ sampling_rate, audio = input_audio
48
+ duration = audio.shape[0] / sampling_rate
49
+ if duration > 20 and limitation:
50
+ return "Please upload an audio file that is less than 20 seconds. If you need to generate a longer audio file, please use Colab.", None
51
+ audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
52
+ if len(audio.shape) > 1:
53
+ audio = librosa.to_mono(audio.transpose(1, 0))
54
+ if sampling_rate != 16000:
55
+ audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
56
+ times = [0, 0, 0]
57
+ f0_up_key = int(f0_up_key)
58
+ audio_opt = vc.pipeline(
59
+ hubert_model,
60
+ net_g,
61
+ 0,
62
+ audio,
63
+ times,
64
+ f0_up_key,
65
+ f0_method,
66
+ file_index,
67
+ file_big_npy,
68
+ index_rate,
69
+ if_f0,
70
+ )
71
+ print(
72
+ f"[{datetime.now().strftime('%Y-%m-%d %H:%M')}]: npy: {times[0]}, f0: {times[1]}s, infer: {times[2]}s"
73
+ )
74
+ return "Success", (tgt_sr, audio_opt)
75
+ except:
76
+ info = traceback.format_exc()
77
+ print(info)
78
+ return info, (None, None)
79
+ return vc_fn
80
+
81
+ def load_hubert():
82
+ global hubert_model
83
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
84
+ ["hubert_base.pt"],
85
+ suffix="",
86
+ )
87
+ hubert_model = models[0]
88
+ hubert_model = hubert_model.to(device)
89
+ if is_half:
90
+ hubert_model = hubert_model.half()
91
+ else:
92
+ hubert_model = hubert_model.float()
93
+ hubert_model.eval()
94
+
95
+ def change_to_tts_mode(tts_mode):
96
+ if tts_mode:
97
+ return gr.Audio.update(visible=False), gr.Textbox.update(visible=True), gr.Dropdown.update(visible=True)
98
+ else:
99
+ return gr.Audio.update(visible=True), gr.Textbox.update(visible=False), gr.Dropdown.update(visible=False)
100
+
101
+ if __name__ == '__main__':
102
+ parser = argparse.ArgumentParser()
103
+ parser.add_argument('--api', action="store_true", default=False)
104
+ parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
105
+ parser.add_argument("--files", action="store_true", default=False, help="load audio from path")
106
+ args, unknown = parser.parse_known_args()
107
+ load_hubert()
108
+ models = []
109
+ tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
110
+ voices = [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
111
+ with open("weights/model_info.json", "r", encoding="utf-8") as f:
112
+ models_info = json.load(f)
113
+ for name, info in models_info.items():
114
+ if not info['enable']:
115
+ continue
116
+ title = info['title']
117
+ author = info.get("author", None)
118
+ cover = f"weights/{name}/{info['cover']}"
119
+ index = f"weights/{name}/{info['feature_retrieval_library']}"
120
+ npy = f"weights/{name}/{info['feature_file']}"
121
+ cpt = torch.load(f"weights/{name}/{name}.pth", map_location="cpu")
122
+ tgt_sr = cpt["config"][-1]
123
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
124
+ if_f0 = cpt.get("f0", 1)
125
+ if if_f0 == 1:
126
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
127
+ else:
128
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
129
+ del net_g.enc_q
130
+ print(net_g.load_state_dict(cpt["weight"], strict=False)) # 不加这一行清不干净, 真奇葩
131
+ net_g.eval().to(device)
132
+ if is_half:
133
+ net_g = net_g.half()
134
+ else:
135
+ net_g = net_g.float()
136
+ vc = VC(tgt_sr, device, is_half)
137
+ models.append((name, title, author, cover, create_vc_fn(tgt_sr, net_g, vc, if_f0, index, npy)))
138
+ with gr.Blocks() as app:
139
+ gr.Markdown(
140
+ "# <center> RVC Models\n"
141
+ "## <center> The input audio should be clean and pure voice without background music.\n"
142
+ "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=zomehwh.Rvc-Models)\n\n"
143
+ "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/16MXRcKEjGDqQzVanvi8xYOOOlhdNBopM?usp=share_link)\n\n"
144
+ "[![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm-dark.svg)](https://huggingface.co/spaces/zomehwh/rvc-models?duplicate=true)\n\n"
145
+ "[![Original Repo](https://badgen.net/badge/icon/github?icon=github&label=Original%20Repo)](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI)"
146
+
147
+ )
148
+ with gr.Tabs():
149
+ for (name, title, author, cover, vc_fn) in models:
150
+ with gr.TabItem(name):
151
+ with gr.Row():
152
+ gr.Markdown(
153
+ '<div align="center">'
154
+ f'<div>{title}</div>\n'+
155
+ (f'<div>Model author: {author}</div>' if author else "")+
156
+ (f'<img style="width:auto;height:300px;" src="file/{cover}">' if cover else "")+
157
+ '</div>'
158
+ )
159
+ with gr.Row():
160
+ with gr.Column():
161
+ if args.files:
162
+ vc_input = gr.Textbox(label="Input audio path")
163
+ else:
164
+ vc_input = gr.Audio(label="Input audio"+' (less than 20 seconds)' if limitation else '')
165
+ vc_transpose = gr.Number(label="Transpose", value=0)
166
+ vc_f0method = gr.Radio(
167
+ label="Pitch extraction algorithm, PM is fast but Harvest is better for low frequencies",
168
+ choices=["pm", "harvest"],
169
+ value="pm",
170
+ interactive=True,
171
+ )
172
+ vc_index_ratio = gr.Slider(
173
+ minimum=0,
174
+ maximum=1,
175
+ label="Retrieval feature ratio",
176
+ value=0.6,
177
+ interactive=True,
178
+ )
179
+ tts_mode = gr.Checkbox(label="tts (use edge-tts as input)", value=False)
180
+ tts_text = gr.Textbox(visible=False,label="TTS text (100 words limitation)" if limitation else "TTS text")
181
+ tts_voice = gr.Dropdown(label="Edge-tts speaker", choices=voices, visible=False, allow_custom_value=False, value="en-US-AnaNeural-Female")
182
+ vc_submit = gr.Button("Generate", variant="primary")
183
+ with gr.Column():
184
+ vc_output1 = gr.Textbox(label="Output Message")
185
+ vc_output2 = gr.Audio(label="Output Audio")
186
+ vc_submit.click(vc_fn, [vc_input, vc_transpose, vc_f0method, vc_index_ratio, tts_mode, tts_text, tts_voice], [vc_output1, vc_output2])
187
+ tts_mode.change(change_to_tts_mode, [tts_mode], [vc_input, tts_text, tts_voice])
188
+ app.queue(concurrency_count=1, max_size=20, api_open=args.api).launch(share=args.share)