File size: 9,085 Bytes
91f15c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ce11a7
74db01d
91f15c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9db781
91f15c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b315d19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# This is a sample Python script.

# Press ⌃R to execute it or replace it with your code.
# Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings.
import base64
import datetime
import json

import cv2
import requests
from PIL import Image, ImageDraw, ImageFont, ImageOps
import numpy as np
from io import BytesIO
import time

main_image_path = "/Users/aaron/Documents/temp/16pic_2415206_s.png"
API_TOKEN = "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"
headers = {"Authorization": f"Bearer {API_TOKEN}"}
# API_URL = "https://api-inference.huggingface.co/models/hustvl/yolos-tiny"
API_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50"
# API_OBJECT_URL = "https://api-inference.huggingface.co/models/microsoft/resnet-50"
API_SEGMENTATION_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50-panoptic"
API_SEGMENTATION_URL_2 = "https://api-inference.huggingface.co/models/nvidia/segformer-b0-finetuned-ade-512-512"

temp_dir = "/Users/aaron/Documents/temp/imageai/"


def query(filename):
    with open(filename, "rb") as f:
        data = f.read()
    response = requests.request("POST", API_URL, headers=headers, data=data)
    return json.loads(response.content.decode("utf-8"))


def queryObjectDetection(filename):
    with open(filename, "rb") as f:
        data = f.read()
    response = requests.request("POST", API_OBJECT_URL, headers=headers, data=data, timeout=6)
    print(response)
    return json.loads(response.content.decode("utf-8"))


def getImageSegmentation():
    data = query(main_image_path)
    print(data)
    return data


def crop_image(box):
    # 打开图片
    image = Image.open(main_image_path)

    # 计算裁剪区域
    crop_area = (box['xmin'], box['ymin'], box['xmax'], box['ymax'])

    # 裁剪图片
    cropped_image = image.crop(crop_area)

    return cropped_image


# # 示例
# image_path = "path/to/your/image.jpg"
# box = {'xmin': 186, 'ymin': 75, 'xmax': 252, 'ymax': 123}
#
# cropped_image = crop_image(image_path, box)
# cropped_image.show()  # 显示裁剪后的图片
# cropped_image.save("path/to/save/cropped_image.jpg")  # 保存裁剪后的图片

# Press the green button in the gutter to run the script.
# if __name__ == '__main__':
#     data = getImageSegmentation()
#     for item in data:
#         box = item['box']
#         cropped_image = crop_image(box)
#         temp_image_path = temp_dir + str(int(datetime.datetime.now().timestamp() * 1000000)) + ".png"
#         print(temp_image_path)
#         cropped_image.save(temp_image_path)
#         object_data = queryObjectDetection(temp_image_path)
#         print(object_data)
#         flag = False
#         for obj in object_data:
#             # 检查字典中是否包含 'error' 键
#             if 'error' in obj and obj['error'] is not None:
#                 flag = True
#                 print("找到了一个包含 'error' 键的字典,且其值不为 None")
#             else:
#                 print("字典不包含 'error' 键,或其值为 None")
#         if flag:
#             continue
#         item['label'] = object_data[0]['label']
#     print(data)
#
#     ###下面就是画个图,和上面住流程无关,仅仅用于测试
#     image = Image.open(main_image_path)
#     draw = ImageDraw.Draw(image)
#
#     # 设置边框颜色和字体
#     border_color = (255, 0, 0)  # 红色
#     text_color = (255, 255, 255)  # 白色
#     font = ImageFont.truetype("Geneva.ttf", 12)  # 使用 系统Geneva 字体,大小为 8
#
#     # 遍历对象列表,画边框和标签
#     for obj in data:
#         label = obj['label']
#         box = obj['box']
#         xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax']
#
#         # 画边框
#         draw.rectangle([xmin, ymin, xmax, ymax], outline=border_color, width=2)
#
#         # 画标签
#         text_size = draw.textsize(label, font=font)
#         draw.rectangle([xmin, ymin, xmin + text_size[0], ymin + text_size[1]], fill=border_color)
#         draw.text((xmin, ymin), label, font=font, fill=text_color)
#
#     image.show()


import numpy as np
from PIL import Image
import gradio as gr


def send_request_to_api(img_byte_arr, max_retries=3, wait_time=60):
    retry_count = 0

    while retry_count < max_retries:
        response = requests.request("POST", API_SEGMENTATION_URL, headers=headers, data=img_byte_arr)
        response_content = response.content.decode("utf-8")

        # 检查响应是否包含错误
        if "error" in response_content:
            print(f"Error: {response_content}")
            retry_count += 1
            time.sleep(wait_time)
        else:
            json_obj = json.loads(response_content)
            return json_obj

    raise Exception("Failed to get a valid response from the API after multiple retries.")

