File size: 4,053 Bytes
ca53d7c 88a5db7 ca53d7c 4792343 5bb196c 4792343 ec9e91a 4792343 5bb196c ca53d7c 5d1e573 ca53d7c 88a5db7 ca53d7c 6c0128c 4792343 ca53d7c ec9e91a 5d1e573 3e193b0 5bb196c ca53d7c 6c0128c 4792343 ca53d7c a84bd32 6bbcef4 c77986f ca53d7c c77986f 6bbcef4 ca53d7c c77986f ca53d7c c77986f c3949f9 c77986f ca53d7c c77986f ca53d7c c77986f ca53d7c 5bb196c ca53d7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import numpy as np
import torch
from torch import nn
import streamlit as st
import os
from PIL import Image
from io import BytesIO
import transformers
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
from logits_ngrams import NoRepeatNGramLogitsProcessor, get_table_token_ids
def run_prediction(sample, model, processor, mode):
skip_tokens = get_table_token_ids(processor)
no_repeat_ngram_size = 15
if mode == "OCR":
prompt = "<s><s_pretraining>"
else:
prompt = "<s><s_hierarchical>"
print("prompt:", prompt)
print("no_repeat_ngram_size:", no_repeat_ngram_size)
pixel_values = processor(np.array(
sample,
np.float32,
), return_tensors="pt").pixel_values
transformers.set_seed(42)
with torch.no_grad():
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
logits_processor=[NoRepeatNGramLogitsProcessor(no_repeat_ngram_size, skip_tokens)],
do_sample=True,
top_p=0.92,
top_k=5,
no_repeat_ngram_size=15,
num_beams=3,
output_attentions=False,
output_hidden_states=False,
)
# process output
prediction = processor.batch_decode(outputs)[0]
print(prediction)
return prediction
logo = Image.open("./rsz_unstructured_logo.png")
st.image(logo)
st.markdown('''
### Chipper
Chipper is an OCR-free Document Understanding Transformer. It was pre-trained with over 1M documents from public sources and fine-tuned on a large range of documents.
At [Unstructured.io](https://github.com/Unstructured-IO/unstructured) we are on a mission to build custom preprocessing pipelines for labeling, training, or production ML-ready pipelines.
Come and join us in our public repos and contribute! Each of your contributions and feedback holds great value and is very significant to the community.
''')
image_upload = None
photo = None
with st.sidebar:
# file upload
uploaded_file = st.file_uploader("Upload a document")
if uploaded_file is not None:
# To read file as bytes:
image_bytes_data = uploaded_file.getvalue()
image_upload = Image.open(BytesIO(image_bytes_data))
mode = st.selectbox('Mode', ('OCR', 'Element annotation'), index=1)
if image_upload:
image = image_upload
else:
image = Image.open(f"./document.png")
st.image(image, caption='Your target document')
with st.spinner(f'Processing the document ...'):
pre_trained_model = "unstructuredio/chipper-v3"
processor = DonutProcessor.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN'])
device = "cuda" if torch.cuda.is_available() else "cpu"
if 'model' in st.session_state:
model = st.session_state['model']
else:
model = VisionEncoderDecoderModel.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN'])
from huggingface_hub import hf_hub_download
lm_head_file = hf_hub_download(
repo_id=pre_trained_model, filename="lm_head.pth", token=os.environ['HF_TOKEN']
)
rank = 128
model.decoder.lm_head = nn.Sequential(
nn.Linear(model.decoder.lm_head.weight.shape[1], rank, bias=False),
nn.Linear(rank, rank, bias=False),
nn.Linear(rank, model.decoder.lm_head.weight.shape[0], bias=True),
)
model.decoder.lm_head.load_state_dict(torch.load(lm_head_file))
model.eval()
model.to(device)
st.session_state['model'] = model
st.info(f'Parsing document')
parsed_info = run_prediction(image.convert("RGB"), model, processor, mode)
st.text(f'\nDocument:')
st.text_area('Output text', value=parsed_info, height=800)
|