andito HF staff commited on
Commit
3abafc4
1 Parent(s): f6f039f

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. handler.py +50 -34
  2. test.py +57 -4
handler.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
  from queue import Queue, Empty
8
  import threading
9
  import base64
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
@@ -22,7 +23,7 @@ class EndpointHandler:
22
  self.parler_tts_handler_kwargs,
23
  self.melo_tts_handler_kwargs,
24
  self.chat_tts_handler_kwargs,
25
- ) = get_default_arguments(mode='none', lm_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', log_level='DEBUG')
26
  setup_logger(self.module_kwargs.log_level)
27
 
28
  prepare_all_args(
@@ -57,65 +58,80 @@ class EndpointHandler:
57
 
58
  # Add a new queue for collecting the final output
59
  self.final_output_queue = Queue()
 
60
 
61
- def _collect_output(self):
62
  while True:
63
  try:
64
- output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=5) # 2-second timeout
65
  if isinstance(output, (str, bytes)) and output in (b"END", "END"):
66
- self.final_output_queue.put("END")
67
  break
68
  elif isinstance(output, np.ndarray):
69
- self.final_output_queue.put(output.tobytes())
70
  else:
71
- self.final_output_queue.put(output)
72
  except Empty:
73
- # If no output for 2 seconds, assume processing is complete
74
- self.final_output_queue.put("END")
75
  break
76
 
77
- def __call__(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
78
- """
79
- Args:
80
- data (Dict[str, Any]): The input data containing the necessary arguments.
81
 
82
- Returns:
83
- Generator[Dict[str, Any], None, None]: A generator yielding output chunks from the model or pipeline.
84
- """
85
- # Start a thread to collect the final output
86
- self.output_collector_thread = threading.Thread(target=self._collect_output)
87
- self.output_collector_thread.start()
 
 
 
 
 
 
 
 
88
 
89
  input_type = data.get("input_type", "text")
90
  input_data = data.get("inputs", "")
91
 
92
  if input_type == "speech":
93
- # Convert input audio data to numpy array
94
  audio_array = np.frombuffer(input_data, dtype=np.int16)
95
-
96
- # Put audio data into the recv_audio_chunks_queue
97
  self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
98
  elif input_type == "text":
99
- # Put text data directly into the text_prompt_queue
100
  self.queues_and_events['text_prompt_queue'].put(input_data)
101
  else:
102
  raise ValueError(f"Unsupported input type: {input_type}")
103
 
104
- # Collect all output chunks
105
- output_chunks = []
106
- while True:
107
- chunk = self.final_output_queue.get()
108
- if chunk == "END":
109
- break
110
- output_chunks.append(chunk)
111
 
112
- # Combine all audio chunks into a single byte string
113
- combined_audio = b''.join(output_chunks)
 
 
114
 
115
- # Encode the combined audio as Base64
116
- base64_audio = base64.b64encode(combined_audio).decode('utf-8')
 
117
 
118
- return {"output": base64_audio}
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def cleanup(self):
121
  # Stop the pipeline
 
7
  from queue import Queue, Empty
8
  import threading
9
  import base64
10
+ import uuid
11
 
12
  class EndpointHandler:
13
  def __init__(self, path=""):
 
23
  self.parler_tts_handler_kwargs,
24
  self.melo_tts_handler_kwargs,
25
  self.chat_tts_handler_kwargs,
26
+ ) = get_default_arguments(mode='none', log_level='DEBUG')
27
  setup_logger(self.module_kwargs.log_level)
28
 
29
  prepare_all_args(
 
58
 
59
  # Add a new queue for collecting the final output
60
  self.final_output_queue = Queue()
61
+ self.sessions = {} # Store session information
62
 
63
+ def _collect_output(self, session_id):
64
  while True:
65
  try:
66
+ output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=2)
67
  if isinstance(output, (str, bytes)) and output in (b"END", "END"):
68
+ self.sessions[session_id]['status'] = 'completed'
69
  break
70
  elif isinstance(output, np.ndarray):
71
+ self.sessions[session_id]['chunks'].append(output.tobytes())
72
  else:
73
+ self.sessions[session_id]['chunks'].append(output)
74
  except Empty:
75
+ self.sessions[session_id]['status'] = 'completed'
 
76
  break
77
 
78
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
79
+ request_type = data.get("request_type", "start")
 
 
80
 
81
+ if request_type == "start":
82
+ return self._handle_start_request(data)
83
+ elif request_type == "continue":
84
+ return self._handle_continue_request(data)
85
+ else:
86
+ raise ValueError(f"Unsupported request type: {request_type}")
87
+
88
+ def _handle_start_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
89
+ session_id = str(uuid.uuid4())
90
+ self.sessions[session_id] = {
91
+ 'status': 'processing',
92
+ 'chunks': [],
93
+ 'last_sent_index': 0
94
+ }
95
 
96
  input_type = data.get("input_type", "text")
97
  input_data = data.get("inputs", "")
98
 
99
  if input_type == "speech":
 
100
  audio_array = np.frombuffer(input_data, dtype=np.int16)
 
 
101
  self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
102
  elif input_type == "text":
 
103
  self.queues_and_events['text_prompt_queue'].put(input_data)
104
  else:
105
  raise ValueError(f"Unsupported input type: {input_type}")
106
 
107
+ # Start output collection in a separate thread
108
+ threading.Thread(target=self._collect_output, args=(session_id,)).start()
109
+
110
+ return {"session_id": session_id, "status": "processing"}
 
 
 
111
 
112
+ def _handle_continue_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
113
+ session_id = data.get("session_id")
114
+ if not session_id or session_id not in self.sessions:
115
+ raise ValueError("Invalid or missing session_id")
116
 
117
+ session = self.sessions[session_id]
118
+ chunks_to_send = session['chunks'][session['last_sent_index']:]
119
+ session['last_sent_index'] = len(session['chunks'])
120
 
121
+ if chunks_to_send:
122
+ combined_audio = b''.join(chunks_to_send)
123
+ base64_audio = base64.b64encode(combined_audio).decode('utf-8')
124
+ return {
125
+ "session_id": session_id,
126
+ "status": session['status'],
127
+ "output": base64_audio
128
+ }
129
+ else:
130
+ return {
131
+ "session_id": session_id,
132
+ "status": session['status'],
133
+ "output": None
134
+ }
135
 
136
  def cleanup(self):
137
  # Stop the pipeline
test.py CHANGED
@@ -1,7 +1,60 @@
1
  from handler import EndpointHandler
 
 
 
 
 
2
 
3
- endpoint = EndpointHandler('')
4
 
5
- for x in endpoint({'text': 'how are you?'}):
6
- print('passed')
7
- print(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from handler import EndpointHandler
2
+ import requests
3
+ import base64
4
+ import numpy as np
5
+ import sounddevice as sd
6
+ import time
7
 
8
+ my_handler = EndpointHandler('')
9
 
10
+
11
+ def play_audio(audio_data, sample_rate=16000):
12
+ sd.play(audio_data, sample_rate)
13
+ sd.wait()
14
+
15
+ def stream_audio(session_id):
16
+ audio_chunks = []
17
+ while True:
18
+ continue_payload = {
19
+ "request_type": "continue",
20
+ "session_id": session_id
21
+ }
22
+ response = my_handler(continue_payload)
23
+
24
+ if response["status"] == "completed" and response["output"] is None:
25
+ break
26
+
27
+ if response["output"]:
28
+ audio_bytes = base64.b64decode(response["output"])
29
+ audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
30
+ audio_chunks.append(audio_np)
31
+
32
+ # Play the chunk immediately (optional)
33
+ play_audio(audio_np)
34
+
35
+ time.sleep(0.01) # Small delay to prevent overwhelming the server
36
+
37
+ return np.concatenate(audio_chunks) if audio_chunks else None
38
+
39
+ # Test with text input
40
+ text_payload = {
41
+ "request_type": "start",
42
+ "inputs": "Tell me a cool fact about Messi.",
43
+ "input_type": "text",
44
+ }
45
+
46
+ start_response = my_handler(text_payload)
47
+
48
+
49
+ if "session_id" in start_response:
50
+ print(f"Session started. Session ID: {start_response['session_id']}")
51
+ print("Streaming audio response...")
52
+
53
+ full_audio = stream_audio(start_response['session_id'])
54
+
55
+ if full_audio is not None:
56
+ print("Received complete audio response. Playing...")
57
+ else:
58
+ print("No audio received.")
59
+ else:
60
+ print("Error:", start_response)