Hilley commited on
Commit
b57d37a
1 Parent(s): 2fe2568

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -5
app.py CHANGED
@@ -9,12 +9,14 @@ import numpy as np
9
 
10
  import ChatTTS
11
 
 
 
 
12
  print("loading ChatTTS model...")
13
  chat = ChatTTS.Chat()
14
  chat.load_models()
15
 
16
 
17
-
18
  def generate_seed():
19
  new_seed = random.randint(1, 100000000)
20
  return {
@@ -23,7 +25,7 @@ def generate_seed():
23
  }
24
 
25
  @spaces.GPU
26
- def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input):
27
 
28
  torch.manual_seed(audio_seed_input)
29
  rand_spk = torch.randn(768)
@@ -57,7 +59,67 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_
57
  sample_rate = 24000
58
  text_data = text[0] if isinstance(text, list) else text
59
 
60
- return [(sample_rate, audio_data), text_data]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  with gr.Blocks() as demo:
@@ -83,7 +145,7 @@ with gr.Blocks() as demo:
83
 
84
  generate_button = gr.Button("Generate")
85
 
86
- text_output = gr.Textbox(label="Refined Text", interactive=False)
87
  audio_output = gr.Audio(label="Output Audio")
88
 
89
  generate_audio_seed.click(generate_seed,
@@ -96,7 +158,7 @@ with gr.Blocks() as demo:
96
 
97
  generate_button.click(generate_audio,
98
  inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
99
- outputs=[audio_output, text_output])
100
 
101
  parser = argparse.ArgumentParser(description='ChatTTS demo Launch')
102
  parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
 
9
 
10
  import ChatTTS
11
 
12
+ import se_extractor
13
+ from api import BaseSpeakerTTS, ToneColorConverter
14
+
15
  print("loading ChatTTS model...")
16
  chat = ChatTTS.Chat()
17
  chat.load_models()
18
 
19
 
 
20
  def generate_seed():
21
  new_seed = random.randint(1, 100000000)
22
  return {
 
25
  }
26
 
27
  @spaces.GPU
28
+ def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None):
29
 
30
  torch.manual_seed(audio_seed_input)
31
  rand_spk = torch.randn(768)
 
59
  sample_rate = 24000
60
  text_data = text[0] if isinstance(text, list) else text
61
 
62
+ if output_path is None:
63
+ return [(sample_rate, audio_data), text_data]
64
+ else:
65
+ soundfile.write(output_path, audio_data, sample_rate)
66
+
67
+ # OpenVoice
68
+
69
+ ckpt_base_en = 'checkpoints/base_speakers/EN'
70
+ ckpt_converter_en = 'checkpoints/converter'
71
+ device = 'cuda:0'
72
+
73
+ #device = "cpu"
74
+
75
+ base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base_en}/config.json', device=device)
76
+ base_speaker_tts.load_ckpt(f'{ckpt_base_en}/checkpoint.pth')
77
+
78
+ tone_color_converter = ToneColorConverter(f'{ckpt_converter_en}/config.json', device=device)
79
+ tone_color_converter.load_ckpt(f'{ckpt_converter_en}/checkpoint.pth')
80
+
81
+
82
+ def generate_audio(text, audio_ref, style_mode, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input):
83
+ if style_mode=="default":
84
+ source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device)
85
+ reference_speaker = audio_ref
86
+ target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
87
+ save_path = "output.wav"
88
+
89
+ # Run the base speaker tts
90
+ src_path = "tmp.wav"
91
+ chat_tts(text, text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None, src_path)
92
+
93
+ # Run the tone color converter
94
+ encode_message = "@MyShell"
95
+ tone_color_converter.convert(
96
+ audio_src_path=src_path,
97
+ src_se=source_se,
98
+ tgt_se=target_se,
99
+ output_path=save_path,
100
+ message=encode_message)
101
+
102
+ else:
103
+ source_se = torch.load(f'{ckpt_base_en}/en_style_se.pth').to(device)
104
+ reference_speaker = audio_ref
105
+ target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
106
+
107
+ save_path = "output.wav"
108
+
109
+ # Run the base speaker tts
110
+ src_path = "tmp.wav"
111
+ base_speaker_tts.tts(text, src_path, speaker=style_mode, language='English', speed=0.9)
112
+
113
+ # Run the tone color converter
114
+ encode_message = "@MyShell"
115
+ tone_color_converter.convert(
116
+ audio_src_path=src_path,
117
+ src_se=source_se,
118
+ tgt_se=target_se,
119
+ output_path=save_path,
120
+ message=encode_message)
121
+
122
+ return "output.wav"
123
 
124
 
125
  with gr.Blocks() as demo:
 
145
 
146
  generate_button = gr.Button("Generate")
147
 
148
+ #text_output = gr.Textbox(label="Refined Text", interactive=False)
149
  audio_output = gr.Audio(label="Output Audio")
150
 
151
  generate_audio_seed.click(generate_seed,
 
158
 
159
  generate_button.click(generate_audio,
160
  inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
161
+ outputs=audio_output)
162
 
163
  parser = argparse.ArgumentParser(description='ChatTTS demo Launch')
164
  parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')