P01yH3dr0n commited on
Commit
86f7b58
1 Parent(s): c33c656

tagger support

Browse files
Files changed (3) hide show
  1. app.py +4 -0
  2. requirements.txt +4 -1
  3. tagger.py +297 -0
app.py CHANGED
@@ -9,6 +9,7 @@ from huggingface_hub import HfApi, snapshot_download
9
  from pnginfo import read_info_from_image, send_paras
10
  from images_history import img_history_ui
11
  from director_tools import director_ui, send_outputs
 
12
  from utils import set_token, generate_novelai_image, image_from_bytes, get_remain_anlas, calculate_cost
13
 
14
  client_config = toml.load("config.toml")['client']
@@ -258,6 +259,8 @@ def ui():
258
  from_t2i, send_i2i, send_inp, send_vib, in_image, out_image, d_index = director_ui()
259
  with gr.TabItem("图片信息读取"):
260
  png2main, png_items, info, read_image = util_ui()
 
 
261
  with gr.TabItem("云端图片浏览") as tab:
262
  gallery, h_index, gal2main, gal_items, history2ref, history2i2i, history2inp, history2dtl = img_history_ui(tab)
263
  with gr.TabItem("设置"):
@@ -278,6 +281,7 @@ def ui():
278
  send_jump_select(history2inp, send_outputs, [gallery, h_index], paras[22], "client_ui_main", others[6], gal_items, "inp_block")
279
  send_jump_select(history2ref, (lambda l, i: None if i == -1 else [l[i]]), [gallery, h_index], paras[15], "client_ui_main", others[5], gal_items)
280
  send_and_jump(history2dtl, send_outputs, [gallery, h_index], in_image, "client_ui_dtool", gal_items)
 
281
  read_image.change(read_info_from_image, inputs=read_image, outputs=[info, png_items])
282
  return website
283
 
 
9
  from pnginfo import read_info_from_image, send_paras
10
  from images_history import img_history_ui
11
  from director_tools import director_ui, send_outputs
12
+ from tagger import tagger_ui
13
  from utils import set_token, generate_novelai_image, image_from_bytes, get_remain_anlas, calculate_cost
14
 
15
  client_config = toml.load("config.toml")['client']
 
259
  from_t2i, send_i2i, send_inp, send_vib, in_image, out_image, d_index = director_ui()
260
  with gr.TabItem("图片信息读取"):
261
  png2main, png_items, info, read_image = util_ui()
262
+ with gr.TabItem("Tagger反推"):
263
+ tags, tagger2main = tagger_ui()
264
  with gr.TabItem("云端图片浏览") as tab:
265
  gallery, h_index, gal2main, gal_items, history2ref, history2i2i, history2inp, history2dtl = img_history_ui(tab)
266
  with gr.TabItem("设置"):
 
281
  send_jump_select(history2inp, send_outputs, [gallery, h_index], paras[22], "client_ui_main", others[6], gal_items, "inp_block")
282
  send_jump_select(history2ref, (lambda l, i: None if i == -1 else [l[i]]), [gallery, h_index], paras[15], "client_ui_main", others[5], gal_items)
283
  send_and_jump(history2dtl, send_outputs, [gallery, h_index], in_image, "client_ui_dtool", gal_items)
284
+ send_and_jump(tagger2main, (lambda x: x), tags, paras[0], "client_ui_main", tags)
285
  read_image.change(read_info_from_image, inputs=read_image, outputs=[info, png_items])
286
  return website
287
 
requirements.txt CHANGED
@@ -4,4 +4,7 @@ pillow
4
  numpy
5
  gradio==4.38.1
6
  toml
7
- piexif
 
 
 
 
4
  numpy
5
  gradio==4.38.1
6
  toml
