Spaces:
Runtime error
Runtime error
## Dirty one file implementation for expermiental (and fun) purpose only | |
import os | |
import gradio as gr | |
from gradio_client import Client | |
import requests | |
from dotenv import load_dotenv | |
from pydub import AudioSegment | |
from tqdm.auto import tqdm | |
print("starting") | |
load_dotenv() | |
HF_API = os.getenv("HF_API") | |
SEAMLESS_API_URL = os.getenv("SEAMLESS_API_URL") # path to Seamlessm4t API endpoint | |
GPU_AVAILABLE = os.getenv("GPU_AVAILABLE") | |
DEFAULT_TARGET_LANGUAGE = "French" | |
MISTRAL_SUMMARY_URL = ( | |
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2" | |
) | |
LLAMA_SUMMARY_URL = ( | |
"https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct" | |
) | |
print("env setup ok") | |
DESCRIPTION = """ | |
# Transcribe and create a summary of a conversation. | |
""" | |
DUPLICATE = """ | |
To duplicate this repo, you have to give permission from three reopsitories and accept all user conditions: | |
1- https://huggingface.co/pyannote/voice-activity-detection | |
2- https://hf.co/pyannote/segmentation | |
3- https://hf.co/pyannote/speaker-diarization | |
""" | |
from pyannote.audio import Pipeline | |
# initialize diarization pipeline | |
diarizer = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", use_auth_token=HF_API | |
) | |
# send pipeline to GPU (when available) | |
import torch | |
diarizer.to(torch.device(GPU_AVAILABLE)) | |
print("diarizer setup ok") | |
# predict is a generator that incrementally yields recognized text with speaker label | |
def predict(target_language, input_audio): | |
print("->predict started") | |
print(target_language, type(input_audio), input_audio) | |
print("-->diarization") | |
diarized = diarizer(input_audio, min_speakers=2, max_speakers=5) | |
print("-->automatic speech recognition") | |
# split audio according to diarization | |
song = AudioSegment.from_wav(input_audio) | |
# client = Client(SEAMLESS_API_URL, hf_token=HF_API, serialize=False) | |
output_text = "" | |
for turn, _, speaker in diarized.itertracks(yield_label=True): | |
print(speaker, turn) | |
try: | |
filename = f"{turn.start}_segment.wav" | |
clipped = song[turn.start * 1000 : turn.end * 1000] | |
clipped.export(filename, format="wav", bitrate=16000) | |
# result = client.predict(f"my.wav", target_language, api_name="/asr") | |
result = automatic_speech_recognition(target_language, filename) | |
current_text = f"speaker: {speaker} text: {result} " | |
print(current_text) | |
if current_text is not None: | |
output_text = output_text + "\n" + current_text | |
yield output_text | |
except Exception as e: | |
print(e) | |
def automatic_speech_recognition(language, filename): | |
match language: | |
case "French": | |
api_url = "https://api-inference.huggingface.co/models/bofenghuang/whisper-large-v3-french" | |
case "English": | |
api_url = "https://api-inference.huggingface.co/models/facebook/wav2vec2-base-960h" | |
case _: | |
return f"Unknown language {language}" | |
print(f"-> automatic_speech_recognition with {api_url}") | |
with open(filename, "rb") as f: | |
data = f.read() | |
response = requests.post( | |
api_url, headers={"Authorization": f"Bearer {HF_API}"}, data=data | |
) | |
print(response.json()) | |
return response.json()["text"] | |
def generate_summary_llama3(language, transcript): | |
queryTxt = f""" | |
<|begin_of_text|><|start_header_id|>system<|end_header_id|> | |
You are a helpful and truthful patient-doctor encounter summary writer. | |
Users sends you transcripts of patient-doctor encounter and you create accurate and concise summaries. | |
The summary only contains informations from the transcript. | |
Your summary is written in {language}. | |
The summary only includes relevant sections. | |
<template> | |
# Chief Complaint | |
# History of Present Illness (HPI) | |
# Relevant Past Medical History | |
# Physical Examination | |
# Assessment and Plan | |
# Follow-up | |
# Additional Notes | |
</template> <|eot_id|> | |
<|begin_of_text|><|start_header_id|>user<|end_header_id|> | |
<transcript> | |
{transcript} | |
</transcript><|eot_id|> | |
<|start_header_id|>assistant<|end_header_id|> | |
""" | |
payload = { | |
"inputs": queryTxt, | |
"parameters": { | |
"return_full_text": False, | |
"wait_for_model": True, | |
"min_length": 1000, | |
}, | |
"options": {"use_cache": False}, | |
} | |
response = requests.post( | |
LLAMA_SUMMARY_URL, headers={"Authorization": f"Bearer {HF_API}"}, json=payload | |
) | |
print(response.json()) | |
return response.json()[0]["generated_text"][len("<summary>") :] | |
def generate_summary_mistral(language, transcript): | |
sysPrompt = f"""<s>[INST] | |
You are a helpful and truthful patient-doctor encounter summary writer. | |
Users sends you transcripts of patient-doctor encounter and you create accurate and concise summaries. | |
The summary only contains informations from the transcript. | |
Your summary is written in {language}. | |
The summary only includes relevant sections. | |
<template> | |
# Chief Complaint | |
# History of Present Illness (HPI) | |
# Relevant Past Medical History | |
# Physical Examination | |
# Assessment and Plan | |
# Follow-up | |
# Additional Notes | |
</template> | |
""" | |
queryTxt = f""" | |
<transcript> | |
{transcript} | |
</transcript> | |
[/INST] | |
""" | |
payload = { | |
"inputs": sysPrompt + queryTxt, | |
"parameters": { | |
"return_full_text": False, | |
"wait_for_model": True, | |
"min_length": 1000, | |
}, | |
"options": {"use_cache": False}, | |
} | |
response = requests.post( | |
MISTRAL_SUMMARY_URL, headers={"Authorization": f"Bearer {HF_API}"}, json=payload | |
) | |
print(response.json()) | |
return response.json()[0]["generated_text"][len("<summary>") :] | |
def generate_summary(model, language, transcript): | |
match model: | |
case "Mistral-7B": | |
print("-> summarize with mistral") | |
return generate_summary_mistral(language, transcript) | |
case "LLAMA3": | |
print("-> summarize with llama3") | |
return generate_summary_llama3(language, transcript) | |
case _: | |
return f"Unknown model {model}" | |
def update_audio_ui(audio_source: str) -> tuple[dict, dict]: | |
mic = audio_source == "microphone" | |
return ( | |
gr.update(visible=mic, value=None), # input_audio_mic | |
gr.update(visible=not mic, value=None), # input_audio_file | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Group(): | |
with gr.Row(): | |
target_language = gr.Dropdown( | |
choices=["French", "English"], | |
label="Output Language", | |
value="French", | |
interactive=True, | |
info="Select your target language", | |
) | |
with gr.Row() as audio_box: | |
input_audio = gr.Audio(type="filepath") | |
submit = gr.Button("Transcribe") | |
transcribe_output = gr.Textbox( | |
label="Transcribed Text", | |
value="", | |
interactive=False, | |
lines=10, | |
scale=10, | |
max_lines=100, | |
) | |
submit.click( | |
fn=predict, | |
inputs=[target_language, input_audio], | |
outputs=[transcribe_output], | |
api_name="predict", | |
) | |
with gr.Row(): | |
sumary_model = gr.Dropdown( | |
choices=["Mistral-7B", "LLAMA3"], | |
label="Summary model", | |
value="Mistral-7B", | |
interactive=True, | |
info="Select your summary model", | |
) | |
summarize = gr.Button("Summarize") | |
summary_output = gr.Textbox( | |
label="Summarized Text", | |
value="", | |
interactive=False, | |
lines=10, | |
scale=10, | |
max_lines=100, | |
) | |
summarize.click( | |
fn=generate_summary, | |
inputs=[sumary_model, target_language, transcribe_output], | |
outputs=[summary_output], | |
api_name="predict", | |
) | |
gr.Markdown(DUPLICATE) | |
demo.queue(max_size=50).launch() | |