Upload folder using huggingface_hub
Browse files- handler.py +50 -34
- test.py +57 -4
handler.py
CHANGED
@@ -7,6 +7,7 @@ import numpy as np
|
|
7 |
from queue import Queue, Empty
|
8 |
import threading
|
9 |
import base64
|
|
|
10 |
|
11 |
class EndpointHandler:
|
12 |
def __init__(self, path=""):
|
@@ -22,7 +23,7 @@ class EndpointHandler:
|
|
22 |
self.parler_tts_handler_kwargs,
|
23 |
self.melo_tts_handler_kwargs,
|
24 |
self.chat_tts_handler_kwargs,
|
25 |
-
) = get_default_arguments(mode='none',
|
26 |
setup_logger(self.module_kwargs.log_level)
|
27 |
|
28 |
prepare_all_args(
|
@@ -57,65 +58,80 @@ class EndpointHandler:
|
|
57 |
|
58 |
# Add a new queue for collecting the final output
|
59 |
self.final_output_queue = Queue()
|
|
|
60 |
|
61 |
-
def _collect_output(self):
|
62 |
while True:
|
63 |
try:
|
64 |
-
output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=
|
65 |
if isinstance(output, (str, bytes)) and output in (b"END", "END"):
|
66 |
-
self.
|
67 |
break
|
68 |
elif isinstance(output, np.ndarray):
|
69 |
-
self.
|
70 |
else:
|
71 |
-
self.
|
72 |
except Empty:
|
73 |
-
|
74 |
-
self.final_output_queue.put("END")
|
75 |
break
|
76 |
|
77 |
-
def __call__(self, data: Dict[str, Any]) ->
|
78 |
-
"""
|
79 |
-
Args:
|
80 |
-
data (Dict[str, Any]): The input data containing the necessary arguments.
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
""
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
input_type = data.get("input_type", "text")
|
90 |
input_data = data.get("inputs", "")
|
91 |
|
92 |
if input_type == "speech":
|
93 |
-
# Convert input audio data to numpy array
|
94 |
audio_array = np.frombuffer(input_data, dtype=np.int16)
|
95 |
-
|
96 |
-
# Put audio data into the recv_audio_chunks_queue
|
97 |
self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
|
98 |
elif input_type == "text":
|
99 |
-
# Put text data directly into the text_prompt_queue
|
100 |
self.queues_and_events['text_prompt_queue'].put(input_data)
|
101 |
else:
|
102 |
raise ValueError(f"Unsupported input type: {input_type}")
|
103 |
|
104 |
-
#
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
if chunk == "END":
|
109 |
-
break
|
110 |
-
output_chunks.append(chunk)
|
111 |
|
112 |
-
|
113 |
-
|
|
|
|
|
114 |
|
115 |
-
|
116 |
-
|
|
|
117 |
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
def cleanup(self):
|
121 |
# Stop the pipeline
|
|
|
7 |
from queue import Queue, Empty
|
8 |
import threading
|
9 |
import base64
|
10 |
+
import uuid
|
11 |
|
12 |
class EndpointHandler:
|
13 |
def __init__(self, path=""):
|
|
|
23 |
self.parler_tts_handler_kwargs,
|
24 |
self.melo_tts_handler_kwargs,
|
25 |
self.chat_tts_handler_kwargs,
|
26 |
+
) = get_default_arguments(mode='none', log_level='DEBUG')
|
27 |
setup_logger(self.module_kwargs.log_level)
|
28 |
|
29 |
prepare_all_args(
|
|
|
58 |
|
59 |
# Add a new queue for collecting the final output
|
60 |
self.final_output_queue = Queue()
|
61 |
+
self.sessions = {} # Store session information
|
62 |
|
63 |
+
def _collect_output(self, session_id):
|
64 |
while True:
|
65 |
try:
|
66 |
+
output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=2)
|
67 |
if isinstance(output, (str, bytes)) and output in (b"END", "END"):
|
68 |
+
self.sessions[session_id]['status'] = 'completed'
|
69 |
break
|
70 |
elif isinstance(output, np.ndarray):
|
71 |
+
self.sessions[session_id]['chunks'].append(output.tobytes())
|
72 |
else:
|
73 |
+
self.sessions[session_id]['chunks'].append(output)
|
74 |
except Empty:
|
75 |
+
self.sessions[session_id]['status'] = 'completed'
|
|
|
76 |
break
|
77 |
|
78 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
79 |
+
request_type = data.get("request_type", "start")
|
|
|
|
|
80 |
|
81 |
+
if request_type == "start":
|
82 |
+
return self._handle_start_request(data)
|
83 |
+
elif request_type == "continue":
|
84 |
+
return self._handle_continue_request(data)
|
85 |
+
else:
|
86 |
+
raise ValueError(f"Unsupported request type: {request_type}")
|
87 |
+
|
88 |
+
def _handle_start_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
89 |
+
session_id = str(uuid.uuid4())
|
90 |
+
self.sessions[session_id] = {
|
91 |
+
'status': 'processing',
|
92 |
+
'chunks': [],
|
93 |
+
'last_sent_index': 0
|
94 |
+
}
|
95 |
|
96 |
input_type = data.get("input_type", "text")
|
97 |
input_data = data.get("inputs", "")
|
98 |
|
99 |
if input_type == "speech":
|
|
|
100 |
audio_array = np.frombuffer(input_data, dtype=np.int16)
|
|
|
|
|
101 |
self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
|
102 |
elif input_type == "text":
|
|
|
103 |
self.queues_and_events['text_prompt_queue'].put(input_data)
|
104 |
else:
|
105 |
raise ValueError(f"Unsupported input type: {input_type}")
|
106 |
|
107 |
+
# Start output collection in a separate thread
|
108 |
+
threading.Thread(target=self._collect_output, args=(session_id,)).start()
|
109 |
+
|
110 |
+
return {"session_id": session_id, "status": "processing"}
|
|
|
|
|
|
|
111 |
|
112 |
+
def _handle_continue_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
113 |
+
session_id = data.get("session_id")
|
114 |
+
if not session_id or session_id not in self.sessions:
|
115 |
+
raise ValueError("Invalid or missing session_id")
|
116 |
|
117 |
+
session = self.sessions[session_id]
|
118 |
+
chunks_to_send = session['chunks'][session['last_sent_index']:]
|
119 |
+
session['last_sent_index'] = len(session['chunks'])
|
120 |
|
121 |
+
if chunks_to_send:
|
122 |
+
combined_audio = b''.join(chunks_to_send)
|
123 |
+
base64_audio = base64.b64encode(combined_audio).decode('utf-8')
|
124 |
+
return {
|
125 |
+
"session_id": session_id,
|
126 |
+
"status": session['status'],
|
127 |
+
"output": base64_audio
|
128 |
+
}
|
129 |
+
else:
|
130 |
+
return {
|
131 |
+
"session_id": session_id,
|
132 |
+
"status": session['status'],
|
133 |
+
"output": None
|
134 |
+
}
|
135 |
|
136 |
def cleanup(self):
|
137 |
# Stop the pipeline
|
test.py
CHANGED
@@ -1,7 +1,60 @@
|
|
1 |
from handler import EndpointHandler
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from handler import EndpointHandler
|
2 |
+
import requests
|
3 |
+
import base64
|
4 |
+
import numpy as np
|
5 |
+
import sounddevice as sd
|
6 |
+
import time
|
7 |
|
8 |
+
my_handler = EndpointHandler('')
|
9 |
|
10 |
+
|
11 |
+
def play_audio(audio_data, sample_rate=16000):
|
12 |
+
sd.play(audio_data, sample_rate)
|
13 |
+
sd.wait()
|
14 |
+
|
15 |
+
def stream_audio(session_id):
|
16 |
+
audio_chunks = []
|
17 |
+
while True:
|
18 |
+
continue_payload = {
|
19 |
+
"request_type": "continue",
|
20 |
+
"session_id": session_id
|
21 |
+
}
|
22 |
+
response = my_handler(continue_payload)
|
23 |
+
|
24 |
+
if response["status"] == "completed" and response["output"] is None:
|
25 |
+
break
|
26 |
+
|
27 |
+
if response["output"]:
|
28 |
+
audio_bytes = base64.b64decode(response["output"])
|
29 |
+
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
|
30 |
+
audio_chunks.append(audio_np)
|
31 |
+
|
32 |
+
# Play the chunk immediately (optional)
|
33 |
+
play_audio(audio_np)
|
34 |
+
|
35 |
+
time.sleep(0.01) # Small delay to prevent overwhelming the server
|
36 |
+
|
37 |
+
return np.concatenate(audio_chunks) if audio_chunks else None
|
38 |
+
|
39 |
+
# Test with text input
|
40 |
+
text_payload = {
|
41 |
+
"request_type": "start",
|
42 |
+
"inputs": "Tell me a cool fact about Messi.",
|
43 |
+
"input_type": "text",
|
44 |
+
}
|
45 |
+
|
46 |
+
start_response = my_handler(text_payload)
|
47 |
+
|
48 |
+
|
49 |
+
if "session_id" in start_response:
|
50 |
+
print(f"Session started. Session ID: {start_response['session_id']}")
|
51 |
+
print("Streaming audio response...")
|
52 |
+
|
53 |
+
full_audio = stream_audio(start_response['session_id'])
|
54 |
+
|
55 |
+
if full_audio is not None:
|
56 |
+
print("Received complete audio response. Playing...")
|
57 |
+
else:
|
58 |
+
print("No audio received.")
|
59 |
+
else:
|
60 |
+
print("Error:", start_response)
|