linm1 commited on
Commit
302615c
1 Parent(s): 1e3ede4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -162
app.py CHANGED
@@ -1,163 +1,162 @@
1
- import gradio as gr
2
- import os
3
- from dotenv import load_dotenv
4
- import base64
5
- from io import BytesIO
6
- from mistralai import Mistral
7
- from pydantic import BaseModel, Field
8
- from datasets import load_dataset
9
- from PIL import Image
10
- import json
11
- import sqlite3
12
- from datetime import datetime
13
-
14
- # Load the dataset
15
- ds = load_dataset("svjack/pokemon-blip-captions-en-zh")
16
- ds = ds["train"]
17
-
18
- # Load environment variables
19
- api_key = os.environ.get('MISTRAL_API_KEY')
20
-
21
- if not api_key:
22
- raise ValueError("MISTRAL_API_KEY is not set in the environment variables.")
23
-
24
- # Create sample history
25
- hist = [str({"en": ds[i]["en_text"], "zh": ds[i]["zh_text"]}) for i in range(8)]
26
- hist_str = "\n".join(hist)
27
-
28
- # Define the Caption model
29
- class Caption(BaseModel):
30
- en: str = Field(...,
31
- description="English caption of image",
32
- max_length=84)
33
- zh: str = Field(...,
34
- description="Chinese caption of image",
35
- max_length=64)
36
-
37
- # Initialize the Mistral client
38
- client = Mistral(api_key=api_key)
39
-
40
- def generate_caption(image):
41
- # Convert image to base64
42
- buffered = BytesIO()
43
- image.save(buffered, format="JPEG")
44
- base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
45
-
46
- messages = [
47
- {
48
- "role": "system",
49
- "content": f'''
50
- You are a highly accurate image to caption transformer.
51
- Describe the image content in English and Chinese respectively. Make sure to FOCUS on item CATEGORY and COLOR!
52
- Do NOT provide NAMES! KEEP it SHORT!
53
- While adhering to the following JSON schema: {Caption.model_json_schema()}
54
- Following are some samples you should adhere to for style and tone:
55
- {hist_str}
56
- '''
57
- },
58
- {
59
- "role": "user",
60
- "content": [
61
- {
62
- "type": "text",
63
- "text": "Describe the image in English and Chinese"
64
- },
65
- {
66
- "type": "image_url",
67
- "image_url": f"data:image/jpeg;base64,{base64_image}"
68
- }
69
- ]
70
- }
71
- ]
72
-
73
- chat_response = client.chat.complete(
74
- model="pixtral-12b-2409",
75
- messages=messages,
76
- response_format = {
77
- "type": "json_object",
78
- }
79
- )
80
-
81
- response_content = chat_response.choices[0].message.content
82
-
83
- try:
84
- caption_dict = json.loads(response_content)
85
- return Caption(**caption_dict)
86
- except json.JSONDecodeError as e:
87
- print(f"Error decoding JSON: {e}")
88
- return None
89
-
90
- # Initialize SQLite database
91
- def init_db():
92
- conn = sqlite3.connect('feedback.db')
93
- c = conn.cursor()
94
- c.execute('''CREATE TABLE IF NOT EXISTS thumbs_up
95
- (id INTEGER PRIMARY KEY AUTOINCREMENT,
96
- timestamp TEXT,
97
- input_data TEXT,
98
- output_data TEXT)''')
99
- conn.commit()
100
- conn.close()
101
-
102
- init_db()
103
-
104
- def process_image(image):
105
- if image is None:
106
- return "Please upload an image first."
107
-
108
- result = generate_caption(image)
109
-
110
- if result:
111
- return f"English caption: {result.en}\nChinese caption: {result.zh}"
112
- else:
113
- return "Failed to generate caption. Please check the API call or network connectivity."
114
-
115
- def thumbs_up(image, caption):
116
- # Convert image to base64 string for storage
117
- buffered = BytesIO()
118
- image.save(buffered, format="JPEG")
119
- img_str = base64.b64encode(buffered.getvalue()).decode()
120
-
121
- conn = sqlite3.connect('feedback.db')
122
- c = conn.cursor()
123
- c.execute("INSERT INTO thumbs_up (timestamp, input_data, output_data) VALUES (?, ?, ?)",
124
- (datetime.now().isoformat(), img_str, caption))
125
- conn.commit()
126
- conn.close()
127
- print(f"Thumbs up data saved to database.")
128
- return gr.Notification("Thank you for your feedback!", type="success")
129
-
130
- # Create Gradio interface
131
- custom_css = """
132
- .highlight-btn {
133
- background-color: #3498db !important;
134
- border-color: #3498db !important;
135
- color: white !important;
136
- }
137
- .highlight-btn:hover {
138
- background-color: #2980b9 !important;
139
- border-color: #2980b9 !important;
140
- }
141
- """
142
-
143
- with gr.Blocks() as iface:
144
- gr.Markdown("# Image Captioner")
145
- gr.Markdown("Upload an image to generate captions in English and Chinese. Use the 'Thumbs Up' button if you like the result!")
146
-
147
- with gr.Row():
148
- with gr.Column(scale=1):
149
- input_image = gr.Image(type="pil")
150
- with gr.Row():
151
- clear_btn = gr.Button("Clear")
152
- submit_btn = gr.Button("Submit", elem_classes=["highlight-btn"])
153
-
154
- with gr.Column(scale=1):
155
- output_text = gr.Textbox()
156
- thumbs_up_btn = gr.Button("Thumbs Up")
157
-
158
- clear_btn.click(fn=lambda: None, inputs=None, outputs=input_image)
159
- submit_btn.click(fn=process_image, inputs=input_image, outputs=output_text)
160
- thumbs_up_btn.click(fn=thumbs_up, inputs=[input_image, output_text], outputs=None)
161
-
162
- # Launch the interface
163
  iface.launch(share=True)
 
