Update app.py
Browse files
app.py
CHANGED
@@ -9,12 +9,14 @@ import numpy as np
|
|
9 |
|
10 |
import ChatTTS
|
11 |
|
|
|
|
|
|
|
12 |
print("loading ChatTTS model...")
|
13 |
chat = ChatTTS.Chat()
|
14 |
chat.load_models()
|
15 |
|
16 |
|
17 |
-
|
18 |
def generate_seed():
|
19 |
new_seed = random.randint(1, 100000000)
|
20 |
return {
|
@@ -23,7 +25,7 @@ def generate_seed():
|
|
23 |
}
|
24 |
|
25 |
@spaces.GPU
|
26 |
-
def
|
27 |
|
28 |
torch.manual_seed(audio_seed_input)
|
29 |
rand_spk = torch.randn(768)
|
@@ -57,7 +59,67 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_
|
|
57 |
sample_rate = 24000
|
58 |
text_data = text[0] if isinstance(text, list) else text
|
59 |
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
with gr.Blocks() as demo:
|
@@ -83,7 +145,7 @@ with gr.Blocks() as demo:
|
|
83 |
|
84 |
generate_button = gr.Button("Generate")
|
85 |
|
86 |
-
text_output = gr.Textbox(label="Refined Text", interactive=False)
|
87 |
audio_output = gr.Audio(label="Output Audio")
|
88 |
|
89 |
generate_audio_seed.click(generate_seed,
|
@@ -96,7 +158,7 @@ with gr.Blocks() as demo:
|
|
96 |
|
97 |
generate_button.click(generate_audio,
|
98 |
inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
|
99 |
-
outputs=
|
100 |
|
101 |
parser = argparse.ArgumentParser(description='ChatTTS demo Launch')
|
102 |
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
|
|
|
9 |
|
10 |
import ChatTTS
|
11 |
|
12 |
+
import se_extractor
|
13 |
+
from api import BaseSpeakerTTS, ToneColorConverter
|
14 |
+
|
15 |
print("loading ChatTTS model...")
|
16 |
chat = ChatTTS.Chat()
|
17 |
chat.load_models()
|
18 |
|
19 |
|
|
|
20 |
def generate_seed():
|
21 |
new_seed = random.randint(1, 100000000)
|
22 |
return {
|
|
|
25 |
}
|
26 |
|
27 |
@spaces.GPU
|
28 |
+
def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None):
|
29 |
|
30 |
torch.manual_seed(audio_seed_input)
|
31 |
rand_spk = torch.randn(768)
|
|
|
59 |
sample_rate = 24000
|
60 |
text_data = text[0] if isinstance(text, list) else text
|
61 |
|
62 |
+
if output_path is None:
|
63 |
+
return [(sample_rate, audio_data), text_data]
|
64 |
+
else:
|
65 |
+
soundfile.write(output_path, audio_data, sample_rate)
|
66 |
+
|
67 |
+
# OpenVoice
|
68 |
+
|
69 |
+
ckpt_base_en = 'checkpoints/base_speakers/EN'
|
70 |
+
ckpt_converter_en = 'checkpoints/converter'
|
71 |
+
device = 'cuda:0'
|
72 |
+
|
73 |
+
#device = "cpu"
|
74 |
+
|
75 |
+
base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base_en}/config.json', device=device)
|
76 |
+
base_speaker_tts.load_ckpt(f'{ckpt_base_en}/checkpoint.pth')
|
77 |
+
|
78 |
+
tone_color_converter = ToneColorConverter(f'{ckpt_converter_en}/config.json', device=device)
|
79 |
+
tone_color_converter.load_ckpt(f'{ckpt_converter_en}/checkpoint.pth')
|
80 |
+
|
81 |
+
|
82 |
+
def generate_audio(text, audio_ref, style_mode, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input):
|
83 |
+
if style_mode=="default":
|
84 |
+
source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device)
|
85 |
+
reference_speaker = audio_ref
|
86 |
+
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
|
87 |
+
save_path = "output.wav"
|
88 |
+
|
89 |
+
# Run the base speaker tts
|
90 |
+
src_path = "tmp.wav"
|
91 |
+
chat_tts(text, text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None, src_path)
|
92 |
+
|
93 |
+
# Run the tone color converter
|
94 |
+
encode_message = "@MyShell"
|
95 |
+
tone_color_converter.convert(
|
96 |
+
audio_src_path=src_path,
|
97 |
+
src_se=source_se,
|
98 |
+
tgt_se=target_se,
|
99 |
+
output_path=save_path,
|
100 |
+
message=encode_message)
|
101 |
+
|
102 |
+
else:
|
103 |
+
source_se = torch.load(f'{ckpt_base_en}/en_style_se.pth').to(device)
|
104 |
+
reference_speaker = audio_ref
|
105 |
+
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
|
106 |
+
|
107 |
+
save_path = "output.wav"
|
108 |
+
|
109 |
+
# Run the base speaker tts
|
110 |
+
src_path = "tmp.wav"
|
111 |
+
base_speaker_tts.tts(text, src_path, speaker=style_mode, language='English', speed=0.9)
|
112 |
+
|
113 |
+
# Run the tone color converter
|
114 |
+
encode_message = "@MyShell"
|
115 |
+
tone_color_converter.convert(
|
116 |
+
audio_src_path=src_path,
|
117 |
+
src_se=source_se,
|
118 |
+
tgt_se=target_se,
|
119 |
+
output_path=save_path,
|
120 |
+
message=encode_message)
|
121 |
+
|
122 |
+
return "output.wav"
|
123 |
|
124 |
|
125 |
with gr.Blocks() as demo:
|
|
|
145 |
|
146 |
generate_button = gr.Button("Generate")
|
147 |
|
148 |
+
#text_output = gr.Textbox(label="Refined Text", interactive=False)
|
149 |
audio_output = gr.Audio(label="Output Audio")
|
150 |
|
151 |
generate_audio_seed.click(generate_seed,
|
|
|
158 |
|
159 |
generate_button.click(generate_audio,
|
160 |
inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
|
161 |
+
outputs=audio_output)
|
162 |
|
163 |
parser = argparse.ArgumentParser(description='ChatTTS demo Launch')
|
164 |
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
|