ucaslcl commited on
Commit
d57c223
1 Parent(s): 35a1672

Update modeling_GOT.py

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +518 -64
modeling_GOT.py CHANGED
@@ -1,16 +1,145 @@
1
- from transformers import AutoConfig, AutoModelForCausalLM, \
2
- Qwen2Config, Qwen2Model, Qwen2ForCausalLM, \
3
- CLIPVisionModel, CLIPImageProcessor
4
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
5
  from typing import List, Optional, Tuple, Union
6
- from transformers.cache_utils import Cache, DynamicCache
 
 
 
7
  import torch
8
  import torch.nn as nn
9
- import torch.nn.functional as F
10
  from torch.nn import CrossEntropyLoss
11
- from GOT.utils.constants import *
12
- from GOT.model.vision_encoder.vary_b import build_vary_vit_b
13
- from GOT.model.plug.blip_process import BlipImageEvalProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class GOTConfig(Qwen2Config):
16
  model_type = "GOT"
@@ -22,7 +151,7 @@ class GOTQwenModel(Qwen2Model):
22
  def __init__(self, config: Qwen2Config):
23
  super(GOTQwenModel, self).__init__(config)
24
 
25
- self.vision_tower_high = build_vary_vit_b()
26
 
27
  self.mm_projector_vary = nn.Linear(1024, 1024)
28
 
@@ -38,13 +167,8 @@ class GOTQwenModel(Qwen2Model):
38
  device="cuda"
39
  ):
40
 
41
- # Vary old codes, not use in GOT
42
- image_processor = BlipImageEvalProcessor(image_size=1024)
43
- # 1024*1024
44
-
45
- image_processor_high = BlipImageEvalProcessor(image_size=1024)
46
-
47
 
 
48
 
49
  self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
50
 
@@ -55,20 +179,17 @@ class GOTQwenModel(Qwen2Model):
55
 
56
  self.config.vision_tower = vision_tower
57
  self.config.image_token_len = image_token_len
58
- # self.config.use_im_start_end = use_im_start_end
59
  self.config.use_im_start_end = True
60
 
61
  self.config.vision_select_layer = vision_select_layer
62
  self.config.freeze_vision_tower = freeze_vision_tower
63
 
64
  return dict(
65
- image_processor=image_processor,
66
  image_processor_high=image_processor_high,
67
  image_token_len=image_token_len,
68
  )
69
 
70
- # def get_input_embeddings(self, x):
71
- # return self.wte(x)
72
 
73
  def forward(
74
  self,
@@ -98,9 +219,6 @@ class GOTQwenModel(Qwen2Model):
98
 
99
 
100
  if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
101
- # if True:
102
- # assert type(images) is list, ValueError("To fit both interleave and conversation, images must be list of batches of images")
103
- # print(im)
104
  use_im_start_end = getattr(self.config, "use_im_start_end", -1)
105
 
106
  vision_select_layer = getattr(self.config, "vision_select_layer", -1)
@@ -115,31 +233,20 @@ class GOTQwenModel(Qwen2Model):
115
 
116
  im_end_token = 151858
117
 
118
-
119
-
120
  image_features = []
121
 
122
- print(images.shape)
123
  for image in images:
124
- P, C, H, W = image[1].shape
125
- # with torch.set_grad_enabled(True):
126
- # # print(image[1].shape)
127
- # cnn_feature = vision_tower_high(image[1])
128
- # cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256 1024
129
- # # image_features.append(cnn_feature)
130
- # image_features_2.append(cnn_feature)
131
  if P == 1:
132
  with torch.set_grad_enabled(False):
133
- # print(image[1].shape)
134
- cnn_feature = vision_tower_high(image[1])
135
  cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
136
- # image_features.append(cnn_feature)
137
- # image_features_2.append(cnn_feature)
138
  image_feature = self.mm_projector_vary(cnn_feature)
139
  image_features.append(image_feature)
140
 
141
  else:
142
- image_patches = torch.unbind(image[1])
143
  image_patches_features = []
144
  for image_patch in image_patches:
145
  image_p = torch.stack([image_patch])
@@ -149,21 +256,15 @@ class GOTQwenModel(Qwen2Model):
149
  image_feature_p = self.mm_projector_vary(cnn_feature_p)
150
  image_patches_features.append(image_feature_p)
151
  image_feature = torch.cat(image_patches_features, dim=1)
152
- # print(P)
153
- # print(image_feature.shape)
154
- # exit()
155
  image_features.append(image_feature)
156
 
157
 
158
-
159
  dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
160
- # dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)
161
  dummy_image_features = dummy_image_features_2
162
  use_im_start_end = True
163
  new_input_embeds = []
164
  for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
165
  if (cur_input_ids == im_patch_token).sum() == 0:
166
- # multimodal LLM, but the current sample is not multimodal
167
  cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
168
  new_input_embeds.append(cur_input_embeds)
169
  continue
@@ -222,11 +323,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
222
  def get_model(self):
223
  return self.model
224
 
225
- # def _set_gradient_checkpointing(self, module, value=False):
226
- # if isinstance(module, GOTQwenModel):
227
- # module.gradient_checkpointing = value
228
- # @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
229
- # @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
230
  def forward(
231
  self,
232
  input_ids: torch.LongTensor = None,
@@ -248,12 +344,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
248
  )
249
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
250
 
251
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
252
- # print(input_ids)
253
- # print(len(images))
254
-
255
- # print(inputs_embeds)
256
-
257
  outputs = self.model(
258
  input_ids=input_ids,
259
  past_key_values=past_key_values,
@@ -268,7 +358,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
268
 
269
  )
270
 
271
-
272
  hidden_states = outputs[0]
273
  logits = self.lm_head(hidden_states)
274
  logits = logits.float()
@@ -368,24 +457,389 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
368
  ):
