JGKaaij commited on
Commit
f377df5
1 Parent(s): ada6118

Upload flask_kosmos2.py

Browse files

A flask server you can run to prompt kosmos2.

Files changed (1) hide show
  1. flask_kosmos2.py +51 -0
flask_kosmos2.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq
3
+ from flask import Flask, request, jsonify
4
+
5
+ app = Flask(__name__)
6
+
7
+ model = AutoModelForVision2Seq.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
8
+ processor = AutoProcessor.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
9
+
10
+
11
+ @app.route('/process_grounding_prompt', methods=['POST'])
12
+ def process_prompt():
13
+ try:
14
+ # Get the uploaded image data from the POST request
15
+ uploaded_file = request.files['image']
16
+ prompt = request.form.get('prompt')
17
+ image = Image.open(uploaded_file.stream)
18
+
19
+ inputs = processor(text='<grounding>'+prompt, images=image, return_tensors="pt")
20
+
21
+ generated_ids = model.generate(
22
+ pixel_values=inputs["pixel_values"],
23
+ input_ids=inputs["input_ids"][:, :-1],
24
+ attention_mask=inputs["attention_mask"][:, :-1],
25
+ img_features=None,
26
+ img_attn_mask=inputs["img_attn_mask"][:, :-1],
27
+ use_cache=True,
28
+ max_new_tokens=64,
29
+ )
30
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
31
+
32
+ # Specify `cleanup_and_extract=False` in order to see the raw model generation.
33
+ processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False)
34
+
35
+ # print(processed_text)
36
+ # `<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>.`
37
+
38
+ # By default, the generated text is cleanup and the entities are extracted.
39
+ processed_text, entities = processor.post_process_generation(generated_text)
40
+
41
+ print(processed_text)
42
+ # `An image of a snowman warming himself by a fire.`
43
+
44
+ print(entities)
45
+ return jsonify({"message": processed_text, 'entities': entities})
46
+ except Exception as e:
47
+ return jsonify({"error": str(e)})
48
+
49
+
50
+ if __name__ == '__main__':
51
+ app.run(host='localhost', port=8005)