File size: 4,681 Bytes
302615c
 
 
 
 
 
 
 
 
 
 
76e65f9
302615c
 
 
 
 
76e65f9
 
 
 
 
 
 
 
 
302615c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76e65f9
 
 
 
 
302615c
76e65f9
 
 
 
 
 
 
 
 
302615c
 
 
 
 
 
 
 
 
 
 
 
 
76e65f9
302615c
 
 
76e65f9
 
302615c
 
 
 
 
 
3fe04de
302615c
 
 
 
 
 
 
 
 
 
76e65f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import os
import base64
from io import BytesIO
from mistralai import Mistral
from pydantic import BaseModel, Field
from datasets import load_dataset
from PIL import Image
import json
import sqlite3
from datetime import datetime
from pymongo import MongoClient

# Load the dataset
ds = load_dataset("svjack/pokemon-blip-captions-en-zh")
ds = ds["train"]

# load MongoDB client
MONGO_URI = os.environ.get('MONGO_URI')
if not MONGO_URI:
    raise ValueError("MONGO_URI is not set in the environment variables.")

client = MongoClient(MONGO_URI)
db = client['capimg']  # Choose a database name
collection = db['feedback']  # Choose a collection name

# Load environment variables
api_key = os.environ.get('MISTRAL_API_KEY')

if not api_key:
    raise ValueError("MISTRAL_API_KEY is not set in the environment variables.")

# Create sample history
hist = [str({"en": ds[i]["en_text"], "zh": ds[i]["zh_text"]}) for i in range(8)]
hist_str = "\n".join(hist)

# Define the Caption model
class Caption(BaseModel):
    en: str = Field(...,
        description="English caption of image",
        max_length=84)
    zh: str = Field(...,
        description="Chinese caption of image",
        max_length=64)

# Initialize the Mistral client
client = Mistral(api_key=api_key)

def generate_caption(image):
    # Convert image to base64
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
    
    messages = [
        {
            "role": "system",
            "content": f'''
            You are a highly accurate image to caption transformer.
            Describe the image content in English and Chinese respectively. Make sure to FOCUS on item CATEGORY and COLOR!
            Do NOT provide NAMES! KEEP it SHORT!
            While adhering to the following JSON schema: {Caption.model_json_schema()}
            Following are some samples you should adhere to for style and tone:
            {hist_str}
            '''
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Describe the image in English and Chinese"
                },
                {
                    "type": "image_url",
                    "image_url": f"data:image/jpeg;base64,{base64_image}"
                }
            ]
        }
    ]
    
    chat_response = client.chat.complete(
        model="pixtral-12b-2409",
        messages=messages,
        response_format = {
          "type": "json_object",
        }
    )
    
    response_content = chat_response.choices[0].message.content
    
    try:
        caption_dict = json.loads(response_content)
        return Caption(**caption_dict)
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
        return None

# Initialize SQLite database
def save_feedback(image, caption):
    # Convert image to base64 string for storage
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    feedback_entry = {
        "timestamp": datetime.now(),
        "input_data": img_str,
        "output_data": caption
    }
    
    result = collection.insert_one(feedback_entry)
    print(f"Feedback saved with id: {result.inserted_id}")
    return gr.Info("Thanks for your feedback!")

def process_image(image):
    if image is None:
        return "Please upload an image first."
    
    result = generate_caption(image)
    
    if result:
        return f"English caption: {result.en}\nChinese caption: {result.zh}"
    else:
        return "Failed to generate caption. Please check the API call or network connectivity."

def thumbs_up(image, caption):
    return save_feedback(image, caption)

with gr.Blocks() as iface:
    gr.Markdown("# Image Captioner")
    gr.Markdown("Upload an image to generate captions in English and Chinese.")
    gr.Markdown("Use the 'Thumbs Up' button if you like the result!!") 
    
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil")
            with gr.Row():
                clear_btn = gr.Button("Clear")
                submit_btn = gr.Button("Submit")
        
        with gr.Column(scale=1):
            output_text = gr.Textbox()
            thumbs_up_btn = gr.Button("Thumbs Up")
    
    clear_btn.click(fn=lambda: None, inputs=None, outputs=input_image)
    submit_btn.click(fn=process_image, inputs=input_image, outputs=output_text)
    thumbs_up_btn.click(fn=thumbs_up, inputs=[input_image, output_text], outputs=None)

# Launch the interface
iface.launch()