369
  config = self.get_model().config
370
 
371
- # add image patch token <image>
372
- # tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
373
  self.resize_token_embeddings(len(tokenizer))
374
- # config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
375
 
376
  config.im_patch_token = 151859
377
 
378
  config.use_im_start_end = True
379
 
380
- # add image start token <im_start> and end token <im_end>
381
  if config.use_im_start_end:
382
- # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
383
  self.resize_token_embeddings(len(tokenizer))
384
- # config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
385
-
386
  config.im_start_token, config.im_end_token = 151857, 151858
387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
- AutoConfig.register("GOT", GOTConfig)
390
- AutoModelForCausalLM.register(GOTConfig, GOTQwenForCausalLM)
391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
 
 
2
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3
  from typing import List, Optional, Tuple, Union
4
+ from transformers.cache_utils import Cache
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
  import torch
9
  import torch.nn as nn
 
10
  from torch.nn import CrossEntropyLoss
11
+ from .got_vision_b import build_GOT_vit_b
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ import dataclasses
15
+ from megfile import smart_open
16
+
17
+ DEFAULT_IMAGE_TOKEN = "<image>"
18
+ DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
+ DEFAULT_IM_START_TOKEN = '<img>'
20
+ DEFAULT_IM_END_TOKEN = '</img>'
21
+
22
+ from enum import auto, Enum
23
+ class SeparatorStyle(Enum):
24
+ """Different separator style."""
25
+ SINGLE = auto()
26
+ TWO = auto()
27
+ MPT = auto()
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class Conversation:
32
+ """A class that keeps all conversation history."""
33
+ system: str
34
+ roles: List[str]
35
+ messages: List[List[str]]
36
+ offset: int
37
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
38
+ sep: str = "<|im_end|>"
39
+ sep2: str = None
40
+ version: str = "Unknown"
41
+
42
+ skip_next: bool = False
43
+
44
+ def get_prompt(self):
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep + '\n'
47
+ for role, message in self.messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ return ret
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(self.messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message, _, _ = message
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ return ret
66
+ if self.sep_style == SeparatorStyle.MPT:
67
+ if self.system:
68
+ ret = self.system + self.sep
69
+ else:
70
+ ret = ''
71
+ for role, message in self.messages:
72
+ if message:
73
+ if type(message) is tuple:
74
+ message, _, _ = message
75
+ ret += role + message + self.sep
76
+ else:
77
+ ret += role
78
+ return ret
79
+ else:
80
+ raise ValueError(f"Invalid style: {self.sep_style}")
81
+
82
+
83
+ def append_message(self, role, message):
84
+ self.messages.append([role, message])
85
+
86
+ def copy(self):
87
+ return Conversation(
88
+ system=self.system,
89
+ roles=self.roles,
90
+ messages=[[x, y] for x, y in self.messages],
91
+ offset=self.offset,
92
+ sep_style=self.sep_style,
93
+ sep=self.sep,
94
+ sep2=self.sep2)
95
+
96
+
97
+
98
+ class KeywordsStoppingCriteria(StoppingCriteria):
99
+ def __init__(self, keywords, tokenizer, input_ids):
100
+ self.keywords = keywords
101
+ self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
102
+ self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
103
+ self.tokenizer = tokenizer
104
+ self.start_len = None
105
+ self.input_ids = input_ids
106
+
107
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
108
+ if self.start_len is None:
109
+ self.start_len = self.input_ids.shape[1]
110
+ else:
111
+ for keyword_id in self.keyword_ids:
112
+ if output_ids[0, -1] == keyword_id:
113
+ return True
114
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
115
+ for keyword in self.keywords:
116
+ if keyword in outputs:
117
+ return True
118
+ return False
119
+
120
+
121
+ class GOTImageEvalProcessor:
122
+ def __init__(self, image_size=384, mean=None, std=None):
123
+ if mean is None:
124
+ mean = (0.48145466, 0.4578275, 0.40821073)
125
+ if std is None:
126
+ std = (0.26862954, 0.26130258, 0.27577711)
127
+
128
+ self.normalize = transforms.Normalize(mean, std)
129
+
130
+ self.transform = transforms.Compose(
131
+ [
132
+ transforms.Resize(
133
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
134
+ ),
135
+ transforms.ToTensor(),
136
+ self.normalize,
137
+ ]
138
+ )
139
+ def __call__(self, item):
140
+ return self.transform(item)
141
+
142
+
143
 
144
  class GOTConfig(Qwen2Config):
145
  model_type = "GOT"
 
151
  def __init__(self, config: Qwen2Config):
152
  super(GOTQwenModel, self).__init__(config)
153
 
154
+ self.vision_tower_high = build_GOT_vit_b()
155
 
156
  self.mm_projector_vary = nn.Linear(1024, 1024)
157
 
 
167
  device="cuda"
168
  ):
