Samet Yilmaz
Reorganize
c32151f
raw
history blame
2.79 kB
import os
from vllm import LLM, SamplingParams
import gradio as gr
from PIL import Image
from io import BytesIO
import base64
import requests
from huggingface_hub import login
import os
login(os.environ["HF_TOKEN"])
repo_id = "mistral-community/pixtral-12b-240910" #Replace to the model you would like to use
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
max_tokens_per_img = 4096
max_img_per_msg = 5
llm = LLM(model="mistralai/Pixtral-12B-2409",
tokenizer_mode="mistral",
max_model_len=65536,
max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
limit_mm_per_prompt={"image": max_img_per_msg}) # Name or path of your model
def encode_image(image: Image.Image, image_format="PNG") -> str:
im_file = BytesIO()
image.save(im_file, format=image_format)
im_bytes = im_file.getvalue()
im_64 = base64.b64encode(im_bytes).decode("utf-8")
return im_64
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(image_url, prompt, progress=gr.Progress(track_tqdm=True)):
image = Image.open(BytesIO(requests.get(image_url).content))
image = image.resize((3844, 2408))
new_image_url = f"data:image/png;base64,{encode_image(image, image_format='PNG')}"
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": new_image_url}}]
},
]
outputs = llm.chat(messages, sampling_params=sampling_params)
return outputs[0].outputs[0].text
examples = [["https://picsum.photos/id/237/200/300", "What do you see in this image?"]]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Mistral Pixtral 12B
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your prompt",
container=False,
)
with gr.Row():
image_url = gr.Text(
label="Image URL",
show_label=False,
max_lines=1,
placeholder="Enter your image URL",
container=False,
)
with gr.Row():
run_button = gr.Button("Run", scale=0)
result = gr.Textbox(
show_label=False
)
gr.Examples(
examples=examples,
inputs=[image_url, prompt]
)
gr.on(
triggers=[run_button.click, image_url.submit, prompt.submit],
fn=infer,
inputs=[image_url, prompt],
outputs=[result]
)
demo.queue().launch()