7
+ piexif
8
+ pillow>=9.0.0
9
+ onnxruntime>=1.12.0
10
+ huggingface-hub
tagger.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import gradio as gr
4
+ import huggingface_hub
5
+ import numpy as np
6
+ import onnxruntime as rt
7
+ import pandas as pd
8
+ from PIL import Image
9
+
10
+ # Dataset v3 series of models:
11
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
12
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
13
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
14
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
15
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
16
+
17
+ # Dataset v2 series of models:
18
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
19
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
20
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
21
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
22
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
23
+
24
+ # Files to download from the repos
25
+ MODEL_FILENAME = "model.onnx"
26
+ LABEL_FILENAME = "selected_tags.csv"
27
+
28
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
29
+ kaomojis = [
30
+ "0_0",
31
+ "(o)_(o)",
32
+ "+_+",
33
+ "+_-",
34
+ "._.",
35
+ "<o>_<o>",
36
+ "<|>_<|>",
37
+ "=_=",
38
+ ">_<",
39
+ "3_3",
40
+ "6_9",
41
+ ">_o",
42
+ "@_@",
43
+ "^_^",
44
+ "o_o",
45
+ "u_u",
46
+ "x_x",
47
+ "|_|",
48
+ "||_||",
49
+ ]
50
+
51
+ def load_labels(dataframe) -> list[str]:
52
+ name_series = dataframe["name"]
53
+ name_series = name_series.map(
54
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
55
+ )
56
+ tag_names = name_series.tolist()
57
+
58
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
59
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
60
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
61
+ return tag_names, rating_indexes, general_indexes, character_indexes
62
+
63
+
64
+ def mcut_threshold(probs):
65
+ """
66
+ Maximum Cut Thresholding (MCut)
67
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
68
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
69
+ (pp. 172-183).
70
+ """
71
+ sorted_probs = probs[probs.argsort()[::-1]]
72
+ difs = sorted_probs[:-1] - sorted_probs[1:]
73
+ t = difs.argmax()
74
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
75
+ return thresh
76
+
77
+
78
+ class Predictor:
79
+ def __init__(self):
80
+ self.model_target_size = None
81
+ self.last_loaded_repo = None
82
+
83
+ def download_model(self, model_repo):
84
+ csv_path = huggingface_hub.hf_hub_download(
85
+ model_repo,
86
+ LABEL_FILENAME,
87
+ )
88
+ model_path = huggingface_hub.hf_hub_download(
89
+ model_repo,
90
+ MODEL_FILENAME,
91
+ )
92
+ return csv_path, model_path
93
+
94
+ def load_model(self, model_repo):
95
+ if model_repo == self.last_loaded_repo:
96
+ return
97
+
98
+ csv_path, model_path = self.download_model(model_repo)
99
+
100
+ tags_df = pd.read_csv(csv_path)
101
+ sep_tags = load_labels(tags_df)
102
+
103
+ self.tag_names = sep_tags[0]
104
+ self.rating_indexes = sep_tags[1]
105
+ self.general_indexes = sep_tags[2]
106
+ self.character_indexes = sep_tags[3]
107
+
108
+ model = rt.InferenceSession(model_path)
109
+ _, height, width, _ = model.get_inputs()[0].shape
110
+ self.model_target_size = height
111
+
112
+ self.last_loaded_repo = model_repo
113
+ self.model = model
114
+
115
+ def prepare_image(self, image):
116
+ target_size = self.model_target_size
117
+
118
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
119
+ canvas.alpha_composite(image)
120
+ image = canvas.convert("RGB")
121
+
122
+ # Pad image to square
123
+ image_shape = image.size
124
+ max_dim = max(image_shape)
125
+ pad_left = (max_dim - image_shape[0]) // 2
126
+ pad_top = (max_dim - image_shape[1]) // 2
127
+
128
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
129
+ padded_image.paste(image, (pad_left, pad_top))
130
+
131
+ # Resize
132
+ if max_dim != target_size:
133
+ padded_image = padded_image.resize(
134
+ (target_size, target_size),
135
+ Image.BICUBIC,
136
+ )
137
+
138
+ # Convert to numpy array
139
+ image_array = np.asarray(padded_image, dtype=np.float32)
140
+
141
+ # Convert PIL-native RGB to BGR
142
+ image_array = image_array[:, :, ::-1]
143
+
144
+ return np.expand_dims(image_array, axis=0)
145
+
146
+ def predict(
147
+ self,
148
+ image,
149
+ model_repo,
150
+ general_thresh,
151
+ general_mcut_enabled,
152
+ character_thresh,
153
+ character_mcut_enabled,
154
+ ):
155
+ self.load_model(model_repo)
156
+
157
+ image = self.prepare_image(image)
158
+
159
+ input_name = self.model.get_inputs()[0].name
160
+ label_name = self.model.get_outputs()[0].name
161
+ preds = self.model.run([label_name], {input_name: image})[0]
162
+
163
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
164
+
165
+ # First 4 labels are actually ratings: pick one with argmax
166
+ ratings_names = [labels[i] for i in self.rating_indexes]
167
+ rating = dict(ratings_names)
168
+
169
+ # Then we have general tags: pick any where prediction confidence > threshold
170
+ general_names = [labels[i] for i in self.general_indexes]
171
+
172
+ if general_mcut_enabled:
173
+ general_probs = np.array([x[1] for x in general_names])
174
+ general_thresh = mcut_threshold(general_probs)
175
+
176
+ general_res = [x for x in general_names if x[1] > general_thresh]
177
+ general_res = dict(general_res)
178
+
179
+ # Everything else is characters: pick any where prediction confidence > threshold
180
+ character_names = [labels[i] for i in self.character_indexes]
181
+
182
+ if character_mcut_enabled:
183
+ character_probs = np.array([x[1] for x in character_names])
184
+ character_thresh = mcut_threshold(character_probs)
185
+ character_thresh = max(0.15, character_thresh)
186
+
187
+ character_res = [x for x in character_names if x[1] > character_thresh]
188
+ character_res = dict(character_res)
189
+
190
+ sorted_general_strings = sorted(
191
+ general_res.items(),
192
+ key=lambda x: x[1],
193
+ reverse=True,
194
+ )
195
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
196
+ sorted_general_strings = (
197
+ ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
198
+ )
199
+
200
+ return sorted_general_strings, rating, character_res, general_res
201
+
202
+
203
+ def tagger_ui():
204
+
205
+ predictor = Predictor()
206
+
207
+ dropdown_list = [
208
+ SWINV2_MODEL_DSV3_REPO,
209
+ CONV_MODEL_DSV3_REPO,
210
+ VIT_MODEL_DSV3_REPO,
211
+ VIT_LARGE_MODEL_DSV3_REPO,
212
+ EVA02_LARGE_MODEL_DSV3_REPO,
213
+ MOAT_MODEL_DSV2_REPO,
214
+ SWIN_MODEL_DSV2_REPO,
215
+ CONV_MODEL_DSV2_REPO,
216
+ CONV2_MODEL_DSV2_REPO,
217
+ VIT_MODEL_DSV2_REPO,
218
+ ]
219
+
220
+ with gr.Row():
221
+ with gr.Column(variant="panel"):
222
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
223
+ model_repo = gr.Dropdown(
224
+ dropdown_list,
225
+ value=SWINV2_MODEL_DSV3_REPO,
226
+ label="模型",
227
+ )
228
+ with gr.Row():
229
+ general_thresh = gr.Slider(
230
+ 0,
231
+ 1,
232
+ step=0.05,
233
+ value=0.35,
234
+ label="一般Tag阈值",
235
+ scale=3,
236
+ )
237
+ general_mcut_enabled = gr.Checkbox(
238
+ value=False,
239
+ label="使用MCut阈值",
240
+ scale=1,
241
+ )
242
+ with gr.Row():
243
+ character_thresh = gr.Slider(
244
+ 0,
245
+ 1,
246
+ step=0.05,
247
+ value=0.85,
248
+ label="角色Tags阈值",
249
+ scale=3,
250
+ )
251
+ character_mcut_enabled = gr.Checkbox(
252
+ value=False,
253
+ label="使用MCut阈值",
254
+ scale=1,
255
+ )
256
+ with gr.Row():
257
+ clear = gr.ClearButton(
258
+ components=[
259
+ image,
260
+ model_repo,
261
+ general_thresh,
262
+ general_mcut_enabled,
263
+ character_thresh,
264
+ character_mcut_enabled,
265
+ ],
266
+ variant="secondary",
267
+ size="lg",
268
+ )
269
+ submit = gr.Button(value="提交", variant="primary", size="lg")
270
+ with gr.Column(variant="panel"):
271
+ sorted_general_strings = gr.Textbox(label="输出 (字符串)")
272
+ send_btn = gr.Button("发送到文生图",visible=False)
273
+ rating = gr.Label(label="分级")
274
+ character_res = gr.Label(label="输出 (角色)")
275
+ general_res = gr.Label(label="输出 (Tag)")
276
+ clear.add(
277
+ [
278
+ sorted_general_strings,
279
+ rating,
280
+ character_res,
281
+ general_res,
282
+ ]
283
+ )
284
+ sorted_general_strings.change(lambda s: gr.Button(visible=s is not None and len(s)), inputs=sorted_general_strings, outputs=send_btn)
285
+ submit.click(
286
+ predictor.predict,
287
+ inputs=[
288
+ image,
289
+ model_repo,
290
+ general_thresh,
291
+ general_mcut_enabled,
292
+ character_thresh,
293
+ character_mcut_enabled,
294
+ ],
295
+ outputs=[sorted_general_strings, rating, character_res, general_res],
296
+ )
297
+ return sorted_general_strings, send_btn