169
 
 
 
 
 
 
 
170
 
171
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
172
 
173
  self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
174
 
 
179
 
180
  self.config.vision_tower = vision_tower
181
  self.config.image_token_len = image_token_len
182
+
183
  self.config.use_im_start_end = True
184
 
185
  self.config.vision_select_layer = vision_select_layer
186
  self.config.freeze_vision_tower = freeze_vision_tower
187
 
188
  return dict(
 
189
  image_processor_high=image_processor_high,
190
  image_token_len=image_token_len,
191
  )
192
 
 
 
193
 
194
  def forward(
195
  self,
 
219
 
220
 
221
  if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
 
 
 
222
  use_im_start_end = getattr(self.config, "use_im_start_end", -1)
223
 
224
  vision_select_layer = getattr(self.config, "vision_select_layer", -1)
 
233
 
234
  im_end_token = 151858
235
 
 
 
236
  image_features = []
237
 
238
+ print(images)
239
  for image in images:
240
+ P, C, H, W = image.shape
 
 
 
 
 
 
241
  if P == 1:
242
  with torch.set_grad_enabled(False):
243
+ cnn_feature = vision_tower_high(image)
 
244
  cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
 
 
245
  image_feature = self.mm_projector_vary(cnn_feature)
246
  image_features.append(image_feature)
247
 
248
  else:
249
+ image_patches = torch.unbind(image)
250
  image_patches_features = []
251
  for image_patch in image_patches:
252
  image_p = torch.stack([image_patch])
 
256
  image_feature_p = self.mm_projector_vary(cnn_feature_p)
257
  image_patches_features.append(image_feature_p)
258
  image_feature = torch.cat(image_patches_features, dim=1)
 
 
 
259
  image_features.append(image_feature)
260
 
261
 
 
262
  dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
 
263
  dummy_image_features = dummy_image_features_2
264
  use_im_start_end = True
265
  new_input_embeds = []
266
  for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
267
  if (cur_input_ids == im_patch_token).sum() == 0:
 
268
  cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
269
  new_input_embeds.append(cur_input_embeds)
270
  continue
 
323
  def get_model(self):
324
  return self.model
325
 
 
 
 
 
 
326
  def forward(
327
  self,
328
  input_ids: torch.LongTensor = None,
 
344
  )
345
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
346
 
 
 
 
 
 
 
347
  outputs = self.model(
348
  input_ids=input_ids,
349
  past_key_values=past_key_values,
 
358
 
359
  )
360
 
 
361
  hidden_states = outputs[0]
362
  logits = self.lm_head(hidden_states)
363
  logits = logits.float()
 
457
  ):
