|
import argparse |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
import os |
|
from .conversation import conv_templates, SeparatorStyle |
|
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria |
|
from .utils import KeywordsStoppingCriteria, disable_torch_init |
|
|
|
from PIL import Image |
|
|
|
import os |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
from .blip_process import BlipImageEvalProcessor |
|
|
|
from .GOT_ocr_2_0 import GOTQwenModel, GOTQwenForCausalLM, GOTConfig |
|
|
|
from transformers import TextStreamer |
|
import re |
|
import string |
|
|
|
|
|
import string |
|
|
|
punctuation_dict = { |
|
",": ",", |
|
"。": ".", |
|
} |
|
|
|
|
|
def svg_to_html(svg_content, output_filename): |
|
|
|
html_content = f""" |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>SVG Embedded in HTML</title> |
|
</head> |
|
<body> |
|
<svg width="2100" height="15000" xmlns="http://www.w3.org/2000/svg"> |
|
{svg_content} |
|
</svg> |
|
</body> |
|
</html> |
|
""" |
|
|
|
with open(output_filename, 'w') as file: |
|
file.write(html_content) |
|
|
|
|
|
|
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>' |
|
|
|
DEFAULT_IM_START_TOKEN = '<img>' |
|
DEFAULT_IM_END_TOKEN = '</img>' |
|
|
|
|
|
|
|
translation_table = str.maketrans(punctuation_dict) |
|
|
|
|
|
def load_image(image_file): |
|
if image_file.startswith('http') or image_file.startswith('https'): |
|
response = requests.get(image_file) |
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
else: |
|
image = Image.open(image_file).convert('RGB') |
|
return image |
|
|
|
|
|
def eval_model(model_name, image_file, ocr_type, ocr_box='', ocr_color='', render=False): |
|
|
|
disable_torch_init() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval() |
|
|
|
|
|
|
|
model.to(device='cuda', dtype=torch.bfloat16) |
|
|
|
|
|
|
|
image_processor = BlipImageEvalProcessor(image_size=1024) |
|
|
|
image_processor_high = BlipImageEvalProcessor(image_size=1024) |
|
|
|
use_im_start_end = True |
|
|
|
image_token_len = 256 |
|
|
|
image = load_image(image_file) |
|
|
|
w, h = image.size |
|
|
|
|
|
if ocr_type == 'format': |
|
qs = 'OCR with format: ' |
|
else: |
|
qs = 'OCR: ' |
|
|
|
if ocr_box: |
|
bbox = eval(ocr_box) |
|
if len(bbox) == 2: |
|
bbox[0] = int(bbox[0]/w*1000) |
|
bbox[1] = int(bbox[1]/h*1000) |
|
if len(bbox) == 4: |
|
bbox[0] = int(bbox[0]/w*1000) |
|
bbox[1] = int(bbox[1]/h*1000) |
|
bbox[2] = int(bbox[2]/w*1000) |
|
bbox[3] = int(bbox[3]/h*1000) |
|
if ocr_type == 'format': |
|
qs = str(bbox) + ' ' + 'OCR with format: ' |
|
else: |
|
qs = str(bbox) + ' ' + 'OCR: ' |
|
|
|
if ocr_color: |
|
if ocr_type == 'format': |
|
qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: ' |
|
else: |
|
qs = '[' + ocr_color + ']' + ' ' + 'OCR: ' |
|
|
|
if use_im_start_end: |
|
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs |
|
else: |
|
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs |
|
|
|
|
|
|
|
conv_mode = "mpt" |
|
args.conv_mode = conv_mode |
|
|
|
conv = conv_templates[args.conv_mode].copy() |
|
conv.append_message(conv.roles[0], qs) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
print(prompt) |
|
|
|
|
|
inputs = tokenizer([prompt]) |
|
|
|
|
|
|
|
image_1 = image.copy() |
|
image_tensor = image_processor(image) |
|
|
|
|
|
image_tensor_1 = image_processor_high(image_1) |
|
|
|
|
|
input_ids = torch.as_tensor(inputs.input_ids).cuda() |
|
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
keywords = [stop_str] |
|
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) |
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())], |
|
do_sample=False, |
|
num_beams = 1, |
|
no_repeat_ngram_size = 20, |
|
streamer=streamer, |
|
max_new_tokens=4096, |
|
stopping_criteria=[stopping_criteria] |
|
) |
|
|
|
|
|
if render: |
|
print('==============rendering===============') |
|
|
|
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() |
|
|
|
if outputs.endswith(stop_str): |
|
outputs = outputs[:-len(stop_str)] |
|
outputs = outputs.strip() |
|
|
|
if '**kern' in outputs: |
|
import verovio |
|
from cairosvg import svg2png |
|
import cv2 |
|
import numpy as np |
|
tk = verovio.toolkit() |
|
tk.loadData(outputs) |
|
tk.setOptions({"pageWidth": 2100, "footer": 'none', |
|
'barLineWidth': 0.5, 'beamMaxSlope': 15, |
|
'staffLineWidth': 0.2, 'spacingStaff': 6}) |
|
tk.getPageCount() |
|
svg = tk.renderToSVG() |
|
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"") |
|
|
|
svg_to_html(svg, "./results/demo.html") |
|
|
|
if ocr_type == 'format' and '**kern' not in outputs: |
|
|
|
|
|
if '\\begin{tikzpicture}' not in outputs: |
|
html_path = "./render_tools/" + "/content-mmd-to-html.html" |
|
html_path_2 = "./results/demo.html" |
|
right_num = outputs.count('\\right') |
|
left_num = outputs.count('\left') |
|
|
|
if right_num != left_num: |
|
outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.') |
|
|
|
|
|
outputs = outputs.replace('"', '``').replace('$', '') |
|
|
|
outputs_list = outputs.split('\n') |
|
gt= '' |
|
for out in outputs_list: |
|
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n' |
|
|
|
gt = gt[:-2] |
|
|
|
with open(html_path, 'r') as web_f: |
|
lines = web_f.read() |
|
lines = lines.split("const text =") |
|
new_web = lines[0] + 'const text =' + gt + lines[1] |
|
else: |
|
html_path = "./render_tools/" + "/tikz.html" |
|
html_path_2 = "./results/demo.html" |
|
outputs = outputs.translate(translation_table) |
|
outputs_list = outputs.split('\n') |
|
gt= '' |
|
for out in outputs_list: |
|
if out: |
|
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out: |
|
while out[-1] == ' ': |
|
out = out[:-1] |
|
if out is None: |
|
break |
|
|
|
if out: |
|
if out[-1] != ';': |
|
gt += out[:-1] + ';\n' |
|
else: |
|
gt += out + '\n' |
|
else: |
|
gt += out + '\n' |
|
|
|
|
|
with open(html_path, 'r') as web_f: |
|
lines = web_f.read() |
|
lines = lines.split("const text =") |
|
new_web = lines[0] + gt + lines[1] |
|
|
|
with open(html_path_2, 'w') as web_f_new: |
|
web_f_new.write(new_web) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model-name", type=str, default="facebook/opt-350m") |
|
parser.add_argument("--image-file", type=str, required=True) |
|
parser.add_argument("--type", type=str, required=True) |
|
parser.add_argument("--box", type=str, default= '') |
|
parser.add_argument("--color", type=str, default= '') |
|
parser.add_argument("--render", action='store_true') |
|
args = parser.parse_args() |
|
|
|
eval_model(args) |
|
|