riteshkr's picture
Update app.py
3f04ac3 verified
raw
history blame contribute delete
No virus
4 kB
import gradio as gr
import torch
import numpy as np
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
# Check if a GPU is available and set the device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load the Whisper ASR model
whisper_model_id = "riteshkr/quantized-whisper-large-v3"
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_id)
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_id)
# Set the language to English using forced_decoder_ids
forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="english", task="transcribe")
whisper_pipe = pipeline(
"automatic-speech-recognition",
model=whisper_model,
tokenizer=whisper_processor.tokenizer,
feature_extractor=whisper_processor.feature_extractor,
device=0 if torch.cuda.is_available() else -1
)
# Load the SpeechT5 TTS model
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
tts_model.to(device)
vocoder.to(device)
# Load speaker embeddings for TTS
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)
# Set target data type and max range for speech
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
# Define the transcription function (Whisper ASR)
def transcribe_speech(filepath):
batch_size = 16 if torch.cuda.is_available() else 4
output = whisper_pipe(
filepath,
max_new_tokens=256,
generate_kwargs={"forced_decoder_ids": forced_decoder_ids},
chunk_length_s=30,
batch_size=batch_size,
)
return output["text"]
# Define the synthesis function (SpeechT5 TTS)
def synthesise(text):
inputs = tts_processor(text=text, return_tensors="pt")
speech = tts_model.generate_speech(
inputs["input_ids"].to(device), speaker_embeddings, vocoder=vocoder
)
return speech.cpu()
# Define the speech-to-speech translation function
def speech_to_speech_translation(audio):
# Transcribe speech
translated_text = transcribe_speech(audio)
# Synthesize speech
synthesised_speech = synthesise(translated_text)
# Convert speech to desired format
synthesised_speech = (synthesised_speech.numpy() * max_range).astype(np.int16)
return 16000, synthesised_speech
# Define the Gradio interfaces for microphone input and file upload
mic_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
)
file_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
)
# Define the Gradio interfaces for transcription
mic_transcribe = gr.Interface(
fn=transcribe_speech,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Textbox(),
)
file_transcribe = gr.Interface(
fn=transcribe_speech,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=gr.Textbox(),
)
# Create the app using Gradio Blocks with tabbed interfaces
demo = gr.Blocks()
with demo:
gr.TabbedInterface(
[
mic_transcribe, file_transcribe, # For transcription
mic_translate, file_translate # For speech-to-speech translation
],
[
"Transcribe Microphone", "Transcribe Audio File",
"Translate Microphone", "Translate Audio File"
]
)
# Launch the app with debugging enabled
if __name__ == "__main__":
demo.launch(debug=True, share=True)