458
  config = self.get_model().config
459
 
460
+
 
461
  self.resize_token_embeddings(len(tokenizer))
 
462
 
463
  config.im_patch_token = 151859
464
 
465
  config.use_im_start_end = True
466
 
 
467
  if config.use_im_start_end:
 
468
  self.resize_token_embeddings(len(tokenizer))
 
 
469
  config.im_start_token, config.im_end_token = 151857, 151858
470
 
471
+ def load_image(self, image_file):
472
+ if image_file.startswith('http') or image_file.startswith('https'):
473
+ response = requests.get(image_file)
474
+ image = Image.open(BytesIO(response.content)).convert('RGB')
475
+ else:
476
+ image = Image.open(image_file).convert('RGB')
477
+ return image
478
+
479
+ def disable_torch_init(self):
480
+ """
481
+ Disable the redundant torch default initialization to accelerate model creation.
482
+ """
483
+ import torch
484
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
+
487
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None):
488
+
489
+ self.disable_torch_init()
490
+
491
+
492
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
493
+
494
+ use_im_start_end = True
495
+
496
+ image_token_len = 256
497
+
498
+ image = self.load_image(image_file)
499
+
500
+ w, h = image.size
501
+
502
+ if ocr_type == 'format':
503
+ qs = 'OCR with format: '
504
+ else:
505
+ qs = 'OCR: '
506
+
507
+ if ocr_box:
508
+ bbox = eval(ocr_box)
509
+ if len(bbox) == 2:
510
+ bbox[0] = int(bbox[0]/w*1000)
511
+ bbox[1] = int(bbox[1]/h*1000)
512
+ if len(bbox) == 4:
513
+ bbox[0] = int(bbox[0]/w*1000)
514
+ bbox[1] = int(bbox[1]/h*1000)
515
+ bbox[2] = int(bbox[2]/w*1000)
516
+ bbox[3] = int(bbox[3]/h*1000)
517
+ if ocr_type == 'format':
518
+ qs = str(bbox) + ' ' + 'OCR with format: '
519
+ else:
520
+ qs = str(bbox) + ' ' + 'OCR: '
521
+
522
+ if ocr_color:
523
+ if ocr_type == 'format':
524
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: '
525
+ else:
526
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
527
+
528
+ if use_im_start_end:
529
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
530
+ else:
531
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
532
+
533
+
534
+ conv_mpt = Conversation(
535
+ system="""<|im_start|>system
536
+ You should follow the instructions carefully and explain your answers in detail.""",
537
+ # system = None,
538
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
539
+ version="mpt",
540
+ messages=(),
541
+ offset=0,
542
+ sep_style=SeparatorStyle.MPT,
543
+ sep="<|im_end|>",
544
+ )
545
+
546
+ conv = conv_mpt.copy()
547
+ conv.append_message(conv.roles[0], qs)
548
+ conv.append_message(conv.roles[1], None)
549
+ prompt = conv.get_prompt()
550
+
551
+ print(prompt)
552
+
553
+ inputs = tokenizer([prompt])
554
 
555
+ image_tensor_1 = image_processor_high(image)
 
556
 