1
+ import gradio as gr
2
+ import os
3
+ import base64
4
+ from io import BytesIO
5
+ from mistralai import Mistral
6
+ from pydantic import BaseModel, Field
7
+ from datasets import load_dataset
8
+ from PIL import Image
9
+ import json
10
+ import sqlite3
11
+ from datetime import datetime
12
+
13
+ # Load the dataset
14
+ ds = load_dataset("svjack/pokemon-blip-captions-en-zh")
15
+ ds = ds["train"]
16
+
17
+ # Load environment variables
18
+ api_key = os.environ.get('MISTRAL_API_KEY')
19
+
20
+ if not api_key:
21
+ raise ValueError("MISTRAL_API_KEY is not set in the environment variables.")
22
+
23
+ # Create sample history
24
+ hist = [str({"en": ds[i]["en_text"], "zh": ds[i]["zh_text"]}) for i in range(8)]
25
+ hist_str = "\n".join(hist)
26
+
27
+ # Define the Caption model
28
+ class Caption(BaseModel):
29
+ en: str = Field(...,
30
+ description="English caption of image",
31
+ max_length=84)
32
+ zh: str = Field(...,
33
+ description="Chinese caption of image",
34
+ max_length=64)
35
+
36
+ # Initialize the Mistral client
37
+ client = Mistral(api_key=api_key)
38
+
39
+ def generate_caption(image):
40
+ # Convert image to base64
41
+ buffered = BytesIO()
42
+ image.save(buffered, format="JPEG")
43
+ base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
44
+
45
+ messages = [
46
+ {
47
+ "role": "system",
48
+ "content": f'''
49
+ You are a highly accurate image to caption transformer.
50
+ Describe the image content in English and Chinese respectively. Make sure to FOCUS on item CATEGORY and COLOR!
51
+ Do NOT provide NAMES! KEEP it SHORT!
52
+ While adhering to the following JSON schema: {Caption.model_json_schema()}
53
+ Following are some samples you should adhere to for style and tone:
54
+ {hist_str}
55
+ '''
56
+ },
57
+ {
58
+ "role": "user",
59
+ "content": [
60
+ {
61
+ "type": "text",
62
+ "text": "Describe the image in English and Chinese"
63
+ },
64
+ {
65
+ "type": "image_url",
66
+ "image_url": f"data:image/jpeg;base64,{base64_image}"
67
+ }
68
+ ]
69
+ }
70
+ ]
71
+
72
+ chat_response = client.chat.complete(
73
+ model="pixtral-12b-2409",
74
+ messages=messages,
75
+ response_format = {
76
+ "type": "json_object",
77
+ }
78
+ )
79
+
80
+ response_content = chat_response.choices[0].message.content
81
+
82
+ try:
83
+ caption_dict = json.loads(response_content)
84
+ return Caption(**caption_dict)
85
+ except json.JSONDecodeError as e:
86
+ print(f"Error decoding JSON: {e}")
87
+ return None
88
+
89
+ # Initialize SQLite database
90
+ def init_db():
91
+ conn = sqlite3.connect('feedback.db')
92
+ c = conn.cursor()
93
+ c.execute('''CREATE TABLE IF NOT EXISTS thumbs_up
94
+ (id INTEGER PRIMARY KEY AUTOINCREMENT,
95
+ timestamp TEXT,
96
+ input_data TEXT,
97
+ output_data TEXT)''')
98
+ conn.commit()
99
+ conn.close()
100
+
101
+ init_db()
102
+
103
+ def process_image(image):
104
+ if image is None:
105
+ return "Please upload an image first."
106
+
107
+ result = generate_caption(image)
108
+
109
+ if result:
110
+ return f"English caption: {result.en}\nChinese caption: {result.zh}"
111
+ else:
112
+ return "Failed to generate caption. Please check the API call or network connectivity."
113
+
114
+ def thumbs_up(image, caption):
115
+ # Convert image to base64 string for storage
116
+ buffered = BytesIO()
117
+ image.save(buffered, format="JPEG")
118
+ img_str = base64.b64encode(buffered.getvalue()).decode()
119
+
120
+ conn = sqlite3.connect('feedback.db')
121
+ c = conn.cursor()
122
+ c.execute("INSERT INTO thumbs_up (timestamp, input_data, output_data) VALUES (?, ?, ?)",
123
+ (datetime.now().isoformat(), img_str, caption))
124
+ conn.commit()
125
+ conn.close()
126
+ print(f"Thumbs up data saved to database.")
127
+ return gr.Notification("Thank you for your feedback!", type="success")
128
+
129
+ # Create Gradio interface
130
+ custom_css = """
131
+ .highlight-btn {
132
+ background-color: #3498db !important;
133
+ border-color: #3498db !important;
134
+ color: white !important;
135
+ }
136
+ .highlight-btn:hover {
137
+ background-color: #2980b9 !important;
138
+ border-color: #2980b9 !important;
139
+ }
140
+ """
141
+
142
+ with gr.Blocks() as iface:
143
+ gr.Markdown("# Image Captioner")
144
+ gr.Markdown("Upload an image to generate captions in English and Chinese. Use the 'Thumbs Up' button if you like the result!")
145
+
146
+ with gr.Row():
147
+ with gr.Column(scale=1):
148
+ input_image = gr.Image(type="pil")
149
+ with gr.Row():
150
+ clear_btn = gr.Button("Clear")
151
+ submit_btn = gr.Button("Submit", elem_classes=["highlight-btn"])
152
+
153
+ with gr.Column(scale=1):
154
+ output_text = gr.Textbox()
155
+ thumbs_up_btn = gr.Button("Thumbs Up")
156
+
157
+ clear_btn.click(fn=lambda: None, inputs=None, outputs=input_image)
158
+ submit_btn.click(fn=process_image, inputs=input_image, outputs=output_text)
159
+ thumbs_up_btn.click(fn=thumbs_up, inputs=[input_image, output_text], outputs=None)
160
+
161
+ # Launch the interface
 
162
  iface.launch(share=True)