qtoino commited on
Commit
7974bc5
1 Parent(s): 9b65e39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -67
app.py CHANGED
@@ -1,75 +1,30 @@
1
  import gradio as gr
2
- import requests
3
- from PIL import Image
4
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
5
- import spaces
6
 
7
- @spaces.GPU
8
- def infer_infographics(image, question):
9
- model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ai2d-base")
10
- processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
11
 
12
- inputs = processor(images=image, text=question, return_tensors="pt")
 
13
 
 
 
 
14
  predictions = model.generate(**inputs)
15
  return processor.decode(predictions[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- @spaces.GPU
18
- def infer_ui(image, question):
19
- model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-screen2words-base")
20
- processor = Pix2StructProcessor.from_pretrained("google/pix2struct-screen2words-base")
21
-
22
- inputs = processor(images=image,text=question, return_tensors="pt")
23
-
24
- predictions = model.generate(**inputs)
25
- return processor.decode(predictions[0], skip_special_tokens=True)
26
-
27
- @spaces.GPU
28
- def infer_chart(image, question):
29
- model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-chartqa-base")
30
- processor = Pix2StructProcessor.from_pretrained("google/pix2struct-chartqa-base")
31
-
32
- inputs = processor(images=image, text=question, return_tensors="pt")
33
-
34
- predictions = model.generate(**inputs)
35
- return processor.decode(predictions[0], skip_special_tokens=True)
36
-
37
- @spaces.GPU
38
- def infer_doc(image, question):
39
- model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-base")
40
- processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-base")
41
- inputs = processor(images=image, text=question, return_tensors="pt")
42
- predictions = model.generate(**inputs)
43
- return processor.decode(predictions[0], skip_special_tokens=True)
44
-
45
- css = """
46
- #mkd {
47
- height: 500px;
48
- overflow: auto;
49
- border: 1px solid #ccc;
50
- }
51
- """
52
-
53
- with gr.Blocks(css=css) as demo:
54
- gr.HTML("<h1><center>Pix2Struct 📄<center><h1>")
55
- gr.HTML("<h3><center>Pix2Struct is a powerful backbone for visual question answering. ⚡</h3>")
56
- gr.HTML("<h3><center>This app has base version of the model. For better performance, use large checkpoints.<h3>")
57
-
58
- with gr.Row():
59
- with gr.Column():
60
- input_img = gr.Image(label="Input Document")
61
- question = gr.Text(label="Question")
62
- submit_btn = gr.Button(value="Submit")
63
- output = gr.Text(label="Answer")
64
- gr.Examples(
65
- [["docvqa_example.png", "How many items are sold?"]],
66
- inputs = [input_img, question],
67
- outputs = [output],
68
- fn=infer_doc,
69
- cache_examples=True,
70
- label='Click on any Examples below to get Document Question Answering results quickly 👇'
71
- )
72
-
73
- submit_btn.click(infer_doc, [input_img, question], [output])
74
-
75
- demo.launch(debug=True)
 
1
  import gradio as gr
2
+ # from PIL import Image
 
3
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
 
4
 
 
 
 
 
5
 
6
+ model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-base")
7
+ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-base")
8
 
9
+ def process_document(image, question):
10
+ # image = Image.open(image)
11
+ inputs = processor(images=image, text=question, return_tensors="pt")
12
  predictions = model.generate(**inputs)
13
  return processor.decode(predictions[0], skip_special_tokens=True)
14
+
15
+ description = "Demo for pix2struct fine-tuned on DocVQA (document visual question answering). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
16
+ article = "<p style='text-align: center'><a href='https://arxiv.org/pdf/2210.03347.pdf' target='_blank'>PIX2STRUCT: SCREENSHOT PARSING AS PRETRAINING FOR VISUAL LANGUAGE UNDERSTANDING</a></p>"
17
+
18
+ demo = gr.Interface(
19
+ fn=process_document,
20
+ inputs=["image", "text"],
21
+ outputs="text",
22
+ title="Demo: pix2struct for DocVQA",
23
+ description=description,
24
+ article=article,
25
+ enable_queue=True,
26
+ examples=[["example_1.png", "When is the coffee break?"], ["example_2.jpeg", "What's the population of Stoddard?"]],
27
+ cache_examples=False)
28
+
29
+ demo.launch()
30