557
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
558
+
559
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
560
+ keywords = [stop_str]
561
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
562
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
563
+
564
+
565
+ with torch.autocast("cuda", dtype=torch.bfloat16):
566
+ output_ids = self.generate(
567
+ input_ids,
568
+ images=[image_tensor_1.unsqueeze(0).half().cuda()],
569
+ do_sample=False,
570
+ num_beams = 1,
571
+ no_repeat_ngram_size = 20,
572
+ streamer=streamer,
573
+ max_new_tokens=4096,
574
+ stopping_criteria=[stopping_criteria]
575
+ )
576
+
577
+
578
+ if render:
579
+ print('==============rendering===============')
580
+ from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
581
+
582
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
583
+
584
+ if outputs.endswith(stop_str):
585
+ outputs = outputs[:-len(stop_str)]
586
+ outputs = outputs.strip()
587
+
588
+ if '**kern' in outputs:
589
+ import verovio
590
+ from cairosvg import svg2png
591
+ import cv2
592
+ import numpy as np
593
+ tk = verovio.toolkit()
594
+ tk.loadData(outputs)
595
+ tk.setOptions({"pageWidth": 2100, "footer": 'none',
596
+ 'barLineWidth': 0.5, 'beamMaxSlope': 15,
597
+ 'staffLineWidth': 0.2, 'spacingStaff': 6})
598
+ tk.getPageCount()
599
+ svg = tk.renderToSVG()
600
+ svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
601
+
602
+ svg_to_html(svg, save_render_file)
603
+
604
+ if ocr_type == 'format' and '**kern' not in outputs:
605
+
606
+
607
+ if '\\begin{tikzpicture}' not in outputs:
608
+ html_path_2 = save_render_file
609
+ right_num = outputs.count('\\right')
610
+ left_num = outputs.count('\left')
611
+
612
+ if right_num != left_num:
613
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
614
+
615
+
616
+ outputs = outputs.replace('"', '``').replace('$', '')
617
+
618
+ outputs_list = outputs.split('\n')
619
+ gt= ''
620
+ for out in outputs_list:
621
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
622
+
623
+ gt = gt[:-2]
624
+
625
+
626
+ lines = content_mmd_to_html
627
+ lines = lines.split("const text =")
628
+ new_web = lines[0] + 'const text =' + gt + lines[1]
629
+
630
+ else:
631
+ html_path_2 = save_render_file
632
+ outputs = outputs.translate(translation_table)
633
+ outputs_list = outputs.split('\n')
634
+ gt= ''
635
+ for out in outputs_list:
636
+ if out:
637
+ if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
638
+ while out[-1] == ' ':
639
+ out = out[:-1]
640
+ if out is None:
641
+ break
642
+
643
+ if out:
644
+ if out[-1] != ';':
645
+ gt += out[:-1] + ';\n'
646
+ else:
647
+ gt += out + '\n'
648
+ else:
649
+ gt += out + '\n'
650
+
651
+
652
+ lines = tik_html
653
+ lines = lines.split("const text =")
654
+ new_web = lines[0] + gt + lines[1]
655
+
656
+ with smart_open(html_path_2, 'w') as web_f_new:
657
+ web_f_new.write(new_web)
658
+
659
+ def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
660
+
661
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
662
+ best_ratio_diff = float('inf')
663
+ best_ratio = (1, 1)
664
+ area = width * height
665
+ for ratio in target_ratios:
666
+ target_aspect_ratio = ratio[0] / ratio[1]
667
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
668
+ if ratio_diff < best_ratio_diff:
669
+ best_ratio_diff = ratio_diff
670
+ best_ratio = ratio
671
+ elif ratio_diff == best_ratio_diff:
672
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
673
+ best_ratio = ratio
674
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
675
+ return best_ratio
676
+
677
+ orig_width, orig_height = image.size
678
+ aspect_ratio = orig_width / orig_height
679
+
680
+ # calculate the existing image aspect ratio
681
+ target_ratios = set(
682
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
683
+ i * j <= max_num and i * j >= min_num)
684
+ # print(target_ratios)
685
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
686
+
687
+ # find the closest aspect ratio to the target
688
+ target_aspect_ratio = find_closest_aspect_ratio(
689
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
690
+
691
+ # print(target_aspect_ratio)
692
+ # calculate the target width and height
693
+ target_width = image_size * target_aspect_ratio[0]
694
+ target_height = image_size * target_aspect_ratio[1]
695
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
696
+
697
+ # resize the image
698
+ resized_img = image.resize((target_width, target_height))
699
+ processed_images = []
700
+ for i in range(blocks):
701
+ box = (
702
+ (i % (target_width // image_size)) * image_size,
703
+ (i // (target_width // image_size)) * image_size,
704
+ ((i % (target_width // image_size)) + 1) * image_size,
705
+ ((i // (target_width // image_size)) + 1) * image_size
706
+ )
707
+ # split the image
708
+ split_img = resized_img.crop(box)
709
+ processed_images.append(split_img)
710
+ assert len(processed_images) == blocks
711
+ if use_thumbnail and len(processed_images) != 1:
712
+ thumbnail_img = image.resize((image_size, image_size))
713
+ processed_images.append(thumbnail_img)
714
+ return processed_images
715
+
716
+
717
+ def chat_plus(self, tokenizer, image_file_list, render=False, save_render_file=None):
718
+ # Model
719
+ self.disable_torch_init()
720
+ multi_page=False
721
+
722
+
723
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
724
+
725
+ use_im_start_end = True
726
+
727
+
728
+ image_token_len = 256
729
+
730
+ image_list = []
731
+
732
+ if len(image_file_list)>1:
733
+ multi_page = True
734
+
735
+ if multi_page:
736
+ qs = 'OCR with format across multi pages: '
737
+ # only for png files
738
+ import glob
739
+ # from natsort import natsorted
740
+ # patches = glob.glob(image_file + '/*png')
741
+ patches = image_file_list
742
+ # patches = natsorted(patches)
743
+ sub_images = []
744
+ for sub_image in patches:
745
+ sub_images.append(self.load_image(sub_image))
746
+
747
+ ll = len(patches)
748
+ print(patches)
749
+ print("len ll: ", ll)
750
+
751
+ else:
752
+ qs = 'OCR with format upon the patch reference: '
753
+ img = self.load_image(image_file_list[0])
754
+ sub_images = self.dynamic_preprocess(img)
755
+ ll = len(sub_images)
756
+
757
+ for image in sub_images:
758
+ image_tensor_1 = image_processor_high(image)
759
+ image_list.append(image_tensor_1)
760
+
761
+
762
+ image_list = torch.stack(image_list)
763
+
764
+ print('====new images batch size======: ',image_list.shape)
765
+
766
+
767
+ if use_im_start_end:
768
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
769
+ else:
770
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
771
+
772
+
773
+ conv_mpt = Conversation(
774
+ system="""<|im_start|>system
775
+ You should follow the instructions carefully and explain your answers in detail.""",
776
+ # system = None,
777
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
778
+ version="mpt",
779
+ messages=(),
780
+ offset=0,
781
+ sep_style=SeparatorStyle.MPT,
782
+ sep="<|im_end|>",
783
+ )
784
+
785
+ conv = conv_mpt.copy()
786
+ conv.append_message(conv.roles[0], qs)
787
+ conv.append_message(conv.roles[1], None)
788
+ prompt = conv.get_prompt()
789
+
790
+ print(prompt)
791
+
792
+ inputs = tokenizer([prompt])
793
+
794
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
795
+
796
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
797
+ keywords = [stop_str]
798
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
799
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
800
+
801
+
802
+ with torch.autocast("cuda", dtype=torch.bfloat16):
803
+ output_ids = self.generate(
804
+ input_ids,
805
+ images=[image_list.half().cuda()],
806
+ do_sample=False,
807
+ num_beams = 1,
808
+ # no_repeat_ngram_size = 20,
809
+ streamer=streamer,
810
+ max_new_tokens=4096,
811
+ stopping_criteria=[stopping_criteria]
812
+ )
813
+
814
+ if render:
815
+ print('==============rendering===============')
816
+ from .render_tools import content_mmd_to_html
817
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
818
+
819
+ if outputs.endswith(stop_str):
820
+ outputs = outputs[:-len(stop_str)]
821
+ outputs = outputs.strip()
822
+
823
+ html_path_2 = save_render_file
824
+ right_num = outputs.count('\\right')
825
+ left_num = outputs.count('\left')
826
+
827
+ if right_num != left_num:
828
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
829
+
830
+
831
+ outputs = outputs.replace('"', '``').replace('$', '')
832
+
833
+ outputs_list = outputs.split('\n')
834
+ gt= ''
835
+ for out in outputs_list:
836
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
837
+
838
+ gt = gt[:-2]
839
+
840
+ lines = content_mmd_to_html
841
+ lines = lines.split("const text =")
842
+ new_web = lines[0] + 'const text =' + gt + lines[1]
843
+
844
+ with smart_open(html_path_2, 'w') as web_f_new:
845
+ web_f_new.write(new_web)