gfgf / akn.py
Ffftdtd5dtft's picture
Create akn.py
485b791 verified
import os
import pickle
import torch
from PIL import Image
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
FluxPipeline,
DiffusionPipeline,
DPMSolverMultistepScheduler,
)
from transformers import (
pipeline as transformers_pipeline,
AutoModelForCausalLM,
AutoTokenizer,
GPT2Tokenizer,
GPT2Model,
AutoModel
)
from audiocraft.models import musicgen
import gradio as gr
from huggingface_hub import snapshot_download, HfApi, HfFolder
import io
import time
from tqdm import tqdm
from google.cloud import storage
import json
hf_token = os.getenv("HF_TOKEN")
gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS"))
gcs_bucket_name = os.getenv("GCS_BUCKET_NAME")
HfFolder.save_token(hf_token)
storage_client = storage.Client.from_service_account_info(gcs_credentials)
bucket = storage_client.bucket(gcs_bucket_name)
def load_object_from_gcs(blob_name):
blob = bucket.blob(blob_name)
if blob.exists():
return pickle.loads(blob.download_as_bytes())
return None
def save_object_to_gcs(blob_name, obj):
blob = bucket.blob(blob_name)
blob.upload_from_string(pickle.dumps(obj))
def get_model_or_download(model_id, blob_name, loader_func):
model = load_object_from_gcs(blob_name)
if model:
return model
try:
with tqdm(total=1, desc=f"Downloading {model_id}") as pbar:
model = loader_func(model_id, torch_dtype=torch.float16)
pbar.update(1)
save_object_to_gcs(blob_name, model)
return model
except Exception as e:
print(f"Failed to load or save model: {e}")
return None
def generate_image(prompt):
blob_name = f"diffusers/generated_image:{prompt}"
image_bytes = load_object_from_gcs(blob_name)
if not image_bytes:
try:
with tqdm(total=1, desc="Generating image") as pbar:
image = text_to_image_pipeline(prompt).images[0]
pbar.update(1)
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
save_object_to_gcs(blob_name, image_bytes)
except Exception as e:
print(f"Failed to generate image: {e}")
return None
return image_bytes
def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
blob_name = f"diffusers/edited_image:{prompt}:{strength}"
edited_image_bytes = load_object_from_gcs(blob_name)
if not edited_image_bytes:
try:
image = Image.open(io.BytesIO(image_bytes))
with tqdm(total=1, desc="Editing image") as pbar:
edited_image = img2img_pipeline(
prompt=prompt, image=image, strength=strength
).images[0]
pbar.update(1)
buffered = io.BytesIO()
edited_image.save(buffered, format="JPEG")
edited_image_bytes = buffered.getvalue()
save_object_to_gcs(blob_name, edited_image_bytes)
except Exception as e:
print(f"Failed to edit image: {e}")
return None
return edited_image_bytes
def generate_song(prompt, duration=10):
blob_name = f"music/generated_song:{prompt}:{duration}"
song_bytes = load_object_from_gcs(blob_name)
if not song_bytes:
try:
with tqdm(total=1, desc="Generating song") as pbar:
song = music_gen(prompt, duration=duration)
pbar.update(1)
song_bytes = song[0].getvalue()
save_object_to_gcs(blob_name, song_bytes)
except Exception as e:
print(f"Failed to generate song: {e}")
return None
return song_bytes
def generate_text(prompt):
blob_name = f"transformers/generated_text:{prompt}"
text = load_object_from_gcs(blob_name)
if not text:
try:
with tqdm(total=1, desc="Generating text") as pbar:
text = text_gen_pipeline(prompt, max_new_tokens=256)[0][
"generated_text"
].strip()
pbar.update(1)
save_object_to_gcs(blob_name, text)
except Exception as e:
print(f"Failed to generate text: {e}")
return None
return text
def generate_flux_image(prompt):
blob_name = f"diffusers/generated_flux_image:{prompt}"
flux_image_bytes = load_object_from_gcs(blob_name)
if not flux_image_bytes:
try:
with tqdm(total=1, desc="Generating FLUX image") as pbar:
flux_image = flux_pipeline(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_length=256,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
pbar.update(1)
buffered = io.BytesIO()
flux_image.save(buffered, format="JPEG")
flux_image_bytes = buffered.getvalue()
save_object_to_gcs(blob_name, flux_image_bytes)
except Exception as e:
print(f"Failed to generate flux image: {e}")
return None
return flux_image_bytes
def generate_code(prompt):
blob_name = f"transformers/generated_code:{prompt}"
code = load_object_from_gcs(blob_name)
if not code:
try:
with tqdm(total=1, desc="Generating code") as pbar:
inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt")
outputs = starcoder_model.generate(inputs, max_new_tokens=256)
code = starcoder_tokenizer.decode(outputs[0])
pbar.update(1)
save_object_to_gcs(blob_name, code)
except Exception as e:
print(f"Failed to generate code: {e}")
return None
return code
def test_model_meta_llama():
blob_name = "transformers/meta_llama_test_response"
response = load_object_from_gcs(blob_name)
if not response:
try:
messages = [
{
"role": "system",
"content": "You are a pirate chatbot who always responds in pirate speak!",
},
{"role": "user", "content": "Who are you?"},
]
with tqdm(total=1, desc="Testing Meta-Llama") as pbar:
response = meta_llama_pipeline(messages, max_new_tokens=256)[0][
"generated_text"
].strip()
pbar.update(1)
save_object_to_gcs(blob_name, response)
except Exception as e:
print(f"Failed to test Meta-Llama: {e}")
return None
return response
def generate_image_sdxl(prompt):
blob_name = f"diffusers/generated_image_sdxl:{prompt}"
image_bytes = load_object_from_gcs(blob_name)
if not image_bytes:
try:
with tqdm(total=1, desc="Generating SDXL image") as pbar:
image = base(
prompt=prompt,
num_inference_steps=40,
denoising_end=0.8,
output_type="latent",
).images
image = refiner(
prompt=prompt,
num_inference_steps=40,
denoising_start=0.8,
image=image,
).images[0]
pbar.update(1)
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
save_object_to_gcs(blob_name, image_bytes)
except Exception as e:
print(f"Failed to generate SDXL image: {e}")
return None
return image_bytes
def generate_musicgen_melody(prompt):
blob_name = f"music/generated_musicgen_melody:{prompt}"
song_bytes = load_object_from_gcs(blob_name)
if not song_bytes:
try:
with tqdm(total=1, desc="Generating MusicGen melody") as pbar:
melody, sr = torchaudio.load("./assets/bach.mp3")
wav = music_gen_melody.generate_with_chroma(
[prompt], melody[None].expand(3, -1, -1), sr
)
pbar.update(1)
song_bytes = wav[0].getvalue()
save_object_to_gcs(blob_name, song_bytes)
except Exception as e:
print(f"Failed to generate MusicGen melody: {e}")
return None
return song_bytes
def generate_musicgen_large(prompt):
blob_name = f"music/generated_musicgen_large:{prompt}"
song_bytes = load_object_from_gcs(blob_name)
if not song_bytes:
try:
with tqdm(total=1, desc="Generating MusicGen large") as pbar:
wav = music_gen_large.generate([prompt])
pbar.update(1)
song_bytes = wav[0].getvalue()
save_object_to_gcs(blob_name, song_bytes)
except Exception as e:
print(f"Failed to generate MusicGen large: {e}")
return None
return song_bytes
def transcribe_audio(audio_sample):
blob_name = f"transformers/transcribed_audio:{hash(audio_sample.tobytes())}"
text = load_object_from_gcs(blob_name)
if not text:
try:
with tqdm(total=1, desc="Transcribing audio") as pbar:
text = whisper_pipeline(audio_sample.copy(), batch_size=8)["text"]
pbar.update(1)
save_object_to_gcs(blob_name, text)
except Exception as e:
print(f"Failed to transcribe audio: {e}")
return None
return text
def generate_mistral_instruct(prompt):
blob_name = f"transformers/generated_mistral_instruct:{prompt}"
response = load_object_from_gcs(blob_name)
if not response:
try:
conversation = [{"role": "user", "content": prompt}]
with tqdm(total=1, desc="Generating Mistral Instruct response") as pbar:
inputs = mistral_instruct_tokenizer.apply_chat_template(
conversation,
tools=tools,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
outputs = mistral_instruct_model.generate(
**inputs, max_new_tokens=1000
)
response = mistral_instruct_tokenizer.decode(
outputs[0], skip_special_tokens=True
)
pbar.update(1)
save_object_to_gcs(blob_name, response)
except Exception as e:
print(f"Failed to generate Mistral Instruct response: {e}")
return None
return response
def generate_mistral_nemo(prompt):
blob_name = f"transformers/generated_mistral_nemo:{prompt}"
response = load_object_from_gcs(blob_name)
if not response:
try:
conversation = [{"role": "user", "content": prompt}]
with tqdm(total=1, desc="Generating Mistral Nemo response") as pbar:
inputs = mistral_nemo_tokenizer.apply_chat_template(
conversation,
tools=tools,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
outputs = mistral_nemo_model.generate(**inputs, max_new_tokens=1000)
response = mistral_nemo_tokenizer.decode(
outputs[0], skip_special_tokens=True
)
pbar.update(1)
save_object_to_gcs(blob_name, response)
except Exception as e:
print(f"Failed to generate Mistral Nemo response: {e}")
return None
return response
def generate_gpt2_xl(prompt):
blob_name = f"transformers/generated_gpt2_xl:{prompt}"
response = load_object_from_gcs(blob_name)
if not response:
try:
with tqdm(total=1, desc="Generating GPT-2 XL response") as pbar:
inputs = gpt2_xl_tokenizer(prompt, return_tensors="pt")
outputs = gpt2_xl_model(**inputs)
response = gpt2_xl_tokenizer.decode(
outputs[0][0], skip_special_tokens=True
)
pbar.update(1)
save_object_to_gcs(blob_name, response)
except Exception as e:
print(f"Failed to generate GPT-2 XL response: {e}")
return None
return response
def store_user_question(question):
blob_name = "user_questions.txt"
blob = bucket.blob(blob_name)
if blob.exists():
blob.download_to_filename("user_questions.txt")
with open("user_questions.txt", "a") as f:
f.write(question + "\n")
blob.upload_from_filename("user_questions.txt")
def retrain_models():
pass
def generate_text_to_video_ms_1_7b(prompt, num_frames=200):
blob_name = f"diffusers/text_to_video_ms_1_7b:{prompt}:{num_frames}"
video_bytes = load_object_from_gcs(blob_name)
if not video_bytes:
try:
with tqdm(total=1, desc="Generating video") as pbar:
video_frames = text_to_video_ms_1_7b_pipeline(
prompt, num_inference_steps=25, num_frames=num_frames
).frames
pbar.update(1)
video_path = export_to_video(video_frames)
with open(video_path, "rb") as f:
video_bytes = f.read()
save_object_to_gcs(blob_name, video_bytes)
os.remove(video_path)
except Exception as e:
print(f"Failed to generate video: {e}")
return None
return video_bytes
def generate_text_to_video_ms_1_7b_short(prompt):
blob_name = f"diffusers/text_to_video_ms_1_7b_short:{prompt}"
video_bytes = load_object_from_gcs(blob_name)
if not video_bytes:
try:
with tqdm(total=1, desc="Generating short video") as pbar:
video_frames = text_to_video_ms_1_7b_short_pipeline(
prompt, num_inference_steps=25
).frames
pbar.update(1)
video_path = export_to_video(video_frames)
with open(video_path, "rb") as f:
video_bytes = f.read()
save_object_to_gcs(blob_name, video_bytes)
os.remove(video_path)
except Exception as e:
print(f"Failed to generate short video: {e}")
return None
return video_bytes
text_to_image_pipeline = get_model_or_download(
"stabilityai/stable-diffusion-2",
"diffusers/text_to_image_model",
StableDiffusionPipeline.from_pretrained,
)
img2img_pipeline = get_model_or_download(
"CompVis/stable-diffusion-v1-4",
"diffusers/img2img_model",
StableDiffusionImg2ImgPipeline.from_pretrained,
)
flux_pipeline = get_model_or_download(
"black-forest-labs/FLUX.1-schnell",
"diffusers/flux_model",
FluxPipeline.from_pretrained,
)
text_gen_pipeline = transformers_pipeline(
"text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b"
)
music_gen = (
load_object_from_gcs("music/music_gen")
or musicgen.MusicGen.get_pretrained("melody")
)
meta_llama_pipeline = get_model_or_download(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"transformers/meta_llama_model",
transformers_pipeline,
)
starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
music_gen_melody = musicgen.MusicGen.get_pretrained("melody")
music_gen_melody.set_generation_params(duration=8)
music_gen_large = musicgen.MusicGen.get_pretrained("large")
music_gen_large.set_generation_params(duration=8)
whisper_pipeline = transformers_pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
chunk_length_s=30,
)
mistral_instruct_model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-Large-Instruct-2407",
torch_dtype=torch.bfloat16,
device_map="auto",
)
mistral_instruct_tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-Large-Instruct-2407"
)
mistral_nemo_model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-Nemo-Instruct-2407",
torch_dtype=torch.bfloat16,
device_map="auto",
)
mistral_nemo_tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-Nemo-Instruct-2407"
)
gpt2_xl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
gpt2_xl_model = GPT2Model.from_pretrained("gpt2-xl")
llama_3_groq_70b_tool_use_pipeline = transformers_pipeline(
"text-generation", model="Groq/Llama-3-Groq-70B-Tool-Use"
)
phi_3_5_mini_instruct_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3.5-mini-instruct", torch_dtype="auto", trust_remote_code=True
)
phi_3_5_mini_instruct_tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3.5-mini-instruct"
)
phi_3_5_mini_instruct_pipeline = transformers_pipeline(
"text-generation",
model=phi_3_5_mini_instruct_model,
tokenizer=phi_3_5_mini_instruct_tokenizer,
)
meta_llama_3_1_8b_pipeline = transformers_pipeline(
"text-generation",
model="meta-llama/Meta-Llama-3.1-8B",
model_kwargs={"torch_dtype": torch.bfloat16},
)
meta_llama_3_1_70b_pipeline = transformers_pipeline(
"text-generation",
model="meta-llama/Meta-Llama-3.1-70B",
model_kwargs={"torch_dtype": torch.bfloat16},
)
medical_text_summarization_pipeline = transformers_pipeline(
"summarization", model="your/medical_text_summarization_model"
)
bart_large_cnn_summarization_pipeline = transformers_pipeline(
"summarization", model="facebook/bart-large-cnn"
)
flux_1_dev_pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
flux_1_dev_pipeline.enable_model_cpu_offload()
gemma_2_9b_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b")
gemma_2_9b_it_pipeline = transformers_pipeline(
"text-generation",
model="google/gemma-2-9b-it",
model_kwargs={"torch_dtype": torch.bfloat16},
)
gemma_2_2b_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-2b")
gemma_2_2b_it_pipeline = transformers_pipeline(
"text-generation",
model="google/gemma-2-2b-it",
model_kwargs={"torch_dtype": torch.bfloat16},
)
gemma_2_27b_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b")
gemma_2_27b_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-27b")
gemma_2_27b_it_pipeline = transformers_pipeline(
"text-generation",
model="google/gemma-2-27b-it",
model_kwargs={"torch_dtype": torch.bfloat16},
)
text_to_video_ms_1_7b_pipeline = DiffusionPipeline.from_pretrained(
"damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
)
text_to_video_ms_1_7b_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
text_to_video_ms_1_7b_pipeline.scheduler.config
)
text_to_video_ms_1_7b_pipeline.enable_model_cpu_offload()
text_to_video_ms_1_7b_pipeline.enable_vae_slicing()
text_to_video_ms_1_7b_short_pipeline = DiffusionPipeline.from_pretrained(
"damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
)
text_to_video_ms_1_7b_short_pipeline.scheduler = (
DPMSolverMultistepScheduler.from_config(
text_to_video_ms_1_7b_short_pipeline.scheduler.config
)
)
text_to_video_ms_1_7b_short_pipeline.enable_model_cpu_offload()
tools = []
gen_image_tab = gr.Interface(
fn=generate_image,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Image(type="pil"),
title="Generate Image",
)
edit_image_tab = gr.Interface(
fn=edit_image_with_prompt,
inputs=[
gr.Image(type="pil", label="Image:"),
gr.Textbox(label="Prompt:"),
gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:"),
],
outputs=gr.Image(type="pil"),
title="Edit Image",
)
generate_song_tab = gr.Interface(
fn=generate_song,
inputs=[
gr.Textbox(label="Prompt:"),
gr.Slider(5, 60, 10, step=1, label="Duration (s):"),
],
outputs=gr.Audio(type="numpy"),
title="Generate Songs",
)
generate_text_tab = gr.Interface(
fn=generate_text,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Textbox(label="Generated Text:"),
title="Generate Text",
)
generate_flux_image_tab = gr.Interface(
fn=generate_flux_image,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Image(type="pil"),
title="Generate FLUX Images",
)
generate_code_tab = gr.Interface(
fn=generate_code,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Textbox(label="Generated Code:"),
title="Generate Code",
)
model_meta_llama_test_tab = gr.Interface(
fn=test_model_meta_llama,
inputs=None,
outputs=gr.Textbox(label="Model Output:"),
title="Test Meta-Llama",
)
generate_image_sdxl_tab = gr.Interface(
fn=generate_image_sdxl,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Image(type="pil"),
title="Generate SDXL Image",
)
generate_musicgen_melody_tab = gr.Interface(
fn=generate_musicgen_melody,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Audio(type="numpy"),
title="Generate MusicGen Melody",
)
generate_musicgen_large_tab = gr.Interface(
fn=generate_musicgen_large,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Audio(type="numpy"),
title="Generate MusicGen Large",
)
transcribe_audio_tab = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(type="numpy", label="Audio Sample:"),
outputs=gr.Textbox(label="Transcribed Text:"),
title="Transcribe Audio",
)
generate_mistral_instruct_tab = gr.Interface(
fn=generate_mistral_instruct,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Textbox(label="Mistral Instruct Response:"),
title="Generate Mistral Instruct Response",
)
generate_mistral_nemo_tab = gr.Interface(
fn=generate_mistral_nemo,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Textbox(label="Mistral Nemo Response:"),
title="Generate Mistral Nemo Response",
)
generate_gpt2_xl_tab = gr.Interface(
fn=generate_gpt2_xl,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Textbox(label="GPT-2 XL Response:"),
title="Generate GPT-2 XL Response",
)
answer_question_minicpm_tab = gr.Interface(
fn=answer_question_minicpm,
inputs=[
gr.Image(type="pil", label="Image:"),
gr.Textbox(label="Question:"),
],
outputs=gr.Textbox(label="MiniCPM Answer:"),
title="Answer Question with MiniCPM",
)
llama_3_groq_70b_tool_use_tab = gr.Interface(
fn=llama_3_groq_70b_tool_use_pipeline,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Textbox(label="Llama 3 Groq 70B Tool Use Response:"),
title="Llama 3 Groq 70B Tool Use",
)
phi_3_5_mini_instruct_tab = gr.Interface(
fn=phi_3_5_mini_instruct_pipeline,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Textbox(label="Phi 3.5 Mini Instruct Response:"),
title="Phi 3.5 Mini Instruct",
)
meta_llama_3_1_8b_tab = gr.Interface(
fn=meta_llama_3_1_8b_pipeline,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Textbox(label="Meta Llama 3.1 8B Response:"),
title="Meta Llama 3.1 8B",
)
meta_llama_3_1_70b_tab = gr.Interface(
fn=meta_llama_3_1_70b_pipeline,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Textbox(label="Meta Llama 3.1 70B Response:"),
title="Meta Llama 3.1 70B",
)
medical_text_summarization_tab = gr.Interface(
fn=medical_text_summarization_pipeline,
inputs=[gr.Textbox(label="Medical Document:")],
outputs=gr.Textbox(label="Medical Text Summarization:"),
title="Medical Text Summarization",
)
bart_large_cnn_summarization_tab = gr.Interface(
fn=bart_large_cnn_summarization_pipeline,
inputs=[gr.Textbox(label="Article:")],
outputs=gr.Textbox(label="Bart Large CNN Summarization:"),
title="Bart Large CNN Summarization",
)
flux_1_dev_tab = gr.Interface(
fn=flux_1_dev_pipeline,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Image(type="pil"),
title="FLUX 1 Dev",
)
gemma_2_9b_tab = gr.Interface(
fn=gemma_2_9b_pipeline,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Textbox(label="Gemma 2 9B Response:"),
title="Gemma 2 9B",
)
gemma_2_9b_it_tab = gr.Interface(
fn=gemma_2_9b_it_pipeline,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Textbox(label="Gemma 2 9B IT Response:"),
title="Gemma 2 9B IT",
)
gemma_2_2b_tab = gr.Interface(
fn=gemma_2_2b_pipeline,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Textbox(label="Gemma 2 2B Response:"),
title="Gemma 2 2B",
)
gemma_2_2b_it_tab = gr.Interface(
fn=gemma_2_2b_it_pipeline,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Textbox(label="Gemma 2 2B IT Response:"),
title="Gemma 2 2B IT",
)
def generate_gemma_2_27b(prompt):
input_ids = gemma_2_27b_tokenizer(prompt, return_tensors="pt")
outputs = gemma_2_27b_model.generate(**input_ids, max_new_tokens=32)
return gemma_2_27b_tokenizer.decode(outputs[0])
gemma_2_27b_tab = gr.Interface(
fn=generate_gemma_2_27b,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Textbox(label="Gemma 2 27B Response:"),
title="Gemma 2 27B",
)
gemma_2_27b_it_tab = gr.Interface(
fn=gemma_2_27b_it_pipeline,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Textbox(label="Gemma 2 27B IT Response:"),
title="Gemma 2 27B IT",
)
text_to_video_ms_1_7b_tab = gr.Interface(
fn=generate_text_to_video_ms_1_7b,
inputs=[
gr.Textbox(label="Prompt:"),
gr.Slider(50, 200, 200, step=1, label="Number of Frames:"),
],
outputs=gr.Video(),
title="Text to Video MS 1.7B",
)
text_to_video_ms_1_7b_short_tab = gr.Interface(
fn=generate_text_to_video_ms_1_7b_short,
inputs=[gr.Textbox(label="Prompt:")],
outputs=gr.Video(),
title="Text to Video MS 1.7B Short",
)
app = gr.TabbedInterface(
[
gen_image_tab,
edit_image_tab,
generate_song_tab,
generate_text_tab,
generate_flux_image_tab,
generate_code_tab,
model_meta_llama_test_tab,
generate_image_sdxl_tab,
generate_musicgen_melody_tab,
generate_musicgen_large_tab,
transcribe_audio_tab,
generate_mistral_instruct_tab,
generate_mistral_nemo_tab,
generate_gpt2_xl_tab,
llama_3_groq_70b_tool_use_tab,
phi_3_5_mini_instruct_tab,
meta_llama_3_1_8b_tab,
meta_llama_3_1_70b_tab,
medical_text_summarization_tab,
bart_large_cnn_summarization_tab,
flux_1_dev_tab,
gemma_2_9b_tab,
gemma_2_9b_it_tab,
gemma_2_2b_tab,
gemma_2_2b_it_tab,
gemma_2_27b_tab,
gemma_2_27b_it_tab,
text_to_video_ms_1_7b_tab,
text_to_video_ms_1_7b_short_tab,
],
[
"Generate Image",
"Edit Image",
"Generate Song",
"Generate Text",
"Generate FLUX Image",
"Generate Code",
"Test Meta-Llama",
"Generate SDXL Image",
"Generate MusicGen Melody",
"Generate MusicGen Large",
"Transcribe Audio",
"Generate Mistral Instruct Response",
"Generate Mistral Nemo Response",
"Generate GPT-2 XL Response",
"Llama 3 Groq 70B Tool Use",
"Phi 3.5 Mini Instruct",
"Meta Llama 3.1 8B",
"Meta Llama 3.1 70B",
"Medical Text Summarization",
"Bart Large CNN Summarization",
"FLUX 1 Dev",
"Gemma 2 9B",
"Gemma 2 9B IT",
"Gemma 2 2B",
"Gemma 2 2B IT",
"Gemma 2 27B",
"Gemma 2 27B IT",
"Text to Video MS 1.7B",
"Text to Video MS 1.7B Short",
],
)
app.launch(share=True)