import gradio as gr import torch import torchaudio from transformers import AutoTokenizer, AutoModelForCausalLM from speechtokenizer import SpeechTokenizer from audiotools import AudioSignal import bitsandbytes as bnb # Import bitsandbytes for INT8 quantization import numpy as np from uuid import uuid4 # Load the necessary models and tokenizers model_path = "Vikhrmodels/salt-116k" tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".") # Специальные токены start_audio_token = "" end_audio_token = "" end_sequence_token = "" # Константы n_codebooks = 3 max_seq_length = 1024 top_k = 20 from safetensors.torch import load_file def convert_to_16_bit_wav(data): if data.dtype == np.float32: data = data / np.abs(data).max() data = data * 32767 data = data.astype(np.int16) elif data.dtype == np.int32: data = data / 65538 data = data.astype(np.int16) elif data.dtype == np.int16: pass elif data.dtype == np.uint8: data = data * 257 - 32768 data = data.astype(np.int16) else: raise ValueError("Audio data cannot be converted to 16-bit int format.") return data device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the model with INT8 quantization model = AutoModelForCausalLM.from_pretrained( model_path, cache_dir=".", load_in_8bit=False, # Enable loading in INT8 device_map="auto" # Automatically map model to available devices ) # Configurations for Speech Tokenizer config_path = "audiotokenizer/speechtokenizer_hubert_avg_config.json" ckpt_path = "audiotokenizer/SpeechTokenizer.pt" quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path) quantizer.eval() # Freeze layers in the quantizer def freeze_entire_model(model): for n, p in model.named_parameters(): p.requires_grad = False return model for n, child in quantizer.named_children(): child.to(device) child = freeze_entire_model(child) # Create padding tokens for audio def get_audio_padding_tokens(quantizer): audio = torch.zeros((1, 1, 1)).to(device) codes = quantizer.encode(audio) del audio torch.cuda.empty_cache() return {"audio_tokens": codes.squeeze(1)} # Decode audio from tokens def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens): start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1]) end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1]) start = start[0, -1] + 1 if len(start) else 0 end = end[0, -1] if len(end) else tokens.shape[-1] audio_tokens = tokens[start:end] % n_original_tokens reminder = audio_tokens.shape[-1] % n_codebooks if reminder: audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0) transposed = audio_tokens.view(-1, n_codebooks).t() codes = transposed.view(n_codebooks, 1, -1).to(device) audio = quantizer.decode(codes).squeeze(0) torch.cuda.empty_cache() xp = str(uuid4())+'.wav' AudioSignal(audio.detach().cpu().numpy(),quantizer.sample_rate).write(xp) return xp # Inference functions def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20): print(text) text_tokenized = tokenizer(str(text), return_tensors="pt") text_input_tokens = text_tokenized["input_ids"].to(device) soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) text_tokens = torch.cat([text_input_tokens, soa], dim=1) attention_mask = torch.ones(text_tokens.size(), device=device) output_audio_tokens = model.generate(text_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True) padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"].to(device) audio_signal = decode_audio(output_audio_tokens[0], quantizer, padding_tokens.t()[0], len(tokenizer) - 1024) return audio_signal def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20): audio_data, sample_rate = torchaudio.load(audio_path) audio = audio_data.view(1, 1, -1).float().to(device) codes = quantizer.encode(audio) n_codebooks_a = 1 raw_audio_tokens = codes[:, :n_codebooks_a] + len(tokenizer) - 1024 soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1) attention_mask = torch.ones(audio_tokens.size(), device=device) output_text_tokens = model.generate(audio_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True) output_text_tokens = output_text_tokens.cpu()[0] output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]] decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True) return decoded_text # Functions for Gradio Interface def infer_text_to_audio_gr(text): audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer) return audio_signal def infer_audio_to_text_gr(audio_path): generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer) return generated_text # Gradio Interface text_to_audio_interface = gr.Interface( fn=infer_text_to_audio_gr, inputs=gr.Textbox(label="Input Text"), outputs=gr.Audio(label="Audio Answer"), title="T2S", description="Model in text to audio mode", allow_flagging='never', ) audio_to_text_interface = gr.Interface( fn=infer_audio_to_text_gr, inputs=gr.Audio(type="filepath", label="Input Audio"), outputs=gr.Textbox(label="Text Answer"), title="S2T", description="Model in audio to text mode", allow_flagging='never' ) # Gradio Demo #demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Text - Audio", "Audio - Text"]) # Custom CSS for centered links custom_css = """ """ # Add Gradio description with centered links description = f""" # **Salt: Speech And Language Transformer** Welcome to the demo of **Salt**, a speech and language model. Vikhr Salt is capable of both **Text-to-Speech (T2S)** and **Speech-to-Text (S2T)** tasks, making it a versatile tool for transforming language into speech and vice versa. Built on a pre-trained large language model, Vikhr Salt incorporates audio tokens using cutting-edge techniques like **Encodec** and **SpeechTokenizer**, enabling robust performance across multiple modalities. ## **🛠 Features** - **Text-to-Speech (T2S)**: Enter text and generate high-quality audio outputs. - **Speech-to-Text (S2T)**: Upload an audio file and convert it into accurate text. ## **🚀 Try it out:** Explore the tabs to try the **Text - Audio** and **Audio - Text** modes! ### **📄 Preprint** [Read the paper](https://docs.google.com/document/d/1ZvV47W4BCyZM_JfDC1BKj-0ozwPck5t2yNB8jORVshI/edit?usp=sharing) ### **📂 Code** [Explore the code](https://github.com/VikhrModels/Vikhr4o) """ with gr.Blocks() as demo: gr.Markdown(description) with gr.Tabs(): with gr.TabItem("Text - Audio"): gr.Markdown("### Text-to-Speech (T2S) Mode") input_text = gr.Textbox(label="Input Text") output_audio = gr.Audio(label="Audio Answer") generate_button = gr.Button("Generate") generate_button.click(infer_text_to_audio, inputs=input_text, outputs=output_audio) with gr.TabItem("Audio - Text"): gr.Markdown("### Speech-to-Text (S2T) Mode") input_audio = gr.Audio(type="filepath", label="Input Audio") output_text = gr.Textbox(label="Text Answer") generate_button = gr.Button("Generate") generate_button.click(infer_audio_to_text, inputs=input_audio, outputs=output_text) # Launch the demo demo.launch(share=True)