import gradio as gr from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan from transformers import pipeline import torch from PIL import Image from datasets import load_dataset import soundfile as sf import random import string import spaces #--- IMAGE CAPTION- def model(): model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") return model feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") if gr.NO_RELOAD: llm_model=model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") llm_model.to(device) max_length = 16 num_beams = 4 gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def predict_step(image_paths): images = [] for image_path in image_paths: i_image = Image.open(image_path) if i_image.mode != "RGB": i_image = i_image.convert(mode="RGB") images.append(i_image) pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) output_ids = llm_model.generate(pixel_values, **gen_kwargs) preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) preds = [pred.strip() for pred in preds] return preds ##----TEXT TO SPEECH # load the processor def load_processor(): processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") return processor # load the model def load_speech_model(): speech_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) return speech_model # load the vocoder, that is the voice def load_vocoder(): vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) return vocoder # we load this dataset to get the speaker embeddings def load_embeddings_dataset(): embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") return embeddings_dataset # speaker ids from the embeddings dataset speakers = { 'awb': 0, # Scottish male 'bdl': 1138, # US male 'clb': 2271, # US female 'jmk': 3403, # Canadian male 'ksp': 4535, # Indian male 'rms': 5667, # US male 'slt': 6799 # US female } def save_text_to_speech(text, speaker=None): # preprocess text inputs = processor(text=text, return_tensors="pt").to(device) if speaker is not None: # load xvector containing speaker's voice characteristics from a dataset speaker_embeddings = torch.tensor(embeddings_dataset[speaker]["xvector"]).unsqueeze(0).to(device) else: # random vector, meaning a random voice speaker_embeddings = torch.randn((1, 512)).to(device) # generate speech with the models speech = speech_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder) if speaker is not None: # if we have a speaker, we use the speaker's ID in the filename output_filename = f"{speaker}-{'-'.join(text.split()[:6])}.mp3" #output_filename = "speech.mp3" else: # if we don't have a speaker, we use a random string in the filename random_str = ''.join(random.sample(string.ascii_letters+string.digits, k=5)) output_filename = f"{random_str}-{'-'.join(text.split()[:6])}.mp3" #output_filename = "speech.mp3" # save the generated speech to a file with 16KHz sampling rate sf.write(output_filename, speech.cpu().numpy(), samplerate=16000) # return the filename for reference return output_filename def load_text_generator(): gen = pipeline('text-generation', model='gpt2') # uses GPT-2 return gen if gr.NO_RELOAD: processor = load_processor() speech_model=load_speech_model() vocoder=load_vocoder() embeddings_dataset = load_embeddings_dataset() gen=load_text_generator() def gradio_predict(image): if image is None: return "" image_path = "temp_image.jpg" image.save(image_path) # Save the uploaded image temporarily prediction = predict_step([image_path]) return prediction[0].capitalize() if prediction else "Prediction failed." import re def remove_last_incomplete_sentence(text): # Find all sentences ending with ., !, or ? sentences = re.findall(r'[^.!?]*[.!?]', text, re.DOTALL) # If there's no complete sentence found, return the original text if not sentences: return text # Join the complete sentences cleaned_text = ''.join(sentences).strip() return cleaned_text @spaces.GPU() def get_story(pred): gen_text=gen(pred, max_length=100,)[0] cleaned_text = remove_last_incomplete_sentence(gen_text['generated_text']) output_filename_2 = save_text_to_speech(cleaned_text, speaker=speakers["slt"]) return cleaned_text, output_filename_2 #---FRONT END DESCRIPTION = """ # PictoVerse ### Dive into the multiverse of storytelling with PictoVerse, where every image unveils an array of parallel dimensions. PictoVerse crafts captivating narratives from your photos, each set in a distinct universe of its own. """ with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type='pil', label="Image") clear_button = gr.Button("Clear") with gr.Column(scale=4): output_text = gr.Textbox(label="Prediction") gen_text = gr.Textbox(label="Generated Story") output_filename_2=gr.Audio(label='Audio') button1 = gr.Button("Generate Story and Audio") button1.click(fn=get_story, inputs=output_text, outputs=[gen_text, output_filename_2]) input_image.change(fn=gradio_predict, inputs=input_image, outputs=output_text) clear_button.click(lambda: (None, "", "", None), inputs=[], outputs=[input_image, output_text, gen_text, output_filename_2]) demo.launch()