Update modeling_GOT.py
Browse files- modeling_GOT.py +181 -1
modeling_GOT.py
CHANGED
@@ -541,7 +541,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
541 |
offset=0,
|
542 |
sep_style=SeparatorStyle.MPT,
|
543 |
sep="<|im_end|>",
|
544 |
-
)
|
545 |
|
546 |
conv = conv_mpt.copy()
|
547 |
conv.append_message(conv.roles[0], qs)
|
@@ -657,3 +657,183 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
657 |
# with open(html_path_2, 'w') as web_f_new:
|
658 |
# web_f_new.write(new_web)
|
659 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
offset=0,
|
542 |
sep_style=SeparatorStyle.MPT,
|
543 |
sep="<|im_end|>",
|
544 |
+
)
|
545 |
|
546 |
conv = conv_mpt.copy()
|
547 |
conv.append_message(conv.roles[0], qs)
|
|
|
657 |
# with open(html_path_2, 'w') as web_f_new:
|
658 |
# web_f_new.write(new_web)
|
659 |
|
660 |
+
def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
|
661 |
+
|
662 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
663 |
+
best_ratio_diff = float('inf')
|
664 |
+
best_ratio = (1, 1)
|
665 |
+
area = width * height
|
666 |
+
for ratio in target_ratios:
|
667 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
668 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
669 |
+
if ratio_diff < best_ratio_diff:
|
670 |
+
best_ratio_diff = ratio_diff
|
671 |
+
best_ratio = ratio
|
672 |
+
elif ratio_diff == best_ratio_diff:
|
673 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
674 |
+
best_ratio = ratio
|
675 |
+
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
|
676 |
+
return best_ratio
|
677 |
+
|
678 |
+
orig_width, orig_height = image.size
|
679 |
+
aspect_ratio = orig_width / orig_height
|
680 |
+
|
681 |
+
# calculate the existing image aspect ratio
|
682 |
+
target_ratios = set(
|
683 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
684 |
+
i * j <= max_num and i * j >= min_num)
|
685 |
+
# print(target_ratios)
|
686 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
687 |
+
|
688 |
+
# find the closest aspect ratio to the target
|
689 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
690 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
691 |
+
|
692 |
+
# print(target_aspect_ratio)
|
693 |
+
# calculate the target width and height
|
694 |
+
target_width = image_size * target_aspect_ratio[0]
|
695 |
+
target_height = image_size * target_aspect_ratio[1]
|
696 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
697 |
+
|
698 |
+
# resize the image
|
699 |
+
resized_img = image.resize((target_width, target_height))
|
700 |
+
processed_images = []
|
701 |
+
for i in range(blocks):
|
702 |
+
box = (
|
703 |
+
(i % (target_width // image_size)) * image_size,
|
704 |
+
(i // (target_width // image_size)) * image_size,
|
705 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
706 |
+
((i // (target_width // image_size)) + 1) * image_size
|
707 |
+
)
|
708 |
+
# split the image
|
709 |
+
split_img = resized_img.crop(box)
|
710 |
+
processed_images.append(split_img)
|
711 |
+
assert len(processed_images) == blocks
|
712 |
+
if use_thumbnail and len(processed_images) != 1:
|
713 |
+
thumbnail_img = image.resize((image_size, image_size))
|
714 |
+
processed_images.append(thumbnail_img)
|
715 |
+
return processed_images
|
716 |
+
|
717 |
+
|
718 |
+
def chat_crop(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, multi_page=False):
|
719 |
+
# Model
|
720 |
+
self.disable_torch_init()
|
721 |
+
|
722 |
+
|
723 |
+
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
724 |
+
|
725 |
+
use_im_start_end = True
|
726 |
+
|
727 |
+
|
728 |
+
image_token_len = 256
|
729 |
+
|
730 |
+
image_list = []
|
731 |
+
|
732 |
+
if multi_page:
|
733 |
+
qs = 'OCR with format across multi pages: '
|
734 |
+
# only for png files
|
735 |
+
import glob
|
736 |
+
from natsort import natsorted
|
737 |
+
patches = glob.glob(image_file + '/*png')
|
738 |
+
patches = natsorted(patches)
|
739 |
+
sub_images = []
|
740 |
+
for sub_image in patches:
|
741 |
+
sub_images.append(self.load_image(sub_image))
|
742 |
+
|
743 |
+
ll = len(patches)
|
744 |
+
|
745 |
+
else:
|
746 |
+
qs = 'OCR with format upon the patch reference: '
|
747 |
+
img = self.load_image(image_file)
|
748 |
+
sub_images = self.dynamic_preprocess(img)
|
749 |
+
ll = len(sub_images)
|
750 |
+
|
751 |
+
for image in sub_images:
|
752 |
+
image_tensor_1 = image_processor_high(image)
|
753 |
+
image_list.append(image_tensor_1)
|
754 |
+
|
755 |
+
|
756 |
+
image_list = torch.stack(image_list)
|
757 |
+
|
758 |
+
print('====new images batch size======: ',image_list.shape)
|
759 |
+
|
760 |
+
|
761 |
+
if use_im_start_end:
|
762 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
|
763 |
+
else:
|
764 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
765 |
+
|
766 |
+
|
767 |
+
conv_mpt = Conversation(
|
768 |
+
system="""<|im_start|>system
|
769 |
+
You should follow the instructions carefully and explain your answers in detail.""",
|
770 |
+
# system = None,
|
771 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
772 |
+
version="mpt",
|
773 |
+
messages=(),
|
774 |
+
offset=0,
|
775 |
+
sep_style=SeparatorStyle.MPT,
|
776 |
+
sep="<|im_end|>",
|
777 |
+
)
|
778 |
+
|
779 |
+
conv = conv_mpt.copy()
|
780 |
+
conv.append_message(conv.roles[0], qs)
|
781 |
+
conv.append_message(conv.roles[1], None)
|
782 |
+
prompt = conv.get_prompt()
|
783 |
+
|
784 |
+
|
785 |
+
inputs = tokenizer([prompt])
|
786 |
+
|
787 |
+
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
788 |
+
|
789 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
790 |
+
keywords = [stop_str]
|
791 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
792 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
793 |
+
|
794 |
+
|
795 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
796 |
+
output_ids = self.generate(
|
797 |
+
input_ids,
|
798 |
+
images=[(image_list.half().cuda(), image_list.half().cuda())],
|
799 |
+
do_sample=False,
|
800 |
+
num_beams = 1,
|
801 |
+
# no_repeat_ngram_size = 20,
|
802 |
+
streamer=streamer,
|
803 |
+
max_new_tokens=4096,
|
804 |
+
stopping_criteria=[stopping_criteria]
|
805 |
+
)
|
806 |
+
|
807 |
+
# if render:
|
808 |
+
# print('==============rendering===============')
|
809 |
+
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
810 |
+
|
811 |
+
# if outputs.endswith(stop_str):
|
812 |
+
# outputs = outputs[:-len(stop_str)]
|
813 |
+
# outputs = outputs.strip()
|
814 |
+
|
815 |
+
# html_path = "./render_tools/" + "/content-mmd-to-html.html"
|
816 |
+
# html_path_2 = "./results/demo.html"
|
817 |
+
# right_num = outputs.count('\\right')
|
818 |
+
# left_num = outputs.count('\left')
|
819 |
+
|
820 |
+
# if right_num != left_num:
|
821 |
+
# outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
|
822 |
+
|
823 |
+
|
824 |
+
# outputs = outputs.replace('"', '``').replace('$', '')
|
825 |
+
|
826 |
+
# outputs_list = outputs.split('\n')
|
827 |
+
# gt= ''
|
828 |
+
# for out in outputs_list:
|
829 |
+
# gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
|
830 |
+
|
831 |
+
# gt = gt[:-2]
|
832 |
+
|
833 |
+
# with open(html_path, 'r') as web_f:
|
834 |
+
# lines = web_f.read()
|
835 |
+
# lines = lines.split("const text =")
|
836 |
+
# new_web = lines[0] + 'const text =' + gt + lines[1]
|
837 |
+
|
838 |
+
# with open(html_path_2, 'w') as web_f_new:
|
839 |
+
# web_f_new.write(new_web)
|