File size: 3,997 Bytes
a9e9df1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f04ac3
a9e9df1
 
 
 
 
3f04ac3
a9e9df1
 
 
 
 
 
3f04ac3
a9e9df1
 
 
 
 
3f04ac3
a9e9df1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import torch
import numpy as np
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset

# Check if a GPU is available and set the device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load the Whisper ASR model
whisper_model_id = "riteshkr/quantized-whisper-large-v3"
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_id)
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_id)

# Set the language to English using forced_decoder_ids
forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="english", task="transcribe")

whisper_pipe = pipeline(
    "automatic-speech-recognition",
    model=whisper_model,
    tokenizer=whisper_processor.tokenizer,
    feature_extractor=whisper_processor.feature_extractor,
    device=0 if torch.cuda.is_available() else -1
)

# Load the SpeechT5 TTS model
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

tts_model.to(device)
vocoder.to(device)

# Load speaker embeddings for TTS
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)

# Set target data type and max range for speech
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max

# Define the transcription function (Whisper ASR)
def transcribe_speech(filepath):
    batch_size = 16 if torch.cuda.is_available() else 4
    output = whisper_pipe(
        filepath,
        max_new_tokens=256,
        generate_kwargs={"forced_decoder_ids": forced_decoder_ids},
        chunk_length_s=30,
        batch_size=batch_size,
    )
    return output["text"]

# Define the synthesis function (SpeechT5 TTS)
def synthesise(text):
    inputs = tts_processor(text=text, return_tensors="pt")
    speech = tts_model.generate_speech(
        inputs["input_ids"].to(device), speaker_embeddings, vocoder=vocoder
    )
    return speech.cpu()

# Define the speech-to-speech translation function
def speech_to_speech_translation(audio):
    # Transcribe speech
    translated_text = transcribe_speech(audio)
    # Synthesize speech
    synthesised_speech = synthesise(translated_text)
    # Convert speech to desired format
    synthesised_speech = (synthesised_speech.numpy() * max_range).astype(np.int16)
    return 16000, synthesised_speech

# Define the Gradio interfaces for microphone input and file upload
mic_translate = gr.Interface(
    fn=speech_to_speech_translation,
    inputs=gr.Audio(sources="microphone", type="filepath"),
    outputs=gr.Audio(label="Generated Speech", type="numpy"),
)

file_translate = gr.Interface(
    fn=speech_to_speech_translation,
    inputs=gr.Audio(sources="upload", type="filepath"),
    outputs=gr.Audio(label="Generated Speech", type="numpy"),
)

# Define the Gradio interfaces for transcription
mic_transcribe = gr.Interface(
    fn=transcribe_speech,
    inputs=gr.Audio(sources="microphone", type="filepath"),
    outputs=gr.Textbox(),
)

file_transcribe = gr.Interface(
    fn=transcribe_speech,
    inputs=gr.Audio(sources="upload", type="filepath"),
    outputs=gr.Textbox(),
)

# Create the app using Gradio Blocks with tabbed interfaces
demo = gr.Blocks()

with demo:
    gr.TabbedInterface(
        [
            mic_transcribe, file_transcribe,  # For transcription
            mic_translate, file_translate  # For speech-to-speech translation
        ],
        [
            "Transcribe Microphone", "Transcribe Audio File",
            "Translate Microphone", "Translate Audio File"
        ]
    )

# Launch the app with debugging enabled
if __name__ == "__main__":
    demo.launch(debug=True, share=True)