Updated Chipper model
Browse files- app.py +4 -4
- logits_ngrams.py +1 -1
app.py
CHANGED
@@ -13,7 +13,7 @@ from logits_ngrams import NoRepeatNGramLogitsProcessor, get_table_token_ids
|
|
13 |
def run_prediction(sample, model, processor, mode):
|
14 |
|
15 |
skip_tokens = get_table_token_ids(processor)
|
16 |
-
no_repeat_ngram_size =
|
17 |
|
18 |
if mode == "OCR":
|
19 |
prompt = "<s><s_pretraining>"
|
@@ -35,9 +35,9 @@ def run_prediction(sample, model, processor, mode):
|
|
35 |
decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
|
36 |
logits_processor=[NoRepeatNGramLogitsProcessor(no_repeat_ngram_size, skip_tokens)],
|
37 |
do_sample=True,
|
38 |
-
top_p=0.92,
|
39 |
top_k=5,
|
40 |
-
no_repeat_ngram_size=
|
41 |
num_beams=3,
|
42 |
output_attentions=False,
|
43 |
output_hidden_states=False,
|
@@ -81,7 +81,7 @@ else:
|
|
81 |
st.image(image, caption='Your target document')
|
82 |
|
83 |
with st.spinner(f'Processing the document ...'):
|
84 |
-
pre_trained_model = "unstructuredio/chipper-fast-fine-tuning"
|
85 |
processor = DonutProcessor.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN'])
|
86 |
|
87 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
13 |
def run_prediction(sample, model, processor, mode):
|
14 |
|
15 |
skip_tokens = get_table_token_ids(processor)
|
16 |
+
no_repeat_ngram_size = 15
|
17 |
|
18 |
if mode == "OCR":
|
19 |
prompt = "<s><s_pretraining>"
|
|
|
35 |
decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
|
36 |
logits_processor=[NoRepeatNGramLogitsProcessor(no_repeat_ngram_size, skip_tokens)],
|
37 |
do_sample=True,
|
38 |
+
top_p=0.92,
|
39 |
top_k=5,
|
40 |
+
no_repeat_ngram_size=25,
|
41 |
num_beams=3,
|
42 |
output_attentions=False,
|
43 |
output_hidden_states=False,
|
|
|
81 |
st.image(image, caption='Your target document')
|
82 |
|
83 |
with st.spinner(f'Processing the document ...'):
|
84 |
+
pre_trained_model = "unstructuredio/chipper-fast-fine-tuning-oct-23-release"
|
85 |
processor = DonutProcessor.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN'])
|
86 |
|
87 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
logits_ngrams.py
CHANGED
@@ -59,5 +59,5 @@ def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len
|
|
59 |
|
60 |
|
61 |
def get_table_token_ids(processor):
|
62 |
-
|
63 |
|
|
|
59 |
|
60 |
|
61 |
def get_table_token_ids(processor):
|
62 |
+
return {token_id for token, token_id in processor.tokenizer.get_added_vocab().items() if token.startswith("<t") or token.startswith("</t") }
|
63 |
|