|
import torch |
|
from typing import Any, Optional |
|
from transformers import LayoutLMv2ForQuestionAnswering |
|
from transformers import LayoutLMv2Processor |
|
from transformers import LayoutLMv2FeatureExtractor |
|
from transformers import LayoutLMv2ImageProcessor |
|
from transformers import LayoutLMv2TokenizerFast |
|
from transformers.tokenization_utils_base import BatchEncoding |
|
from transformers.tokenization_utils_base import TruncationStrategy |
|
from transformers.utils import TensorType |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import pdf2image |
|
|
|
import logging |
|
from os import environ |
|
|
|
|
|
|
|
|
|
|
|
|
|
feature_extractor = LayoutLMv2FeatureExtractor() |
|
|
|
|
|
|
|
|
|
|
|
class NoOCRReaderFound(Exception): |
|
def __init__(self, e): |
|
self.e = e |
|
|
|
def __str__(self): |
|
return f"Could not load OCR Reader: {self.e}" |
|
|
|
def pdf_to_image(b: bytes): |
|
|
|
|
|
|
|
|
|
images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(b)] |
|
encoded_inputs = feature_extractor(images) |
|
print('feature_extractor: ', encoded_inputs.keys()) |
|
data = {} |
|
data['image'] = encoded_inputs.pixel_values |
|
data['words'] = encoded_inputs.words |
|
data['boxes'] = encoded_inputs.boxes |
|
return data |
|
|
|
|
|
def setup_logger(which_logger: Optional[str] = None): |
|
lib_level = logging.DEBUG |
|
root_level = logging.INFO |
|
log_format = '%(asctime)s - %(process)d - %(levelname)s - %(funcName)s - %(message)s' |
|
logging.basicConfig( |
|
filename=environ.get('LOG_FILE_PATH_LAYOUTLM_V2'), |
|
format=log_format, |
|
datefmt='%d-%b-%y %H:%M:%S', |
|
level=root_level, |
|
force=True |
|
) |
|
log = logging.getLogger(which_logger) |
|
log.setLevel(lib_level) |
|
return log |
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
class Funcs: |
|
|
|
@staticmethod |
|
def unnormalize_box(bbox, width, height): |
|
return [ |
|
width * (bbox[0] / 1000), |
|
height * (bbox[1] / 1000), |
|
width * (bbox[2] / 1000), |
|
height * (bbox[3] / 1000), |
|
] |
|
|
|
@staticmethod |
|
def num_spans(encoding: BatchEncoding) -> int: |
|
return len(encoding["input_ids"]) |
|
|
|
@staticmethod |
|
def p_mask(num_spans: int, encoding: BatchEncoding) -> list: |
|
try: |
|
return [ |
|
[tok != 1 for tok in encoding.sequence_ids(span_id)] \ |
|
for span_id in range(num_spans) |
|
] |
|
except Exception as e: |
|
raise |
|
|
|
@staticmethod |
|
def token_start_end(encoding, tokenizer): |
|
sequence_ids = encoding.sequence_ids() |
|
|
|
|
|
token_start_index = 0 |
|
while sequence_ids[token_start_index] != 1: |
|
token_start_index += 1 |
|
|
|
|
|
token_end_index = len(encoding.input_ids) - 1 |
|
while sequence_ids[token_end_index] != 1: |
|
token_end_index -= 1 |
|
|
|
print("Token start index:", token_start_index) |
|
print("Token end index:", token_end_index) |
|
print('token_start_end: ', tokenizer.decode(encoding.input_ids[token_start_index:token_end_index+1])) |
|
return token_start_index, token_end_index |
|
|
|
@staticmethod |
|
def reconstruct_answer(word_idx_start, word_idx_end, encoding, tokenizer): |
|
word_ids = encoding.word_ids()[token_start_index:token_end_index+1] |
|
print("Word ids:", word_ids) |
|
for id in word_ids: |
|
if id == word_idx_start: |
|
start_position = token_start_index |
|
else: |
|
token_start_index += 1 |
|
|
|
for id in word_ids[::-1]: |
|
if id == word_idx_end: |
|
end_position = token_end_index |
|
else: |
|
token_end_index -= 1 |
|
|
|
print("Reconstructed answer:", |
|
tokenizer.decode(encoding.input_ids[start_position:end_position+1]) |
|
) |
|
return start_position, end_position |
|
|
|
@staticmethod |
|
def sigmoid(_outputs): |
|
return 1.0 / (1.0 + np.exp(-_outputs)) |
|
|
|
@staticmethod |
|
def softmax(_outputs): |
|
maxes = np.max(_outputs, axis=-1, keepdims=True) |
|
shifted_exp = np.exp(_outputs - maxes) |
|
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) |
|
|
|
class EndpointHandler: |
|
def __init__(self, path="./"): |
|
|
|
self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path) |
|
self.tokenizer = LayoutLMv2TokenizerFast.from_pretrained(path) |
|
|
|
self.processor = LayoutLMv2Processor.from_pretrained( |
|
path, |
|
|
|
tokenizer=self.tokenizer) |
|
|
|
def __call__(self, data: dict[str, bytes]): |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the deserialized image file as PIL.Image |
|
""" |
|
image = data.pop("inputs", data) |
|
images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(image)] |
|
|
|
question = "what is the bill date" |
|
with torch.no_grad(): |
|
for image in images: |
|
|
|
|
|
encoding = self.processor( |
|
image, |
|
question, |
|
|
|
|
|
truncation=True, |
|
|
|
|
|
|
|
|
|
return_tensors=TensorType.PYTORCH |
|
) |
|
print('encoding: ', encoding.keys()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = self.model(**encoding) |
|
|
|
start_logits = outputs.start_logits |
|
end_logits = outputs.end_logits |
|
|
|
predicted_start_idx = start_logits.argmax(-1).item() |
|
predicted_end_idx = end_logits.argmax(-1).item() |
|
|
|
predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1] |
|
predicted_answer = self.processor.tokenizer.decode(predicted_answer_tokens) |
|
|
|
target_start_index = torch.tensor([7]) |
|
target_end_index = torch.tensor([14]) |
|
outputs = self.model(**encoding, start_positions=target_start_index, end_positions=target_end_index) |
|
|
|
|
|
|
|
logger.info(f''' |
|
START |
|
predicted_start_idx: {predicted_start_idx} |
|
predicted_end_idx: {predicted_end_idx} |
|
--- |
|
answer: {predicted_answer} |
|
|
|
END''') |
|
return {'data': 'success'} |
|
|