Update TextGen/router.py
Browse files- TextGen/router.py +17 -4
TextGen/router.py
CHANGED
@@ -25,6 +25,8 @@ class VoiceMessage(BaseModel):
|
|
25 |
npc: str | None = None
|
26 |
input: str | None = None
|
27 |
language: str | None = "en"
|
|
|
|
|
28 |
song_base_api=os.environ["VERCEL_API"]
|
29 |
|
30 |
my_hf_token=os.environ["HF_TOKEN"]
|
@@ -32,7 +34,10 @@ my_hf_token=os.environ["HF_TOKEN"]
|
|
32 |
tts_client = Client("https://jofthomas-xtts.hf.space/",hf_token=my_hf_token)
|
33 |
|
34 |
|
35 |
-
|
|
|
|
|
|
|
36 |
class Generate(BaseModel):
|
37 |
text:str
|
38 |
|
@@ -77,13 +82,21 @@ def inference(message: Message):
|
|
77 |
return generate_text(prompt=message.input)
|
78 |
|
79 |
#Dummy function for now
|
80 |
-
def determine_vocie_from_npc(npc):
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
@app.post("/generate_wav")
|
84 |
async def generate_wav(message:VoiceMessage):
|
85 |
try:
|
86 |
-
voice=determine_vocie_from_npc(message.npc)
|
87 |
# Use the Gradio client to generate the wav file
|
88 |
result = tts_client.predict(
|
89 |
message.input, # str in 'Text Prompt' Textbox component
|
|
|
25 |
npc: str | None = None
|
26 |
input: str | None = None
|
27 |
language: str | None = "en"
|
28 |
+
genre:str | None = "Male"
|
29 |
+
|
30 |
song_base_api=os.environ["VERCEL_API"]
|
31 |
|
32 |
my_hf_token=os.environ["HF_TOKEN"]
|
|
|
34 |
tts_client = Client("https://jofthomas-xtts.hf.space/",hf_token=my_hf_token)
|
35 |
|
36 |
|
37 |
+
main_npcs={
|
38 |
+
"Blacksmith":"./voices/blacksmith.mp3",
|
39 |
+
"Herbalist":"./voices/female.wav"
|
40 |
+
}
|
41 |
class Generate(BaseModel):
|
42 |
text:str
|
43 |
|
|
|
82 |
return generate_text(prompt=message.input)
|
83 |
|
84 |
#Dummy function for now
|
85 |
+
def determine_vocie_from_npc(npc,genre):
|
86 |
+
if npc in main_npcs:
|
87 |
+
return main_npcs[npc]
|
88 |
+
else:
|
89 |
+
if genre =="Male":
|
90 |
+
"./voices/blacksmith.mp3"
|
91 |
+
if genre=="Female":
|
92 |
+
return"./voices/female.wav"
|
93 |
+
else:
|
94 |
+
return "./voices/narator_out.wav"
|
95 |
|
96 |
@app.post("/generate_wav")
|
97 |
async def generate_wav(message:VoiceMessage):
|
98 |
try:
|
99 |
+
voice=determine_vocie_from_npc(message.npc, message.genre)
|
100 |
# Use the Gradio client to generate the wav file
|
101 |
result = tts_client.predict(
|
102 |
message.input, # str in 'Text Prompt' Textbox component
|