Harshithtd commited on
Commit
97429f0
1 Parent(s): 7a3b1ec

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from vllm import LLM, SamplingParams
4
+ import gradio as gr
5
+
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import base64
9
+ import requests
10
+
11
+ from huggingface_hub import login
12
+ import os
13
+
14
+ login(os.environ["HF_TOKEN"])
15
+
16
+ repo_id = "mistral-community/pixtral-12b-240910" #Replace to the model you would like to use
17
+ sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
18
+ max_tokens_per_img = 4096
19
+ max_img_per_msg = 5
20
+
21
+ llm = LLM(model="mistralai/Pixtral-12B-2409",
22
+ tokenizer_mode="mistral",
23
+ max_model_len=65536,
24
+ max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
25
+ limit_mm_per_prompt={"image": max_img_per_msg}) # Name or path of your model
26
+
27
+ def encode_image(image: Image.Image, image_format="PNG") -> str:
28
+ im_file = BytesIO()
29
+ image.save(im_file, format=image_format)
30
+ im_bytes = im_file.getvalue()
31
+ im_64 = base64.b64encode(im_bytes).decode("utf-8")
32
+ return im_64
33
+
34
+
35
+ # @spaces.GPU #[uncomment to use ZeroGPU]
36
+ def infer(image_url, prompt, progress=gr.Progress(track_tqdm=True)):
37
+ image = Image.open(BytesIO(requests.get(image_url).content))
38
+ image = image.resize((3844, 2408))
39
+ new_image_url = f"data:image/png;base64,{encode_image(image, image_format='PNG')}"
40
+
41
+ messages = [
42
+ {
43
+ "role": "user",
44
+ "content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": new_image_url}}]
45
+ },
46
+ ]
47
+
48
+ outputs = llm.chat(messages, sampling_params=sampling_params)
49
+
50
+ return outputs[0].outputs[0].text
51
+
52
+
53
+ examples = [["https://picsum.photos/id/237/200/300", "What do you see in this image?"]]
54
+
55
+ css = """
56
+ #col-container {
57
+ margin: 0 auto;
58
+ max-width: 640px;
59
+ }
60
+ """
61
+
62
+ with gr.Blocks(css=css) as demo:
63
+ with gr.Column(elem_id="col-container"):
64
+ gr.Markdown(f"""
65
+ # Mistral Pixtral 12B
66
+ """)
67
+
68
+ with gr.Row():
69
+ prompt = gr.Text(
70
+ label="Prompt",
71
+ show_label=False,
72
+ max_lines=2,
73
+ placeholder="Enter your prompt",
74
+ container=False,
75
+ )
76
+
77
+ with gr.Row():
78
+ image_url = gr.Text(
79
+ label="Image URL",
80
+ show_label=False,
81
+ max_lines=1,
82
+ placeholder="Enter your image URL",
83
+ container=False,
84
+ )
85
+
86
+ with gr.Row():
87
+ run_button = gr.Button("Run", scale=0)
88
+
89
+ result = gr.Textbox(
90
+ show_label=False
91
+ )
92
+
93
+ gr.Examples(
94
+ examples=examples,
95
+ inputs=[image_url, prompt]
96
+ )
97
+ gr.on(
98
+ triggers=[run_button.click, image_url.submit, prompt.submit],
99
+ fn=infer,
100
+ inputs=[image_url, prompt],
101
+ outputs=[result]
102
+ )
103
+
104
+ demo.queue().launch()