def getSegmentationMaskImage(input_img, blur_kernel_size=21):
    # 调整输入图像的大小
    target_width = 600
    aspect_ratio = float(input_img.height) / float(input_img.width)
    target_height = int(target_width * aspect_ratio)
    input_img.thumbnail((target_width, target_height))

    img_byte_arr = BytesIO()
    input_img.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()
    json_obj = send_request_to_api(img_byte_arr)
    print(json_obj)

    # 加载原始图像
    original_image = input_img.copy()

    # 如果原始图像不是RGBA模式,则将其转换为RGBA模式
    if original_image.mode != 'RGBA':
        original_image = original_image.convert('RGBA')

    output_images = []
    for item in json_obj:
        label = item['label']

        # 如果label以"LABEL"开头,则跳过此项
        if label.startswith("LABEL"):
            continue
        mask_data = item['mask']

        # 将Base64编码的mask数据解码为PNG图像
        mask_image = Image.open(BytesIO(base64.b64decode(mask_data)))

        # 将原始图像转换为OpenCV格式并应用高斯模糊
        original_image_cv2 = cv2.cvtColor(np.array(original_image.convert('RGB')), cv2.COLOR_RGB2BGR)
        blurred_image_cv2 = cv2.GaussianBlur(original_image_cv2, (blur_kernel_size, blur_kernel_size), 0)

        # 将模糊图像转换回PIL格式,并将其转换回原始图像的颜色模式
        blurred_image = Image.fromarray(cv2.cvtColor(blurred_image_cv2, cv2.COLOR_BGR2RGB)).convert(original_image.mode)

        # 使用mask_image作为蒙版将原始图像的非模糊部分复制到模糊图像上
        process_image = Image.composite(original_image, blurred_image, mask_image)

        # 在mask位置添加红色文本和指向原始图像非模糊部分的红色线
        draw = ImageDraw.Draw(process_image)
        font = ImageFont.load_default()  # 您可以选择其他字体和大小
        text_position = (10, 30)
        draw.text(text_position, label, font=font, fill=(255, 0, 0))

        # 计算mask的边界框
        mask_bbox = mask_image.getbbox()

        # 计算mask边界框的顶部中心点
        mask_top_center_x = (mask_bbox[0] + mask_bbox[2]) // 2
        mask_top_center_y = mask_bbox[1]

        # 计算文本框的底部中心点
        text_width, text_height = draw.textsize(label, font=font)
        text_bottom_center_x = text_position[0] + text_width // 2
        text_bottom_center_y = text_position[1] + text_height

        # 绘制一条从文本框底部中心到mask边界框顶部中心的红色线
        draw.line([(text_bottom_center_x, text_bottom_center_y), (mask_top_center_x, mask_top_center_y)],
                  fill=(255, 0, 0), width=2)

        output_images.append(process_image)
    return output_images

def sepia(input_img):
    # 检查输入图像的数据类型和值范围
    if input_img.dtype == np.float32 and np.max(input_img) <= 1.0:
        input_img = (input_img * 255).astype(np.uint8)

    input_img = Image.fromarray(input_img)
    output_images = getSegmentationMaskImage(input_img)

    # 将所有图像堆叠在一起
    stacked_image = np.vstack([np.array(img) for img in output_images])
    return stacked_image


def imageDemo():
    demo = gr.Interface(sepia, gr.Image(shape=None), gr.outputs.Image(label="Processed Images", type="numpy"),
                        title='Image Processing Demo')
    demo.launch()


if __name__ == '__main__':
    imageDemo()



#######---------gif输出方式
# def sepia(input_img):
#     input_img = Image.fromarray((input_img * 255).astype(np.uint8))
#
#     output_images = getSegmentationMaskImage(input_img)
#
#     # 生成GIF动画
#     buffered = BytesIO()
#     output_images[0].save(buffered, format='GIF', save_all=True, append_images=output_images[1:], duration=3000, loop=0)
#     gif_str = base64.b64encode(buffered.getvalue()).decode()
#     return f'<img src="data:image/gif;base64,{gif_str}" width="400" />'
#
#
# def imageDemo():
#     demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), gr.outputs.HTML(label="Processed Animation"), title='Sepia Filter Demo')
#     demo.launch()