gfgf / apxxp.py
Ffftdtd5dtft's picture
Rename app.py to apxxp.py
5196a69 verified
import os
import redis
import pickle
import torch
from PIL import Image
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
FluxPipeline,
DiffusionPipeline,
DPMSolverMultistepScheduler,
)
from diffusers.utils import export_to_video
from transformers import (
pipeline as transformers_pipeline,
AutoModelForCausalLM,
AutoTokenizer,
GPT2Tokenizer,
GPT2Model,
)
from audiocraft.models import musicgen
import gradio as gr
from huggingface_hub import snapshot_download, HfApi, HfFolder
import multiprocessing
import io
import time
from tqdm import tqdm
from google.cloud import storage
import json
hf_token = os.getenv("HF_TOKEN")
redis_host = os.getenv("REDIS_HOST")
redis_port = int(os.getenv("REDIS_PORT", 6379))
redis_password = os.getenv("REDIS_PASSWORD")
gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS"))
gcs_bucket_name = os.getenv("GCS_BUCKET_NAME")
HfFolder.save_token(hf_token)
storage_client = storage.Client.from_service_account_info(gcs_credentials)
def connect_to_redis():
while True:
try:
redis_client = redis.Redis(
host=redis_host, port=redis_port, password=redis_password
)
redis_client.ping()
return redis_client
except (
redis.exceptions.ConnectionError,
redis.exceptions.TimeoutError,
BrokenPipeError,
) as e:
print(f"Connection to Redis failed: {e}. Retrying in 1 second...")
time.sleep(1)
def reconnect_if_needed(redis_client):
try:
redis_client.ping()
except (
redis.exceptions.ConnectionError,
redis.exceptions.TimeoutError,
BrokenPipeError,
):
print("Reconnecting to Redis...")
return connect_to_redis()
return redis_client
def load_object_from_redis(key):
redis_client = connect_to_redis()
redis_client = reconnect_if_needed(redis_client)
try:
obj_data = redis_client.get(key)
return pickle.loads(obj_data) if obj_data else None
except (pickle.PickleError, redis.exceptions.RedisError) as e:
print(f"Failed to load object from Redis: {e}")
return None
def save_object_to_redis(key, obj):
redis_client = connect_to_redis()
redis_client = reconnect_if_needed(redis_client)
try:
redis_client.set(key, pickle.dumps(obj))
except redis.exceptions.RedisError as e:
print(f"Failed to save object to Redis: {e}")
def upload_to_gcs(bucket_name, blob_name, data):
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
blob.upload_from_string(data)
def download_from_gcs(bucket_name, blob_name):
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
return blob.download_as_bytes()
def get_model_or_download(model_id, redis_key, loader_func):
model = load_object_from_redis(redis_key)
if model:
return model
try:
with tqdm(total=1, desc=f"Downloading {model_id}") as pbar:
model = loader_func(model_id, torch_dtype=torch.float16)
pbar.update(1)
save_object_to_redis(redis_key, model)
model_bytes = pickle.dumps(model)
upload_to_gcs(gcs_bucket_name, redis_key, model_bytes)
return model
except Exception as e:
print(f"Failed to load or save model: {e}")
return None
def generate_image(prompt):
redis_key = f"generated_image:{prompt}"
image_bytes = load_object_from_redis(redis_key)
if not image_bytes:
try:
with tqdm(total=1, desc="Generating image") as pbar:
image = text_to_image_pipeline(prompt).images[0]
pbar.update(1)
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
save_object_to_redis(redis_key, image_bytes)
upload_to_gcs(gcs_bucket_name, redis_key, image_bytes)
except Exception as e:
print(f"Failed to generate image: {e}")
return None
return image_bytes
def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
redis_key = f"edited_image:{prompt}:{strength}"
edited_image_bytes = load_object_from_redis(redis_key)
if not edited_image_bytes:
try:
image = Image.open(io.BytesIO(image_bytes))
with tqdm(total=1, desc="Editing image") as pbar:
edited_image = img2img_pipeline(
prompt=prompt, image=image, strength=strength
).images[0]
pbar.update(1)
buffered = io.BytesIO()
edited_image.save(buffered, format="JPEG")
edited_image_bytes = buffered.getvalue()
save_object_to_redis(redis_key, edited_image_bytes)
upload_to_gcs(gcs_bucket_name, redis_key, edited_image_bytes)
except Exception as e:
print(f"Failed to edit image: {e}")
return None
return edited_image_bytes
def generate_song(prompt, duration=10):
redis_key = f"generated_song:{prompt}:{duration}"
song_bytes = load_object_from_redis(redis_key)
if not song_bytes:
try:
with tqdm(total=1, desc="Generating song") as pbar:
song = music_gen(prompt, duration=duration)
pbar.update(1)
song_bytes = song[0].getvalue()
save_object_to_redis(redis_key, song_bytes)
upload_to_gcs(gcs_bucket_name, redis_key, song_bytes)
except Exception as e:
print(f"Failed to generate song: {e}")
return None
return song_bytes
def generate_text(prompt):
redis_key = f"generated_text:{prompt}"
text = load_object_from_redis(redis_key)
if not text:
try:
with tqdm(total=1, desc="Generating text") as pbar:
text = text_gen_pipeline(prompt, max_new_tokens=256)[0][
"generated_text"
].strip()
pbar.update(1)
save_object_to_redis(redis_key, text)
upload_to_gcs(gcs_bucket_name, redis_key, text.encode())
except Exception as e:
print(f"Failed to generate text: {e}")
return None
return text
def generate_flux_image(prompt):
redis_key = f"generated_flux_image:{prompt}"
flux_image_bytes = load_object_from_redis(redis_key)
if not flux_image_bytes:
try:
with tqdm(total=1, desc="Generating FLUX image") as pbar:
flux_image = flux_pipeline(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_length=256,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
pbar.update(1)
buffered = io.BytesIO()
flux_image.save(buffered, format="JPEG")
flux_image_bytes = buffered.getvalue()
save_object_to_redis(redis_key, flux_image_bytes)
upload_to_gcs(gcs_bucket_name, redis_key, flux_image_bytes)
except Exception as e:
print(f"Failed to generate flux image: {e}")
return None
return flux_image_bytes
def generate_code(prompt):
redis_key = f"generated_code:{prompt}"
code = load_object_from_redis(redis_key)
if not code:
try:
with tqdm(total=1, desc="Generating code") as pbar:
inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to(
starcoder_model.device
)
outputs = starcoder_model.generate(inputs, max_new_tokens=256)
code = starcoder_tokenizer.decode(outputs[0])
pbar.update(1)
save_object_to_redis(redis_key, code)
upload_to_gcs(gcs_bucket_name, redis_key, code.encode())
except Exception as e:
print(f"Failed to generate code: {e}")
return None
return code
def test_model_meta_llama():
redis_key = "meta_llama_test_response"
response = load_object_from_redis(redis_key)
if not response:
try:
messages = [
{
"role": "system",
"content": "You are a pirate chatbot who always responds in pirate speak!",
},
{"role": "user", "content": "Who are you?"},
]
with tqdm(total=1, desc="Testing Meta-Llama") as pbar:
response = meta_llama_pipeline(messages, max_new_tokens=256)[0][
"generated_text"
].strip()
pbar.update(1)
save_object_to_redis(redis_key, response)
upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
except Exception as e:
print(f"Failed to test Meta-Llama: {e}")
return None
return response
def generate_image_sdxl(prompt):
redis_key = f"generated_image_sdxl:{prompt}"
image_bytes = load_object_from_redis(redis_key)
if not image_bytes:
try:
with tqdm(total=1, desc="Generating SDXL image") as pbar:
image = base(
prompt=prompt,
num_inference_steps=40,
denoising_end=0.8,
output_type="latent",
).images
image = refiner(
prompt=prompt,
num_inference_steps=40,
denoising_start=0.8,
image=image,
).images[0]
pbar.update(1)
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
save_object_to_redis(redis_key, image_bytes)
upload_to_gcs(gcs_bucket_name, redis_key, image_bytes)
except Exception as e:
print(f"Failed to generate SDXL image: {e}")
return None
return image_bytes
def generate_musicgen_melody(prompt):
redis_key = f"generated_musicgen_melody:{prompt}"
song_bytes = load_object_from_redis(redis_key)
if not song_bytes:
try:
with tqdm(total=1, desc="Generating MusicGen melody") as pbar:
melody, sr = torchaudio.load("./assets/bach.mp3")
wav = music_gen_melody.generate_with_chroma(
[prompt], melody[None].expand(3, -1, -1), sr
)
pbar.update(1)
song_bytes = wav[0].getvalue()
save_object_to_redis(redis_key, song_bytes)
upload_to_gcs(gcs_bucket_name, redis_key, song_bytes)
except Exception as e:
print(f"Failed to generate MusicGen melody: {e}")
return None
return song_bytes
def generate_musicgen_large(prompt):
redis_key = f"generated_musicgen_large:{prompt}"
song_bytes = load_object_from_redis(redis_key)
if not song_bytes:
try:
with tqdm(total=1, desc="Generating MusicGen large") as pbar:
wav = music_gen_large.generate([prompt])
pbar.update(1)
song_bytes = wav[0].getvalue()
save_object_to_redis(redis_key, song_bytes)
upload_to_gcs(gcs_bucket_name, redis_key, song_bytes)
except Exception as e:
print(f"Failed to generate MusicGen large: {e}")
return None
return song_bytes
def transcribe_audio(audio_sample):
redis_key = f"transcribed_audio:{hash(audio_sample.tobytes())}"
text = load_object_from_redis(redis_key)
if not text:
try:
with tqdm(total=1, desc="Transcribing audio") as pbar:
text = whisper_pipeline(audio_sample.copy(), batch_size=8)["text"]
pbar.update(1)
save_object_to_redis(redis_key, text)
upload_to_gcs(gcs_bucket_name, redis_key, text.encode())
except Exception as e:
print(f"Failed to transcribe audio: {e}")
return None
return text
def generate_mistral_instruct(prompt):
redis_key = f"generated_mistral_instruct:{prompt}"
response = load_object_from_redis(redis_key)
if not response:
try:
conversation = [{"role": "user", "content": prompt}]
with tqdm(total=1, desc="Generating Mistral Instruct response") as pbar:
inputs = mistral_instruct_tokenizer.apply_chat_template(
conversation,
tools=tools,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs.to(mistral_instruct_model.device)
outputs = mistral_instruct_model.generate(
**inputs, max_new_tokens=1000
)
response = mistral_instruct_tokenizer.decode(
outputs[0], skip_special_tokens=True
)
pbar.update(1)
save_object_to_redis(redis_key, response)
upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
except Exception as e:
print(f"Failed to generate Mistral Instruct response: {e}")
return None
return response
def generate_mistral_nemo(prompt):
redis_key = f"generated_mistral_nemo:{prompt}"
response = load_object_from_redis(redis_key)
if not response:
try:
conversation = [{"role": "user", "content": prompt}]
with tqdm(total=1, desc="Generating Mistral Nemo response") as pbar:
inputs = mistral_nemo_tokenizer.apply_chat_template(
conversation,
tools=tools,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs.to(mistral_nemo_model.device)
outputs = mistral_nemo_model.generate(**inputs, max_new_tokens=1000)
response = mistral_nemo_tokenizer.decode(
outputs[0], skip_special_tokens=True
)
pbar.update(1)
save_object_to_redis(redis_key, response)
upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
except Exception as e:
print(f"Failed to generate Mistral Nemo response: {e}")
return None
return response
def generate_gpt2_xl(prompt):
redis_key = f"generated_gpt2_xl:{prompt}"
response = load_object_from_redis(redis_key)
if not response:
try:
with tqdm(total=1, desc="Generating GPT-2 XL response") as pbar:
inputs = gpt2_xl_tokenizer(prompt, return_tensors="pt")
outputs = gpt2_xl_model(**inputs)
response = gpt2_xl_tokenizer.decode(
outputs[0][0], skip_special_tokens=True
)
pbar.update(1)
save_object_to_redis(redis_key, response)
upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
except Exception as e:
print(f"Failed to generate GPT-2 XL response: {e}")
return None
return response
def answer_question_minicpm(image_bytes, question):
redis_key = f"minicpm_answer:{hash(image_bytes)}:{question}"
answer = load_object_from_redis(redis_key)
if not answer:
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
with tqdm(total=1, desc="Answering question with MiniCPM") as pbar:
msgs = [{"role": "user", "content": [image, question]}]
answer = minicpm_model.chat(
image=None, msgs=msgs, tokenizer=minicpm_tokenizer
)
pbar.update(1)
save_object_to_redis(redis_key, answer)
upload_to_gcs(gcs_bucket_name, redis_key, answer.encode())
except Exception as e:
print(f"Failed to answer question with MiniCPM: {e}")
return None
return answer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
text_to_image_pipeline = get_model_or_download(
"stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained
)
img2img_pipeline = get_model_or_download(
"CompVis/stable-diffusion-v1-4",
"img2img_model",
StableDiffusionImg2ImgPipeline.from_pretrained,
)
flux_pipeline = get_model_or_download(
"black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained
)
text_gen_pipeline = transformers_pipeline(
"text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b"
)
music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained(
"melody"
).to(device)
meta_llama_pipeline = get_model_or_download(
"meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline
)
starcoder_model = AutoModelForCausalLM.from_pretrained(
"bigcode/starcoder"
).to(device)
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to(device)
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to(device)
music_gen_melody = musicgen.MusicGen.get_pretrained("melody").to(device)
music_gen_melody.set_generation_params(duration=8)
music_gen_large = musicgen.MusicGen.get_pretrained("large").to(device)
music_gen_large.set_generation_params(duration=8)
whisper_pipeline = transformers_pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
chunk_length_s=30,
device=device,
)
mistral_instruct_model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-Large-Instruct-2407",
torch_dtype=torch.bfloat16,
device_map="auto",
)
mistral_instruct_tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-Large-Instruct-2407"
)
mistral_nemo_model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-Nemo-Instruct-2407",
torch_dtype=torch.bfloat16,
device_map="auto",
)
mistral_nemo_tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-Nemo-Instruct-2407"
)
gpt2_xl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
gpt2_xl_model = GPT2Model.from_pretrained("gpt2-xl")
minicpm_model = AutoModel.from_pretrained(
"openbmb/MiniCPM-V-2_6",
trust_remote_code=True,
attn_implementation="sdpa",
torch_dtype=torch.bfloat16,
).eval().cuda()
minicpm_tokenizer = AutoTokenizer.from_pretrained(
"openbmb/MiniCPM-V-2_6", trust_remote_code=True
)
tools = [] # Define any tools needed for Mistral models
gen_image_tab = gr.Interface(
fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image"
)
edit_image_tab = gr.Interface(
fn=edit_image_with_prompt,
inputs=[
gr.Image(type="pil", label="Image:"),
gr.Textbox(label="Prompt:"),
gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:"),
],
outputs=gr.Image(type="pil"),
title="Edit Image",
)
generate_song_tab = gr.Interface(
fn=generate_song,
inputs=[
gr.Textbox(label="Prompt:"),
gr.Slider(5, 60, 10, step=1, label="Duration (s):"),
],
outputs=gr.Audio(type="numpy"),
title="Generate Songs",
)
generate_text_tab = gr.Interface(
fn=generate_text,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Textbox(label="Generated Text:"),
title="Generate Text",
)
generate_flux_image_tab = gr.Interface(
fn=generate_flux_image,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Image(type="pil"),
title="Generate FLUX Images",
)
generate_code_tab = gr.Interface(
fn=generate_code,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Textbox(label="Generated Code:"),
title="Generate Code",
)
model_meta_llama_test_tab = gr.Interface(
fn=test_model_meta_llama,
inputs=None,
outputs=gr.Textbox(label="Model Output:"),
title="Test Meta-Llama",
)
generate_image_sdxl_tab = gr.Interface(
fn=generate_image_sdxl,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Image(type="pil"),
title="Generate SDXL Image",
)
generate_musicgen_melody_tab = gr.Interface(
fn=generate_musicgen_melody,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Audio(type="numpy"),
title="Generate MusicGen Melody",
)
generate_musicgen_large_tab = gr.Interface(
fn=generate_musicgen_large,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Audio(type="numpy"),
title="Generate MusicGen Large",
)
transcribe_audio_tab = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(type="numpy", label="Audio Sample:"),
outputs=gr.Textbox(label="Transcribed Text:"),
title="Transcribe Audio",
)
generate_mistral_instruct_tab = gr.Interface(
fn=generate_mistral_instruct,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Textbox(label="Mistral Instruct Response:"),
title="Generate Mistral Instruct Response",
)
generate_mistral_nemo_tab = gr.Interface(
fn=generate_mistral_nemo,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Textbox(label="Mistral Nemo Response:"),
title="Generate Mistral Nemo Response",
)
generate_gpt2_xl_tab = gr.Interface(
fn=generate_gpt2_xl,
inputs=gr.Textbox(label="Prompt:"),
outputs=gr.Textbox(label="GPT-2 XL Response:"),
title="Generate GPT-2 XL Response",
)
answer_question_minicpm_tab = gr.Interface(
fn=answer_question_minicpm,
inputs=[
gr.Image(type="pil", label="Image:"),
gr.Textbox(label="Question:"),
],
outputs=gr.Textbox(label="MiniCPM Answer:"),
title="Answer Question with MiniCPM",
)
app = gr.TabbedInterface(
[
gen_image_tab,
edit_image_tab,
generate_song_tab,
generate_text_tab,
generate_flux_image_tab,
generate_code_tab,
model_meta_llama_test_tab,
generate_image_sdxl_tab,
generate_musicgen_melody_tab,
generate_musicgen_large_tab,
transcribe_audio_tab,
generate_mistral_instruct_tab,
generate_mistral_nemo_tab,
generate_gpt2_xl_tab,
answer_question_minicpm_tab,
],
[
"Generate Image",
"Edit Image",
"Generate Song",
"Generate Text",
"Generate FLUX Image",
"Generate Code",
"Test Meta-Llama",
"Generate SDXL Image",
"Generate MusicGen Melody",
"Generate MusicGen Large",
"Transcribe Audio",
"Generate Mistral Instruct Response",
"Generate Mistral Nemo Response",
"Generate GPT-2 XL Response",
"Answer Question with MiniCPM",
],
)
app.launch(share=True)