farrosalferro24 commited on
Commit
09773e9
β€’
1 Parent(s): 1de2fb0

Initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/booth_yellowvan.jpg filter=lfs diff=lfs merge=lfs -text
37
+ examples/bucket_cyclist.jpg filter=lfs diff=lfs merge=lfs -text
38
+ examples/bus_luggage.jpg filter=lfs diff=lfs merge=lfs -text
39
+ examples/little_girl.jpg filter=lfs diff=lfs merge=lfs -text
chat_gecko.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import os
4
+ import time
5
+ from PIL import Image
6
+ import functools
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ import re
10
+ import ast
11
+
12
+ from model import GeckoForConditionalGeneration, GeckoConfig, GeckoProcessor, chat_gecko, chat_gecko_stream
13
+ from model.conversation import conv_templates
14
+ from typing import List
15
+
16
+ from io import StringIO
17
+ import sys
18
+
19
+ class Capturing(list):
20
+ def __enter__(self):
21
+ self._stdout = sys.stdout
22
+ sys.stdout = self._stringio = StringIO()
23
+ return self
24
+ def __exit__(self, *args):
25
+ self.extend(self._stringio.getvalue().splitlines())
26
+ del self._stringio # free up some memory
27
+ sys.stdout = self._stdout
28
+
29
+
30
+ # initialization
31
+ topk = 1
32
+ keyword_criteria = 'word'
33
+ positional_information = 'explicit'
34
+ vision_feature_select_strategy = 'cls'
35
+ patch_picking_strategy = 'last_layer'
36
+ cropping_method = 'naive'
37
+ crop_size = 384
38
+ visualize_topk_patches = False
39
+ print_keyword=True
40
+ print_topk_patches = True
41
+
42
+ torch_dtype = torch.float16
43
+ attn_implementation = 'flash_attention_2'
44
+ device_map = 'cuda'
45
+
46
+ conv_template = conv_templates['llama_3']
47
+
48
+ model = 'TIGER-Lab/Mantis-8B-siglip-llama3'
49
+ config = GeckoConfig.from_pretrained(model,
50
+ topk=topk,
51
+ visualize_topk_patches=visualize_topk_patches,
52
+ keyword_criteria=keyword_criteria,
53
+ positional_information=positional_information,
54
+ vision_feature_select_strategy=vision_feature_select_strategy,
55
+ patch_picking_strategy=patch_picking_strategy,
56
+ print_keyword=print_keyword)
57
+ processor = GeckoProcessor.from_pretrained(model, config=config, use_keyword=True, cropping_method=cropping_method, crop_size=crop_size)
58
+ model = GeckoForConditionalGeneration.from_pretrained(
59
+ model, config=config, torch_dtype=torch_dtype,
60
+ attn_implementation=attn_implementation, device_map=device_map)
61
+ model.load_text_encoder(processor)
62
+
63
+ @spaces.GPU
64
+ def generate_stream(text:str, images:List[Image.Image], history: List[dict], **kwargs):
65
+ global processor, model
66
+ model = model.to("cuda")
67
+ if not images:
68
+ images = None
69
+ # print(history)
70
+ print(f'length of images: {len(images)}')
71
+ generator, print_kw, inputs = chat_gecko_stream(text, images, model, processor, history=history, **kwargs)
72
+ texts = []
73
+ # for text, history in chat_gecko_stream(text, images, model, processor, history=history, **kwargs):
74
+ # yield text
75
+ for text, history in generator:
76
+ texts.append(text)
77
+
78
+ # return text
79
+ return texts, print_kw, inputs
80
+
81
+ @spaces.GPU
82
+ def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
83
+ global processor, model
84
+ model = model.to("cuda")
85
+ if not images:
86
+ images = None
87
+ generated_text, history = chat_gecko(text, images, model, processor, history=history, **kwargs)
88
+ return generated_text
89
+
90
+ def enable_next_image(uploaded_images, image):
91
+ uploaded_images.append(image)
92
+ return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False)
93
+
94
+ def add_message(history, message):
95
+ if message["files"]:
96
+ for file in message["files"]:
97
+ history.append([(file,), None])
98
+ if message["text"]:
99
+ history.append([message["text"], None])
100
+ return history, gr.MultimodalTextbox(value=None)
101
+
102
+ def print_like_dislike(x: gr.LikeData):
103
+ print(x.index, x.value, x.liked)
104
+
105
+ def get_chat_history(history):
106
+ chat_history = []
107
+ user_role = conv_template.roles[0]
108
+ assistant_role = conv_template.roles[1]
109
+ for i, message in enumerate(history):
110
+ if isinstance(message[0], str):
111
+ chat_history.append({"role": user_role, "text": message[0]})
112
+ if i != len(history) - 1:
113
+ assert message[1], "The bot message is not provided, internal error"
114
+ chat_history.append({"role": assistant_role, "text": message[1]})
115
+ else:
116
+ assert not message[1], "the bot message internal error, get: {}".format(message[1])
117
+ chat_history.append({"role": assistant_role, "text": ""})
118
+ return chat_history
119
+
120
+ def get_chat_images(history):
121
+ images = []
122
+ for message in history:
123
+ if isinstance(message[0], tuple):
124
+ images.extend(message[0])
125
+ return images
126
+
127
+ def bot(history, topk=None):
128
+ print(history)
129
+ cur_messages = {"text": "", "images": []}
130
+ for message in history[::-1]:
131
+ if message[1]:
132
+ break
133
+ if isinstance(message[0], str):
134
+ cur_messages["text"] = message[0] + " " + cur_messages["text"]
135
+ elif isinstance(message[0], tuple):
136
+ cur_messages["images"].extend(message[0])
137
+ cur_messages["text"] = cur_messages["text"].strip()
138
+ cur_messages["images"] = cur_messages["images"][::-1]
139
+ if not cur_messages["text"]:
140
+ raise gr.Error("Please enter a message")
141
+ if cur_messages['text'].count("<image>") < len(cur_messages['images']):
142
+ gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.")
143
+ cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text']
144
+ history[-1][0] = cur_messages["text"]
145
+ if cur_messages['text'].count("<image>") > len(cur_messages['images']):
146
+ gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.")
147
+ cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1]
148
+ history[-1][0] = cur_messages["text"]
149
+
150
+
151
+
152
+ chat_history = get_chat_history(history)
153
+ chat_images = get_chat_images(history)
154
+
155
+ generation_kwargs = {
156
+ "max_new_tokens": 4096,
157
+ "num_beams": 1,
158
+ "do_sample": False,
159
+ "topk": topk,
160
+ }
161
+
162
+ response = generate_stream(None, chat_images, chat_history, **generation_kwargs)
163
+ num_images = len(response[2].pixel_values)
164
+ coords = response[1][-num_images:]
165
+ print_kw = '\n'.join(response[1][:-num_images-1])
166
+ patches_fig = plot_patches(response[2])
167
+ topk_patches_fig = plot_topk_patches(response[2], coords)
168
+ for _output in response[0]:
169
+ history[-1][1] = _output
170
+ time.sleep(0.05)
171
+ yield history, print_kw, patches_fig, topk_patches_fig
172
+
173
+ def plot_patches(inputs):
174
+ pixel_value = inputs.pixel_values[0].cpu().numpy()
175
+ x, y = inputs.coords[0][-1][0] + 1, inputs.coords[0][-1][1] + 1
176
+
177
+ fig, axes = plt.subplots(y, x, figsize=(x * 4, y * 4))
178
+ for i in range(y):
179
+ for j in range(x):
180
+ axes[i, j].imshow(pixel_value[1 + i * x + j].transpose(1, 2, 0))
181
+ axes[i, j].axis('off')
182
+
183
+ return fig
184
+
185
+ def plot_topk_patches(inputs, selected_coords):
186
+ selected_coords_list = []
187
+ for selected_coord in selected_coords:
188
+ match = re.search(r"\[\[.*\]\]", selected_coord)
189
+ if match:
190
+ coordinates_str = match.group(0)
191
+ # Convert the string representation of the list to an actual list
192
+ coordinates = ast.literal_eval(coordinates_str)
193
+ selected_coords_list.append(coordinates)
194
+ num_images = len(selected_coords_list)
195
+ fig, axes = plt.subplots(num_images, len(selected_coords_list[0])+1, figsize=((len(selected_coords_list[0])+1) * 10, num_images * 10))
196
+ if num_images == 1:
197
+ xmax = inputs.coords[0][-1][0] + 1
198
+ for j in range(len(selected_coords_list[0])+1):
199
+ if j == 0:
200
+ axes[j].imshow(inputs.pixel_values[0][0].cpu().numpy().transpose(1, 2, 0))
201
+ axes[j].axis('off')
202
+ continue
203
+ x, y = selected_coords_list[0][j-1][0], selected_coords_list[0][j-1][1]
204
+ axes[j].imshow(inputs.pixel_values[0][1 + y * xmax + x].cpu().numpy().transpose(1, 2, 0))
205
+ axes[j].axis('off')
206
+ else:
207
+ for i in range(num_images):
208
+ xmax = inputs.coords[i][-1][0] + 1
209
+ for j in range(len(selected_coords_list[0])+1):
210
+ if j == 0:
211
+ axes[i, j].imshow(inputs.pixel_values[i][0].cpu().numpy().transpose(1, 2, 0))
212
+ continue
213
+ x, y = selected_coords_list[i][j-1][0], selected_coords_list[i][j-1][1]
214
+ axes[i, j].imshow(inputs.pixel_values[i][1 + y * xmax + x].cpu().numpy().transpose(1, 2, 0))
215
+ axes[i, j].axis('off')
216
+
217
+ return fig
218
+
219
+
220
+ def build_demo():
221
+ with gr.Blocks() as demo:
222
+
223
+ # gr.Markdown(""" # Mantis
224
+ # Mantis is a multimodal conversational AI model that can chat with users about images and text. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses.
225
+ # ### [Paper](https://arxiv.org/abs/2405.01483) | [Github](https://github.com/TIGER-AI-Lab/Mantis) | [Models](https://huggingface.co/collections/TIGER-Lab/mantis-6619b0834594c878cdb1d6e4) | [Dataset](https://huggingface.co/datasets/TIGER-Lab/Mantis-Instruct) | [Website](https://tiger-ai-lab.github.io/Mantis/)
226
+ # """)
227
+
228
+ # gr.Markdown("""## Chat with Mantis
229
+ # Mantis supports interleaved text-image input format, where you can simply use the placeholder `<image>` to indicate the position of uploaded images.
230
+ # The model is optimized for multi-image reasoning, while preserving the ability to chat about text and images in a single conversation.
231
+ # (The model currently serving is [πŸ€— TIGER-Lab/Mantis-8B-siglip-llama3](https://huggingface.co/TIGER-Lab/Mantis-8B-siglip-llama3))
232
+ # """)
233
+
234
+ chatbot = gr.Chatbot(line_breaks=True)
235
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True)
236
+
237
+ chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
238
+ print_kw = gr.Textbox(label="keywords")
239
+ depict_patches = gr.Plot(label="image patches", format="png")
240
+ depict_topk_patches = gr.Plot(label="top-k image patches", format="png")
241
+
242
+
243
+ # with gr.Accordion(label='Advanced options', open=False):
244
+ # temperature = gr.Slider(
245
+ # label='Temperature',
246
+ # minimum=0.1,
247
+ # maximum=2.0,
248
+ # step=0.1,
249
+ # value=0.2,
250
+ # interactive=True
251
+ # )
252
+ # top_p = gr.Slider(
253
+ # label='Top-p',
254
+ # minimum=0.05,
255
+ # maximum=1.0,
256
+ # step=0.05,
257
+ # value=1.0,
258
+ # interactive=True
259
+ # )
260
+ topk = gr.Slider(
261
+ label='Top-k',
262
+ minimum=1,
263
+ maximum=10,
264
+ step=1,
265
+ value=1,
266
+ interactive=True)
267
+
268
+ bot_msg = chat_msg.success(bot, chatbot,
269
+ chatbot, api_name="bot_response")
270
+
271
+ chatbot.like(print_like_dislike, None, None)
272
+
273
+ with gr.Row():
274
+ send_button = gr.Button("Send")
275
+ clear_button = gr.ClearButton([chatbot, chat_input])
276
+
277
+ send_button.click(
278
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
279
+ ).then(
280
+ bot,
281
+ [chatbot, topk],
282
+ [chatbot, print_kw, depict_patches, depict_topk_patches], api_name="bot_response"
283
+ )
284
+
285
+ gr.Examples(
286
+ examples=[
287
+ {
288
+ "text": open("gradio/examples/little_girl.txt").read(),
289
+ "files": ["gradio/examples/little_girl.jpg"]
290
+ },
291
+ {
292
+ "text": open("gradio/examples/bus_luggage.txt").read(),
293
+ "files": ["gradio/examples/bus_luggage.jpg"]
294
+ },
295
+ ],
296
+ inputs=[chat_input],
297
+ )
298
+
299
+ # gr.Markdown("""
300
+ # ## Citation
301
+ # ```
302
+ # @article{jiang2024mantis,
303
+ # title={MANTIS: Interleaved Multi-Image Instruction Tuning},
304
+ # author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu},
305
+ # journal={arXiv preprint arXiv:2405.01483},
306
+ # year={2024}
307
+ # }
308
+ # ```""")
309
+ return demo
310
+
311
+
312
+ if __name__ == "__main__":
313
+ demo = build_demo()
314
+ demo.launch(share=False)
315
+
examples/booth_yellowvan.jpg ADDED

Git LFS Details

  • SHA256: 26560e898003ef04f807bd997744a1b15b2cfd1235b15069206808d6ce38932f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.65 MB
examples/booth_yellowvan.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ <image>
2
+ Is the telephone booth on the left or right side of the yellow van?
3
+ (A) right
4
+ (B) left
5
+ Answer with the option's letter from the given choices directly.
examples/bucket_cyclist.jpg ADDED

Git LFS Details

  • SHA256: 4f170930283c75eafb65efa7cf3b6531417d59a2dbfe69d82d2b8cbf96f0a508
  • Pointer size: 132 Bytes
  • Size of remote file: 1.7 MB
examples/bucket_cyclist.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ <image>
2
+ Is the bucket on the left or right side of the cyclist?
3
+ (A) right
4
+ (B) left
5
+ Answer with the option's letter from the given choices directly.
examples/bus_luggage.jpg ADDED

Git LFS Details

  • SHA256: 6f56c068c6b9b03743d7650406c6560dcdc81c26349b47b3bed65d2902cdd842
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
examples/bus_luggage.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ <image>
2
+ Is the blue luggage on the left or right side of the bus?
3
+ (A) right
4
+ (B) left
5
+ Answer with the option's letter from the given choices directly.
examples/little_girl.jpg ADDED

Git LFS Details

  • SHA256: 847ba9aa5edd9f28c2d40d4a39c7a74bc9a4cb5918ec7627f1b70f46d5fe954a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.57 MB
examples/little_girl.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <image>
2
+ What is the color of the little girl's shirt?
3
+ (A) yellow
4
+ (B) pink
5
+ (C) white
6
+ (D) black
7
+ Answer with the option's letter from the given choices directly.
model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .modelling_gecko import GeckoForConditionalGeneration
2
+ from .processing_gecko import GeckoProcessor
3
+ from .configuration_gecko import GeckoConfig
4
+ from .utils import chat_gecko, chat_gecko_stream
model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (409 Bytes). View file
 
model/__pycache__/configuration_gecko.cpython-310.pyc ADDED
Binary file (3.8 kB). View file
 
model/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
model/__pycache__/modelling_gecko.cpython-310.pyc ADDED
Binary file (24.3 kB). View file
 
model/__pycache__/processing_gecko.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
model/__pycache__/utils.cpython-310.pyc ADDED
Binary file (6.13 kB). View file
 
model/configuration_gecko.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+ from transformers.models.auto import CONFIG_MAPPING
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+ class GeckoConfig(PretrainedConfig):
8
+ r"""
9
+ This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
10
+ Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
11
+ with the defaults will yield a similar configuration to that of the Llava-9B.
12
+
13
+ e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
14
+
15
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
16
+ documentation from [`PretrainedConfig`] for more information.
17
+
18
+ Args:
19
+ vision_config (`LlavaVisionConfig`, *optional*):
20
+ Custom vision config or dict
21
+ text_config (`Union[AutoConfig, dict]`, *optional*):
22
+ The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
23
+ ignore_index (`int`, *optional*, defaults to -100):
24
+ The ignore index for the loss function.
25
+ image_token_index (`int`, *optional*, defaults to 32000):
26
+ The image token index to encode the image prompt.
27
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
28
+ The activation function used by the multimodal projector.
29
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
30
+ The feature selection strategy used to select the vision feature from the CLIP backbone.
31
+ vision_feature_layer (`int`, *optional*, defaults to -2):
32
+ The index of the layer to select the vision feature.
33
+ vocab_size (`int`, *optional*, defaults to 32000):
34
+ Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
35
+ `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
36
+ """
37
+
38
+ model_type = "gecko"
39
+ is_composition = False
40
+
41
+ def __init__(
42
+ self,
43
+ vision_config=None,
44
+ text_config=None,
45
+ ignore_index=-100,
46
+ image_token_index=32000,
47
+ projector_hidden_act="gelu",
48
+ vision_feature_select_strategy="cls",
49
+ patch_picking_strategy="across_layers",
50
+ vision_feature_layer=-2,
51
+ vocab_size=32000,
52
+ topk=4,
53
+ keyword_criteria="template",
54
+ positional_information="explicit",
55
+ visualize_patches=False,
56
+ visualize_topk_patches=False,
57
+ print_keyword=False,
58
+ print_topk_patches=False,
59
+ **kwargs,
60
+ ):
61
+ self.ignore_index = ignore_index
62
+ self.image_token_index = image_token_index
63
+ self.projector_hidden_act = projector_hidden_act
64
+ self.vision_feature_layer = vision_feature_layer
65
+ self.vision_feature_select_strategy = vision_feature_select_strategy
66
+ self.patch_picking_strategy = patch_picking_strategy
67
+ self.vocab_size = vocab_size
68
+ self.topk = topk
69
+ self.vision_config = vision_config
70
+ self.text_config = text_config
71
+ self.keyword_criteria = keyword_criteria
72
+ self.positional_information = positional_information
73
+ self.visualize_patches = visualize_patches
74
+ self.visualize_topk_patches = visualize_topk_patches
75
+ self.print_keyword = print_keyword
76
+ self.print_topk_patches = print_topk_patches
77
+
78
+ if isinstance(self.vision_config, dict):
79
+ vision_config["model_type"] = (
80
+ vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
81
+ )
82
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
83
+ elif vision_config is None:
84
+ self.vision_config = CONFIG_MAPPING["clip_vision_model"](
85
+ intermediate_size=4096,
86
+ hidden_size=1024,
87
+ patch_size=14,
88
+ image_size=336,
89
+ num_hidden_layers=24,
90
+ num_attention_heads=16,
91
+ vocab_size=32000,
92
+ projection_dim=768,
93
+ )
94
+ self.vocab_size = self.vocab_size
95
+
96
+ self.text_config = text_config
97
+
98
+ if isinstance(self.text_config, dict):
99
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
100
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
101
+ self.vocab_size = self.text_config.vocab_size
102
+ elif text_config is None:
103
+ self.text_config = CONFIG_MAPPING["llama"]()
104
+
105
+ super().__init__(**kwargs)
model/conversation.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+ LLAMA_3 = auto()
14
+ MFuyu = auto()
15
+ PHI_3 = auto()
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Conversation:
20
+ """A class that keeps all conversation history."""
21
+ system: str
22
+ roles: List[str]
23
+ messages: List[List[str]]
24
+ offset: int
25
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
26
+ sep: str = "###"
27
+ sep2: str = None
28
+ version: str = "Unknown"
29
+
30
+ skip_next: bool = False
31
+
32
+ def get_prompt(self):
33
+ messages = self.messages
34
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
35
+
36
+ messages = self.messages.copy()
37
+ init_role, init_msg = messages[0].copy()
38
+ init_msg = init_msg[0].replace("<image>", "").strip()
39
+ if 'mmtag' in self.version:
40
+ messages[0] = (init_role, init_msg)
41
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
42
+ messages.insert(1, (self.roles[1], "Received."))
43
+ else:
44
+ messages[0] = (init_role, "<image>" + init_msg)
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep
47
+ for role, message in messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ elif self.sep_style == SeparatorStyle.TWO:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system + seps[0]
57
+ for i, (role, message) in enumerate(messages):
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + seps[i % 2]
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.MPT:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + message + self.sep
71
+ else:
72
+ ret += role
73
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
74
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
75
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
76
+ ret = ""
77
+
78
+ for i, (role, message) in enumerate(messages):
79
+ if i == 0:
80
+ assert message, "first message should not be none"
81
+ assert role == self.roles[0], "first message should come from user"
82
+ if message:
83
+ if type(message) is tuple:
84
+ message, _, _ = message
85
+ if i == 0: message = wrap_sys(self.system) + message
86
+ if i % 2 == 0:
87
+ message = wrap_inst(message)
88
+ ret += self.sep + message
89
+ else:
90
+ ret += " " + message + " " + self.sep2
91
+ else:
92
+ ret += ""
93
+ ret = ret.lstrip(self.sep)
94
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
95
+ ret = self.system + self.sep
96
+ for role, message in messages:
97
+ if message:
98
+ if type(message) is tuple:
99
+ message, _, _ = message
100
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + message + self.sep
101
+ else:
102
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
103
+ elif self.sep_style == SeparatorStyle.MFuyu:
104
+ seps = [self.sep, self.sep2]
105
+ ret = self.system + "\n"
106
+ for i, (role, message) in enumerate(messages):
107
+ if message:
108
+ if type(message) is tuple:
109
+ message, _, _ = message
110
+ ret += role + ": " + message + seps[i % 2]
111
+ else:
112
+ ret += role + ":"
113
+ elif self.sep_style == SeparatorStyle.PLAIN:
114
+ seps = [self.sep, self.sep2]
115
+ ret = self.system
116
+ for i, (role, message) in enumerate(messages):
117
+ if message:
118
+ if type(message) is tuple:
119
+ message, _, _ = message
120
+ ret += message + seps[i % 2]
121
+ else:
122
+ ret += ""
123
+ elif self.sep_style == SeparatorStyle.PHI_3:
124
+ ret = self.system + self.sep + '\n'
125
+ for role, message in messages:
126
+ if message:
127
+ if type(message) is tuple:
128
+ message, _, _ = message
129
+ ret += f"<|{role}|>\n" + message + self.sep + '\n'
130
+ else:
131
+ ret += f"<|{role}|>\n"
132
+ else:
133
+ raise ValueError(f"Invalid style: {self.sep_style}")
134
+
135
+ return ret
136
+
137
+ def generate_keyword_prompt(self, messages=None):
138
+ messages = messages if messages is not None else self.messages[-2][1]
139
+ system_prompt = """Use the text below as an example to generate your answers to the user's query. Give the answer in the same format.
140
+
141
+ Example starts:
142
+ ```
143
+ User: What is/are the object(s) that being asked in below question? Also give some useful visual features that best describes each object in a photo.
144
+ 'What kind of drink can we buy from that vending machine?'
145
+ Assistant: The object being asked is vending machine. Several visual features of the object are:
146
+
147
+ 'vending machine':
148
+
149
+ * typically have a large, upright, rectangular shape.
150
+ * usually have a large glass or transparent plastic front
151
+ * often feature logos, product images, and labels on their exterior
152
+ * Most are metallic and have a dominant color (often bright or neutral)
153
+ ```
154
+ Example ends
155
+
156
+ Example starts:
157
+ ```
158
+ User: What is/are the object(s) that being asked in below question? Also give some useful visual features that best describes each object in a photo.
159
+ 'Is the wallet on the left or right side of the keyboard?'
160
+ Assistant: The objects being asked are wallet and keyboard. Several visual features of the objects are:
161
+
162
+ 'wallet':
163
+
164
+ * typically have a compact, flat, rectangular shape.
165
+ * can be made from various materials including leather, synthetic fabric, or even metal for hard cases.
166
+ * generally small enough to fit in a pocket or a small bag.
167
+ * come in a wide range of colors, from classic black or brown to vibrant hues and patterns.
168
+
169
+ 'keyboard':
170
+
171
+ * typically feature a rectangular array of keys in a grid layout.
172
+ * can be made from plastic, metal, or other materials.
173
+ * come in various colors, although black and white are most common.
174
+ * may have a visible USB cable or may be identified as wireless if there is no cable connected.
175
+ ```
176
+ Example ends
177
+ Please generate answer in the SAME FORMAT as shown in the above examples. Your response must have an equal number of features for each object in the question.
178
+ Please ensure to cover all significant visual features.
179
+ """
180
+ user_prompt = f"""What is/are the object(s) that being asked in below question? Also give some useful visual features that best describes each object in a photo.
181
+ '{messages}'"""
182
+
183
+ prompt = f"""<|start_header_id|>system<|end_header_id|>
184
+
185
+ {system_prompt}{self.sep}
186
+
187
+ <|start_header_id|>user<|end_header_id|>
188
+
189
+ {user_prompt}
190
+
191
+ <|start_header_id|>assistant<|end_header_id|>"""
192
+ return prompt
193
+
194
+ def append_message(self, role, message):
195
+ self.messages.append([role, message])
196
+
197
+ def get_images(self, return_pil=False):
198
+ images = []
199
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
200
+ if i % 2 == 0:
201
+ if type(msg) is tuple:
202
+ import base64
203
+ from io import BytesIO
204
+ from PIL import Image
205
+ msg, image, image_process_mode = msg
206
+ if image_process_mode == "Pad":
207
+ def expand2square(pil_img, background_color=(122, 116, 104)):
208
+ width, height = pil_img.size
209
+ if width == height:
210
+ return pil_img
211
+ elif width > height:
212
+ result = Image.new(pil_img.mode, (width, width), background_color)
213
+ result.paste(pil_img, (0, (width - height) // 2))
214
+ return result
215
+ else:
216
+ result = Image.new(pil_img.mode, (height, height), background_color)
217
+ result.paste(pil_img, ((height - width) // 2, 0))
218
+ return result
219
+ image = expand2square(image)
220
+ elif image_process_mode in ["Default", "Crop"]:
221
+ pass
222
+ elif image_process_mode == "Resize":
223
+ image = image.resize((336, 336))
224
+ else:
225
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
226
+ max_hw, min_hw = max(image.size), min(image.size)
227
+ aspect_ratio = max_hw / min_hw
228
+ max_len, min_len = 800, 400
229
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
230
+ longest_edge = int(shortest_edge * aspect_ratio)
231
+ W, H = image.size
232
+ if longest_edge != max(image.size):
233
+ if H > W:
234
+ H, W = longest_edge, shortest_edge
235
+ else:
236
+ H, W = shortest_edge, longest_edge
237
+ image = image.resize((W, H))
238
+ if return_pil:
239
+ images.append(image)
240
+ else:
241
+ buffered = BytesIO()
242
+ image.save(buffered, format="PNG")
243
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
244
+ images.append(img_b64_str)
245
+ return images
246
+
247
+ def to_gradio_chatbot(self):
248
+ ret = []
249
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
250
+ if i % 2 == 0:
251
+ if type(msg) is tuple:
252
+ import base64
253
+ from io import BytesIO
254
+ msg, image, image_process_mode = msg
255
+ max_hw, min_hw = max(image.size), min(image.size)
256
+ aspect_ratio = max_hw / min_hw
257
+ max_len, min_len = 800, 400
258
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
259
+ longest_edge = int(shortest_edge * aspect_ratio)
260
+ W, H = image.size
261
+ if H > W:
262
+ H, W = longest_edge, shortest_edge
263
+ else:
264
+ H, W = shortest_edge, longest_edge
265
+ image = image.resize((W, H))
266
+ buffered = BytesIO()
267
+ image.save(buffered, format="JPEG")
268
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
269
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
270
+ msg = img_str + msg.replace('<image>', '').strip()
271
+ ret.append([msg, None])
272
+ else:
273
+ ret.append([msg, None])
274
+ else:
275
+ ret[-1][-1] = msg
276
+ return ret
277
+
278
+ def copy(self):
279
+ return Conversation(
280
+ system=self.system,
281
+ roles=self.roles,
282
+ messages=[[x, y] for x, y in self.messages],
283
+ offset=self.offset,
284
+ sep_style=self.sep_style,
285
+ sep=self.sep,
286
+ sep2=self.sep2,
287
+ version=self.version)
288
+
289
+ def dict(self):
290
+ if len(self.get_images()) > 0:
291
+ return {
292
+ "system": self.system,
293
+ "roles": self.roles,
294
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
295
+ "offset": self.offset,
296
+ "sep": self.sep,
297
+ "sep2": self.sep2,
298
+ }
299
+ return {
300
+ "system": self.system,
301
+ "roles": self.roles,
302
+ "messages": self.messages,
303
+ "offset": self.offset,
304
+ "sep": self.sep,
305
+ "sep2": self.sep2,
306
+ }
307
+
308
+
309
+ conv_vicuna_v0 = Conversation(
310
+ system="A chat between a curious human and an artificial intelligence assistant. "
311
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
312
+ roles=("Human", "Assistant"),
313
+ messages=(
314
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
315
+ ("Assistant",
316
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
317
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
318
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
319
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
320
+ "renewable and non-renewable energy sources:\n"
321
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
322
+ "energy sources are finite and will eventually run out.\n"
323
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
324
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
325
+ "and other negative effects.\n"
326
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
327
+ "have lower operational costs than non-renewable sources.\n"
328
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
329
+ "locations than non-renewable sources.\n"
330
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
331
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
332
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
333
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
334
+ ),
335
+ offset=2,
336
+ sep_style=SeparatorStyle.SINGLE,
337
+ sep="###",
338
+ )
339
+
340
+ conv_vicuna_v1 = Conversation(
341
+ system="A chat between a curious user and an artificial intelligence assistant. "
342
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
343
+ roles=("USER", "ASSISTANT"),
344
+ version="v1",
345
+ messages=(),
346
+ offset=0,
347
+ sep_style=SeparatorStyle.TWO,
348
+ sep=" ",
349
+ sep2="</s>",
350
+ )
351
+
352
+ conv_llama_2 = Conversation(
353
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
354
+
355
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
356
+ roles=("USER", "ASSISTANT"),
357
+ version="llama_v2",
358
+ messages=(),
359
+ offset=0,
360
+ sep_style=SeparatorStyle.LLAMA_2,
361
+ sep="<s>",
362
+ sep2="</s>",
363
+ )
364
+
365
+ conv_llava_llama_2 = Conversation(
366
+ system="You are a helpful language and vision assistant. "
367
+ "You are able to understand the visual content that the user provides, "
368
+ "and assist the user with a variety of tasks using natural language.",
369
+ roles=("USER", "ASSISTANT"),
370
+ version="llama_v2",
371
+ messages=(),
372
+ offset=0,
373
+ sep_style=SeparatorStyle.LLAMA_2,
374
+ sep="<s>",
375
+ sep2="</s>",
376
+ )
377
+
378
+ conv_mpt = Conversation(
379
+ system="""<|im_start|>system
380
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
381
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
382
+ version="mpt",
383
+ messages=(),
384
+ offset=0,
385
+ sep_style=SeparatorStyle.MPT,
386
+ sep="<|im_end|>",
387
+ )
388
+
389
+ conv_llava_plain = Conversation(
390
+ system="",
391
+ roles=("", ""),
392
+ messages=(
393
+ ),
394
+ offset=0,
395
+ sep_style=SeparatorStyle.PLAIN,
396
+ sep="\n",
397
+ )
398
+
399
+ conv_llava_v0 = Conversation(
400
+ system="A chat between a curious human and an artificial intelligence assistant. "
401
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
402
+ roles=("Human", "Assistant"),
403
+ messages=(
404
+ ),
405
+ offset=0,
406
+ sep_style=SeparatorStyle.SINGLE,
407
+ sep="###",
408
+ )
409
+
410
+ conv_llava_v0_mmtag = Conversation(
411
+ system="A chat between a curious user and an artificial intelligence assistant. "
412
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
413
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
414
+ roles=("Human", "Assistant"),
415
+ messages=(
416
+ ),
417
+ offset=0,
418
+ sep_style=SeparatorStyle.SINGLE,
419
+ sep="###",
420
+ version="v0_mmtag",
421
+ )
422
+
423
+ conv_llava_v1 = Conversation(
424
+ system="A chat between a curious human and an artificial intelligence assistant. "
425
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
426
+ roles=("USER", "ASSISTANT"),
427
+ version="v1",
428
+ messages=(),
429
+ offset=0,
430
+ sep_style=SeparatorStyle.TWO,
431
+ sep=" ",
432
+ sep2="</s>",
433
+ )
434
+
435
+ conv_llava_v1_mmtag = Conversation(
436
+ system="A chat between a curious user and an artificial intelligence assistant. "
437
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
438
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
439
+ roles=("USER", "ASSISTANT"),
440
+ messages=(),
441
+ offset=0,
442
+ sep_style=SeparatorStyle.TWO,
443
+ sep=" ",
444
+ sep2="</s>",
445
+ version="v1_mmtag",
446
+ )
447
+
448
+ conv_mfuyu_v1 = Conversation(
449
+ system="You are a helpful language and vision assistant. "
450
+ "You are able to understand the visual content that the user provides, "
451
+ "and assist the user with a variety of tasks using natural language.",
452
+ roles=("USER", "ASSISTANT"),
453
+ version="v1",
454
+ messages=(),
455
+ offset=0,
456
+ sep_style=SeparatorStyle.MFuyu,
457
+ sep="<0x04>", # begin of answer token
458
+ sep2="|ENDOFTEXT|",
459
+ ) # copied from conv_vicuna_v1
460
+
461
+ conv_mllava_v1_mmtag = Conversation(
462
+ system="A chat between a curious user and an artificial intelligence assistant. "
463
+ "The assistant is able to understand the multiple visual contents that the user provides, and assist the user with a variety of tasks using natural language."
464
+ "Each visual content will be provided with the following format: <Image>visual content</Image>.",
465
+ roles=("USER", "ASSISTANT"),
466
+ messages=(),
467
+ offset=0,
468
+ sep_style=SeparatorStyle.SINGLE,
469
+ sep="</s>",
470
+ version="v1_mmtag",
471
+ )
472
+
473
+ conv_mllava_v1 = Conversation(
474
+ system="A chat between a curious human and an artificial intelligence assistant. "
475
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
476
+ roles=("USER", "ASSISTANT"),
477
+ version="v1",
478
+ messages=(),
479
+ offset=0,
480
+ sep_style=SeparatorStyle.SINGLE,
481
+ sep="</s>",
482
+ )
483
+
484
+ conv_llama_3 = Conversation(
485
+ system="<|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.",
486
+ roles=("user", "assistant"),
487
+ messages=(),
488
+ offset=0,
489
+ sep_style=SeparatorStyle.LLAMA_3,
490
+ sep="<|eot_id|>",
491
+ )
492
+
493
+ conv_phi_3 = Conversation(
494
+ system='<s><|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.',
495
+ roles=('<|user|>', '<|assistant|>'),
496
+ messages=(),
497
+ offset=0,
498
+ sep_style=SeparatorStyle.PHI_3,
499
+ sep='<|end|>'
500
+ )
501
+
502
+ default_conversation = conv_mfuyu_v1
503
+ conv_templates = {
504
+ "default": conv_vicuna_v0,
505
+ "v0": conv_vicuna_v0,
506
+ "v1": conv_vicuna_v1,
507
+ "vicuna_v1": conv_vicuna_v1,
508
+ "llama_2": conv_llama_2,
509
+
510
+ "plain": conv_llava_plain,
511
+ "v0_plain": conv_llava_plain,
512
+ "llava_v0": conv_llava_v0,
513
+ "v0_mmtag": conv_llava_v0_mmtag,
514
+ "llava_v1": conv_llava_v1,
515
+ "v1_mmtag": conv_llava_v1_mmtag,
516
+ "llava_llama_2": conv_llava_llama_2,
517
+ "llama_3": conv_llama_3,
518
+ "mllava_v1": conv_mllava_v1,
519
+ "mllava_v1_mmtag": conv_mllava_v1_mmtag,
520
+ "phi_3": conv_phi_3,
521
+
522
+ "mpt": conv_mpt,
523
+ }
524
+
525
+
526
+ if __name__ == "__main__":
527
+ print(default_conversation.get_prompt())
model/modelling_gecko.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union, Dict
4
+ from copy import deepcopy
5
+
6
+ import re
7
+ import math
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ import matplotlib.pyplot as plt
12
+
13
+ from transformers import PreTrainedModel
14
+ from transformers.activations import ACT2FN
15
+ from transformers.cache_utils import Cache
16
+ from transformers.modeling_outputs import ModelOutput
17
+ from transformers.utils import (
18
+ add_start_docstrings,
19
+ add_start_docstrings_to_model_forward,
20
+ logging,
21
+ replace_return_docstrings,
22
+ )
23
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
24
+
25
+ from .configuration_gecko import GeckoConfig
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ _CONFIG_FOR_DOC = "GeckoConfig"
30
+
31
+ @dataclass
32
+ class GeckoCausalLMOutputWithPast(ModelOutput):
33
+ """
34
+ Base class for Llava causal language model (or autoregressive) outputs.
35
+
36
+ Args:
37
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
38
+ Language modeling loss (for next-token prediction).
39
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
40
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
41
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
42
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
43
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
44
+
45
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
46
+ `past_key_values` input) to speed up sequential decoding.
47
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
48
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
49
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
50
+
51
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
52
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
53
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
54
+ sequence_length)`.
55
+
56
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
57
+ heads.
58
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
59
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
60
+ sequence_length, hidden_size)`.
61
+
62
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
63
+ """
64
+
65
+ loss: Optional[torch.FloatTensor] = None
66
+ logits: torch.FloatTensor = None
67
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
68
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
69
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
70
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
71
+
72
+ class GeckoPreTrainedModel(PreTrainedModel):
73
+ config_class = GeckoConfig
74
+ base_model_prefix = "model"
75
+ supports_gradient_checkpointing = True
76
+ _no_split_modules = ["GeckoVisionAttention"]
77
+ _skip_keys_device_placement = "past_key_values"
78
+ _supports_flash_attn_2 = True
79
+
80
+ def _init_weights(self, module):
81
+ std = (
82
+ self.config.intializer_range if hasattr(self.config, "intializer_range") else self.config.text_config.initializer_range
83
+ )
84
+
85
+ if hasattr(module, "class_embedding"):
86
+ module.class_embedding.data.normal_(mean=0.0, std=std)
87
+
88
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
89
+ module.weight.data.normal_(mean=0.0, std=std)
90
+ if module.bias is not None:
91
+ module.bias.data.zero_()
92
+ elif isinstance(module, nn.Embedding):
93
+ module.weight.data.normal_(mean=0.0, std=std)
94
+ if module.padding_idx is not None:
95
+ module.weight.data[module.padding_idx].zero_()
96
+
97
+ @property
98
+ def _supports_sdpa(self):
99
+ return self.language_model._supports_sdpa
100
+
101
+ class PositionalEncoding2D(nn.Module):
102
+ def __init__(self, config: GeckoConfig):
103
+ """
104
+ :param channels: The last dimension of the tensor you want to apply pos emb to.
105
+ """
106
+ super(PositionalEncoding2D, self).__init__()
107
+ if config.positional_information == "2d_before":
108
+ channels = config.vision_config.hidden_size
109
+ else:
110
+ channels = config.text_config.hidden_size
111
+ self.org_channels = channels
112
+ channels = int(math.ceil(channels / 4) * 2)
113
+ self.channels = channels
114
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
115
+ self.register_buffer("inv_freq", inv_freq)
116
+ self.register_buffer("cached_penc", None, persistent=False)
117
+
118
+ def get_emb(self, sin_inp):
119
+ """
120
+ Gets a base embedding for one dimension with sin and cos intertwined
121
+ """
122
+ emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
123
+ return torch.flatten(emb, -2, -1)
124
+
125
+ def forward(self, tensor):
126
+ """
127
+ :param tensor: A 4d tensor of size (x, y, num_tokens, ch)
128
+ :return: Positional Encoding Matrix of size (x, y, num_tokens, ch)
129
+ """
130
+ if len(tensor.shape) != 4:
131
+ raise RuntimeError("The input tensor has to be 4d!")
132
+
133
+ if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
134
+ return self.cached_penc
135
+
136
+ self.cached_penc = None
137
+ x, y, num_tokens, orig_ch = tensor.shape
138
+ pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
139
+ pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype)
140
+ sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
141
+ sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
142
+ emb_x = self.get_emb(sin_inp_x).unsqueeze(1)
143
+ emb_y = self.get_emb(sin_inp_y)
144
+ emb = torch.zeros(
145
+ (x, y, self.channels * 2),
146
+ device=tensor.device,
147
+ dtype=tensor.dtype,
148
+ )
149
+ emb[:, :, : self.channels] = emb_x
150
+ emb[:, :, self.channels : 2 * self.channels] = emb_y
151
+
152
+ self.cached_penc = emb[:, :, None, :orig_ch].repeat(1, 1, num_tokens, 1)
153
+ return self.cached_penc
154
+
155
+ class GeckoMultiModalProjector(nn.Module):
156
+ def __init__(self, config: GeckoConfig):
157
+ super().__init__()
158
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
159
+ self.act = ACT2FN[config.projector_hidden_act]
160
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
161
+
162
+ def forward(self, image_features):
163
+ hidden_states = self.linear_1(image_features)
164
+ hidden_states = self.act(hidden_states)
165
+ hidden_states = self.linear_2(hidden_states)
166
+ return hidden_states
167
+
168
+ class GeckoForConditionalGeneration(GeckoPreTrainedModel):
169
+ def __init__(self, config: GeckoConfig, vision_tower=None, language_model=None, multimodal_projector=None):
170
+ super().__init__(config)
171
+ self.vision_tower = AutoModel.from_config(config.vision_config) if vision_tower is None else vision_tower
172
+ self.positional_encoding = PositionalEncoding2D(config) if '2d' in config.positional_information else None
173
+ self.multi_modal_projector = GeckoMultiModalProjector(config)
174
+ self.vocab_size = config.vocab_size
175
+ self.language_model = AutoModelForCausalLM.from_config(
176
+ config.text_config, attn_implementation=config._attn_implementation
177
+ ) if language_model is None else language_model
178
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
179
+ self.post_init()
180
+
181
+ def load_text_encoder(self, processor):
182
+ self.tokenizer = processor.tokenizer
183
+ self.clip_tokenizer = processor.clip_tokenizer
184
+ self.eos_token_id = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
185
+ self.encoder_type = self.config.vision_config.model_type
186
+ if 'clip' in self.encoder_type:
187
+ self.encoder = AutoModel.from_pretrained('openai/clip-vit-large-patch14-336', torch_dtype=self.dtype, device_map=self.device)
188
+ elif 'siglip' in self.encoder_type:
189
+ self.encoder = AutoModel.from_pretrained("google/siglip-so400m-patch14-384", torch_dtype=self.dtype, device_map=self.device)
190
+ else:
191
+ raise ValueError(f"Vision model {self.config.vision_config.model_type} is not supported.")
192
+
193
+ def get_input_embeddings(self):
194
+ return self.language_model.get_input_embeddings()
195
+
196
+ def set_input_embeddings(self, value):
197
+ self.language_model.set_input_embeddings(value)
198
+
199
+ def get_output_embeddings(self):
200
+ return self.language_model.get_output_embeddings()
201
+
202
+ def set_output_embeddings(self, new_embeddings):
203
+ self.language_model.set_output_embeddings(new_embeddings)
204
+
205
+ def set_decoder(self, decoder):
206
+ self.language_model.set_decoder(decoder)
207
+
208
+ def get_decoder(self):
209
+ return self.language_model.get_decoder()
210
+
211
+ def tie_weights(self):
212
+ return self.language_model.tie_weights()
213
+
214
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
215
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
216
+ # update vocab size
217
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
218
+ self.config.vocab_size = model_embeds.num_embeddings
219
+ self.vocab_size = model_embeds.num_embeddings
220
+ return model_embeds
221
+
222
+ # def _get_highest_similarity(self, cls_token, keyword_hidden_states, top_patches):
223
+ # num_patches, embed_dim = cls_token.shape
224
+ # batch_size, sequence_length, hidden_size = keyword_hidden_states.shape
225
+ # assert embed_dim == hidden_size, f"The embedding dimension of cls token and keyword hidden states do not match. Dimension of cls token: {embed_dim} and dimension of keyword hidden states: {hidden_size}."
226
+ # keyword_hidden_states = keyword_hidden_states.squeeze(0)
227
+
228
+ # # calculate the similarity between the cls token and the keyword hidden states
229
+ # similarity_score = torch.matmul(cls_token, keyword_hidden_states.T) # shape: (num_patches, sequence_length)
230
+ # similarity_score = similarity_score.mean(dim=1) # shape: (num_patches)
231
+ # # take the index of the patch with the highest similarity score
232
+ # patch_index = torch.topk(similarity_score, top_patches).indices
233
+ # return patch_index
234
+
235
+ # def _select_patches(self, image_features, keyword_hidden_states, top_patches=1):
236
+ # selected_patches = []
237
+ # # iterate through each image
238
+ # for image in image_features:
239
+ # if keyword_hidden_states is not None:
240
+ # # take the first token of each patch
241
+ # cls_token = image[:, 0, :].squeeze(1)
242
+ # # get the index of the patch with the highest similarity score
243
+ # patch_index = self._get_highest_similarity(cls_token, keyword_hidden_states, top_patches)
244
+ # else:
245
+ # top_patches = image.shape[0]
246
+ # patch_index = torch.arange(top_patches)
247
+ # # select the patch with the highest similarity score
248
+ # if self.multimodal_projector == 'mlp':
249
+ # image = image[patch_index, 1:, :].reshape(-1, image.shape[-1]).type(self.vision_tower.dtype)
250
+ # elif self.multimodal_projector == 'perceiver':
251
+ # image = image[patch_index, :, :].reshape(-1, image.shape[-1]).type(self.vision_tower.dtype)
252
+ # else:
253
+ # raise ValueError(f"Multimodal projector {self.multimodal_projector} is not supported.")
254
+ # selected_patches.append(image)
255
+ # return selected_patches # shape: list with shape of num_images, each element of shape (num_tokens * num_patches_i, embed_dim)
256
+
257
+ # def _input_to_vision_tower(self, pixel_values):
258
+ # output = []
259
+ # for i in range(len(pixel_values)):
260
+ # num_patches = pixel_values[i].shape[0]
261
+ # pixel_batch_size = 2
262
+ # processed_pixel_values
263
+
264
+ # def _input_to_multimodal_projector(self, selected_image_features):
265
+ # output = []
266
+ # for selected_image in selected_image_features:
267
+ # selected_image = self.multi_modal_projector(selected_image)
268
+ # output.append(selected_image)
269
+ # return output # shape: list with shape of num_images, each element of shape (num_patches_i, num_tokens, embed_dim) where i is the index of the image
270
+
271
+ # def _process_keyword_input(self, keyword_input_ids, maximum_keyword_tokens=10):
272
+ # self.language_model.eval()
273
+ # with torch.no_grad():
274
+ # output_ids = self.language_model.generate(input_ids=keyword_input_ids, return_dict_in_generate=True, max_new_tokens=maximum_keyword_tokens)
275
+ # output_ids = output_ids.sequences[:, keyword_input_ids.shape[-1]:]
276
+
277
+ # self.language_model.train()
278
+ # # conditions
279
+ # if output_ids[0, 0:2].tolist() == [35581, 25]: # condition where the output is in the form Keyword: <keyword>
280
+ # keyword_ids = output_ids[:, 2:-1]
281
+ # if keyword_ids[0, 0].item() == 482:
282
+ # return None
283
+ # return self.get_input_embeddings()(keyword_ids)
284
+ # else: # output
285
+ # return None
286
+
287
+ def generate_keywords(self, keywords_text, criteria='template'):
288
+ keywords_text = keywords_text.lstrip('\n')
289
+ first_sentence = keywords_text.split('.')[0] + '.'
290
+ if re.search(r'are (.+?)\.', first_sentence):
291
+ objects = re.search(r'are (.+?)\.', first_sentence).group(1).split(' and ')
292
+ elif re.search(r'is (.+?)\.', first_sentence):
293
+ objects = [re.search(r'is (.+?)\.', first_sentence).group(1)]
294
+ else:
295
+ objects = []
296
+
297
+ def generate_template(object, description):
298
+ if object[0] in ['a', 'e', 'i', 'o', 'u']:
299
+ return f'An {object}, which {description}'
300
+ else:
301
+ return f'A {object}, which {description}'
302
+
303
+ descriptions = []
304
+ keywords = []
305
+ for i, obj in enumerate(objects):
306
+ keywords.append(obj)
307
+ if criteria == 'word':
308
+ descriptions.append([obj])
309
+ elif criteria == 'template':
310
+ descriptions.append([f'a photo of {obj}'])
311
+ elif criteria == 'description':
312
+ # pattern = rf"'{obj}':(.*?)('|\Z)"
313
+ # match = re.search(pattern, keywords_text, re.DOTALL)
314
+ # if match:
315
+ # # Extract the feature keywords_text and clean it up
316
+ # feature_text = match.group(1).strip()
317
+ # # Split on new lines and strip each line
318
+ # feature_list = [generate_template(obj, line.strip('* ').strip()) for line in feature_text.split('\n') if line.strip()]
319
+ # descriptions.append(feature_list)
320
+ # The problem of the above code is that it does not work for the case where the object is not found in the text
321
+ # make it more general
322
+ features = re.findall(r"\* (.+)", keywords_text, re.MULTILINE)
323
+ descriptions.append([generate_template(obj, feature) for feature in features[i * len(features) // len(objects): (i + 1) * len(features) // len(objects)]])
324
+
325
+ else:
326
+ raise ValueError(f'invalid criteria: {criteria}')
327
+
328
+ return keywords, descriptions
329
+
330
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
331
+ num_images = len(image_features)
332
+ num_image_tokens = torch.tensor([x.shape[0] for x in image_features], device=self.vision_tower.device, dtype=torch.int64) # total image tokens
333
+ embed_dim = image_features[0].shape[-1]
334
+ batch_size, sequence_length = input_ids.shape
335
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
336
+ # 1. Create a mask to know where special image tokens are
337
+ special_image_token_mask = input_ids == self.config.image_token_index
338
+ # num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
339
+ # Compute the maximum embed dimension
340
+ # max_embed_dim = (num_special_image_tokens.max() * (num_image_tokens - 1)) + sequence_length
341
+ max_embed_dim = torch.sum(num_image_tokens) - num_images + sequence_length
342
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
343
+ _, image_indices = torch.where(input_ids == self.config.image_token_index)
344
+
345
+ # 2. Compute the positions where text should be written
346
+ # Calculate new positions for text tokens in merged image-text sequence.
347
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
348
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
349
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
350
+ image_token_mask = special_image_token_mask * 1
351
+ image_token_mask[0, image_indices] = num_image_tokens - 1
352
+ # for i, index in enumerate(image_indices):
353
+ # special_image_token_mask[0, index] = num_image_tokens[i] - 1
354
+ new_token_positions = torch.cumsum((image_token_mask) + 1, -1) - 1
355
+ # new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
356
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
357
+ if left_padding:
358
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
359
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
360
+
361
+ # 3. Create the full embedding, already padded to the maximum position
362
+ final_embedding = torch.zeros(
363
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
364
+ )
365
+ final_attention_mask = torch.zeros(
366
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
367
+ )
368
+ if labels is not None:
369
+ final_labels = torch.full(
370
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
371
+ )
372
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
373
+ # set the corresponding tensors into their correct target device.
374
+ target_device = inputs_embeds.device
375
+ batch_indices, non_image_indices, text_to_overwrite = (
376
+ batch_indices.to(target_device),
377
+ non_image_indices.to(target_device),
378
+ text_to_overwrite.to(target_device),
379
+ )
380
+ attention_mask = attention_mask.to(target_device)
381
+
382
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
383
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
384
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
385
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
386
+ if labels is not None:
387
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
388
+
389
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
390
+ image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
391
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
392
+
393
+ if image_to_overwrite.sum() != torch.sum(num_image_tokens):
394
+ raise ValueError(
395
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
396
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
397
+ )
398
+
399
+ final_embedding[image_to_overwrite] = torch.cat([image_patches for image_patches in image_features], dim=0).to(target_device)
400
+ # final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
401
+ final_attention_mask |= image_to_overwrite
402
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
403
+
404
+ if labels is None:
405
+ final_labels = None
406
+
407
+ return final_embedding, final_attention_mask, final_labels, position_ids
408
+
409
+ def forward(
410
+ self,
411
+ input_ids: torch.LongTensor = None,
412
+ pixel_values: List[torch.FloatTensor] = None,
413
+ coords: List[torch.FloatTensor] = None,
414
+ attention_mask: Optional[torch.Tensor] = None,
415
+ position_ids: Optional[torch.LongTensor] = None,
416
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
417
+ inputs_embeds: Optional[torch.FloatTensor] = None,
418
+ keyword_prompt_input_ids: torch.LongTensor = None,
419
+ vision_feature_select_strategy: Optional[str] = None,
420
+ vision_feature_layer: Optional[int] = None,
421
+ patch_picking_strategy: Optional[str] = None,
422
+ topk: Optional[int] = None,
423
+ keyword_criteria: Optional[str] = None,
424
+ positional_information: Optional[str] = None,
425
+ labels: Optional[torch.LongTensor] = None,
426
+ use_cache: Optional[bool] = None,
427
+ output_attentions: Optional[bool] = None,
428
+ output_hidden_states: Optional[bool] = None,
429
+ return_dict: Optional[bool] = None,
430
+ visualize_patches: Optional[bool] = None,
431
+ visualize_topk_patches: Optional[bool] = None,
432
+ print_keyword: Optional[bool] = None,
433
+ print_topk_patches: Optional[bool] = None,
434
+ ) -> Union[Tuple, GeckoCausalLMOutputWithPast]:
435
+ """
436
+ Parameters:
437
+ text_inputs: Dict
438
+ Output of tokenizer for text data. A dictionary containing the following keys:
439
+ - input_ids: torch.LongTensor of shape (batch_size, sequence_length)
440
+ - attention_mask: torch.LongTensor of shape (batch_size, sequence_length)
441
+ - token_type_ids: torch.LongTensor of shape (batch_size, sequence_length)
442
+ keyword_inputs: Dict
443
+ Output of tokenizer for keyword data. A dictionary containing the following keys:
444
+ - input_ids: torch.LongTensor of shape (batch_size, sequence_length)
445
+ - attention_mask: torch.LongTensor of shape (batch_size, sequence_length)
446
+ - token_type_ids: torch.LongTensor of shape (batch_size, sequence_length)
447
+ image_inputs: Dict
448
+ Output of ImageProcessor for image data. A dictionary containing the following keys:
449
+ - pixel_values: torch.FloatTensor of shape (num_images, num_patches, num_tokens, embed_dim)
450
+ - coords: List of shape (batch_size, num_images)
451
+ """
452
+ # processing image and text inputs
453
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
454
+ output_hidden_states = (
455
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
456
+ )
457
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
458
+ vision_feature_layer = (
459
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
460
+ )
461
+ vision_feature_select_strategy = (
462
+ vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy
463
+ )
464
+ patch_picking_strategy = patch_picking_strategy if patch_picking_strategy is not None else self.config.patch_picking_strategy
465
+ topk = topk if topk is not None else self.config.topk
466
+ keyword_criteria = keyword_criteria if keyword_criteria is not None else self.config.keyword_criteria
467
+ positional_information = positional_information if positional_information is not None else self.config.positional_information
468
+ visualize_patches = visualize_patches if visualize_patches is not None else self.config.visualize_patches
469
+ visualize_topk_patches = visualize_topk_patches if visualize_topk_patches is not None else self.config.visualize_topk_patches
470
+ print_keyword = print_keyword if print_keyword is not None else self.config.print_keyword
471
+ print_topk_patches = print_topk_patches if print_topk_patches is not None else self.config.print_topk_patches
472
+
473
+ if inputs_embeds is None:
474
+ # 1. Extra the input embeddings
475
+ inputs_embeds = self.get_input_embeddings()(input_ids)
476
+
477
+ # 2. Merge text and images
478
+ if pixel_values is not None and input_ids.shape[1] != 1:
479
+
480
+ with torch.no_grad():
481
+ keyword_input_ids = self.language_model.generate(keyword_prompt_input_ids, return_dict_in_generate=True, max_new_tokens=1024, eos_token_id=self.eos_token_id)
482
+ keyword_input_ids = keyword_input_ids.sequences[:, keyword_prompt_input_ids.shape[-1]:]
483
+ keyword_text = self.tokenizer.decode(keyword_input_ids[0], skip_special_tokens=True)
484
+
485
+ # print(keyword_text)
486
+ generated_keywords, generated_descriptions = self.generate_keywords(keyword_text, criteria=keyword_criteria)
487
+
488
+ all_text_features = []
489
+ for descriptions in generated_descriptions:
490
+ one_text_features = []
491
+ for description in descriptions:
492
+ keyword_ids = self.clip_tokenizer(description, return_tensors='pt')
493
+ keyword_ids = {k: v.to(self.device) for k, v in keyword_ids.items()}
494
+ text_features = self.encoder.get_text_features(**keyword_ids)
495
+ one_text_features.append(text_features / text_features.norm(p=2, dim=-1, keepdim=True))
496
+ all_text_features.append(torch.cat(one_text_features, dim=0))
497
+
498
+ pixel_values = [pixel_value.to(self.vision_tower.device, dtype=self.vision_tower.dtype) for pixel_value in pixel_values]
499
+ selected_image_features = []
500
+ selected_coords = []
501
+ for p, pixel_value in enumerate(pixel_values): # iterate through each image
502
+ print_keyword_text = f'Keywords (criteria: {keyword_criteria}):\n'
503
+ all_hidden_states = self.vision_tower(pixel_value, output_hidden_states=True).hidden_states # tuple of size (num_layers, num_patch, num_tokens, vison_embed_dim)
504
+ if patch_picking_strategy == 'last_layer':
505
+ hidden_states = [all_hidden_states[-1]]
506
+ elif patch_picking_strategy == 'across_layers':
507
+ hidden_states = deepcopy(all_hidden_states)
508
+ top_patches = [0]
509
+ for i, text_feature in enumerate(all_text_features):
510
+ print_keyword_text += f' {i+1}: ' + "\n ".join(generated_descriptions[i]) + '\n'
511
+ top_index = []
512
+ for hidden_state in hidden_states: # iterate through each layer
513
+ if 'clip' in self.encoder_type:
514
+ if vision_feature_select_strategy == 'cls':
515
+ image_features = self.encoder.visual_projection(self.encoder.vision_model.post_layernorm(hidden_state[1:, 0, :])) # (num_patch-1, embed_dim)
516
+ elif vision_feature_select_strategy == 'image_features':
517
+ image_features = self.encoder.visual_projection(self.encoder.vision_model.post_layernorm(hidden_state[1:, 1:, :])) # (num_patch-1 * num_tokens, embed_dim)
518
+ num_tokens = hidden_state.shape[1] - 1
519
+ elif 'siglip' in self.encoder_type:
520
+ if vision_feature_select_strategy == 'cls':
521
+ image_features = self.encoder.vision_model.head(self.encoder.vision_model.post_layernorm(hidden_state[1:, :, :])) # (num_patch-1, embed_dim)
522
+ elif vision_feature_select_strategy == 'image_features':
523
+ image_features = self.encoder.vision_model.post_layernorm(hidden_state[1:, :, :]) # (num_patch-1 * num_tokens, embed_dim)
524
+ num_tokens = hidden_state.shape[1]
525
+ image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
526
+
527
+ if vision_feature_select_strategy == 'cls':
528
+ similarity_score = torch.matmul(image_features, text_feature.T).mean(dim=1) # (num_patch-1)
529
+ if patch_picking_strategy == 'across_layers':
530
+ index = torch.topk(similarity_score, 1).indices
531
+ top_index.append(index.item()+1)
532
+ elif patch_picking_strategy == 'last_layer':
533
+ index = torch.topk(similarity_score, math.ceil(topk / len(all_text_features))).indices + 1 # take top k patches
534
+ top_index += index.tolist()
535
+ elif vision_feature_select_strategy == 'image_features':
536
+ image_features = image_features.flatten(0, 1)
537
+ similarity_score = torch.matmul(image_features, text_feature.T).mean(dim=1) # (num_patch-1 * num_tokens)
538
+ index = torch.topk(similarity_score, 100).indices # take top 100 tokens
539
+ patch_index = torch.floor(index / num_tokens) # get the patch index
540
+ count = torch.nn.functional.one_hot(patch_index.to(torch.int64)).sum(dim=0) # count the occurrences of each patch
541
+ if patch_picking_strategy == 'across_layers':
542
+ top_count = torch.topk(count, 1).indices # take top 1
543
+ top_index.append(top_count.item()+1)
544
+ elif patch_picking_strategy == 'last_layer':
545
+ top_count = torch.topk(count, math.ceil(topk / len(all_text_features))).indices + 1
546
+ top_index += top_count.tolist()
547
+
548
+ if visualize_patches and patch_picking_strategy == 'across_layers':
549
+ if 'clip' in self.encoder_type:
550
+ (x, y) = (5, 5)
551
+ elif 'siglip' in self.encoder_type:
552
+ (x, y) = (7, 4)
553
+ fig, axs = plt.subplots(y, x, figsize=(x * 2, y * 2))
554
+ fig.suptitle(f'keyword: {generated_keywords[i]}')
555
+ for k, index in enumerate(top_index):
556
+ axs[k // x, k % x].imshow(pixel_value[index].to(torch.float32).cpu().numpy().transpose(1, 2, 0))
557
+ axs[k // x, k % x].set_title(f'Layer {k}')
558
+ axs[k // x, k % x].axis('off')
559
+ plt.show()
560
+ if patch_picking_strategy == 'across_layers':
561
+ top_patches += torch.topk(torch.bincount(torch.tensor(top_index, dtype=torch.int64)), math.ceil(topk / len(all_text_features))).indices.to(dtype=torch.int64).tolist()
562
+ elif patch_picking_strategy == 'last_layer':
563
+ top_patches += top_index
564
+ topk_patches = list(set(top_patches))
565
+ if visualize_topk_patches:
566
+ fig, axs = plt.subplots(1, len(topk_patches), figsize=(len(topk_patches) * 2, 2))
567
+ fig.suptitle(f'top-{len(topk_patches)} patches')
568
+ for m, topk_patch in enumerate(topk_patches):
569
+ axs[m].imshow(pixel_value[topk_patch].to(torch.float32).cpu().numpy().transpose(1, 2, 0))
570
+ axs[m].axis('off')
571
+ plt.show()
572
+
573
+ if 'clip' in self.encoder_type:
574
+ selected_image_features.append(all_hidden_states[vision_feature_layer][topk_patches, 1:, :])
575
+ elif 'siglip' in self.encoder_type:
576
+ selected_image_features.append(all_hidden_states[vision_feature_layer][topk_patches, :, :])
577
+ selected_coords.append([coords[p][q-1] for q in topk_patches[1:]])
578
+ # if isinstance(pixel_values, list):
579
+ # pixel_values = torch.cat([x for x in pixel_values if x is not None], dim=0)
580
+ if print_keyword:
581
+ print(print_keyword_text)
582
+ multimodal_projector_features = []
583
+
584
+ for x, (selected_image_feature, selected_coord) in enumerate(zip(selected_image_features, selected_coords)):
585
+ print(f'image {x+1}: {selected_coord}')
586
+ if '2d' in positional_information:
587
+ max_width = max(selected_coord, key= lambda x: x[0])[0] + 1
588
+ max_height = max(selected_coord, key= lambda x: x[1])[1] + 1
589
+ positional_encoding = self.positional_encoding(torch.ones((max_width, max_height, selected_image_feature.shape[1], self.positional_encoding.org_channels), dtype=self.dtype, device=self.device))
590
+ accumulate = []
591
+ for i, top_patch in enumerate(selected_image_feature):
592
+ if positional_information == '2d_before' and i != 0:
593
+ top_patch += positional_encoding[selected_coord[i-1][0], selected_coord[i-1][1], :, :]
594
+ aligned_image_feature = self.multi_modal_projector(top_patch)
595
+ if positional_information == '2d_after' and i != 0:
596
+ aligned_image_feature += positional_encoding[selected_coord[i-1][0], selected_coord[i-1][1], :, :]
597
+ accumulate.append(aligned_image_feature)
598
+ if i == 0:
599
+ accumulate.append(self.get_input_embeddings()(self.tokenizer(', ', padding=False, truncation=False, max_length=None, return_tensors='pt')['input_ids'].to(device=self.device)[0, 1:]))
600
+ continue
601
+ if positional_information == 'explicit':
602
+ accumulate.append(self.get_input_embeddings()(self.tokenizer(f' at {str(selected_coord[i-1])}, ', padding=False, truncation=False, max_length=None, return_tensors='pt')['input_ids'].to(device=self.device)[0, 1:]))
603
+ else:
604
+ accumulate.append(self.get_input_embeddings()(self.tokenizer(f', ', padding=False, truncation=False, max_length=None, return_tensors='pt')['input_ids'].to(device=self.device)[0, 1:]))
605
+ multimodal_projector_features.append(torch.cat(accumulate, dim=0)) # dimension of (num_selected_patch * num_tokens-1 + num_selected_patch * sep_len - 1) -> (num_selected_patch * num_tokens - 1) as sep_len = 1
606
+
607
+ assert len(selected_image_features) == len(multimodal_projector_features), f"The number of selected image features and image features do not match. Dimension of selected image features: {len(selected_image_features)} and dimension of image features: {len(multimodal_projector_features)}."
608
+ # print(multimodal_projector_features[0].shape)
609
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
610
+ multimodal_projector_features, inputs_embeds, input_ids, attention_mask, labels
611
+ )
612
+ if labels is None:
613
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
614
+ else:
615
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
616
+ # generation with cache
617
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
618
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
619
+ # that are set to 0
620
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
621
+
622
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
623
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
624
+
625
+ # Get the target length
626
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
627
+
628
+ extended_attention_mask = torch.ones(
629
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
630
+ dtype=attention_mask.dtype,
631
+ device=attention_mask.device,
632
+ )
633
+
634
+ # Filter out only the tokens that can be un-attended, this can happen
635
+ # if one uses Llava + Fused modules where the cache on the
636
+ # first iteration is already big enough, or if one passes custom cache
637
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
638
+ new_batch_index = batch_index[valid_indices]
639
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
640
+
641
+ # Zero-out the places where we don't need to attend
642
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
643
+
644
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
645
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
646
+
647
+ outputs = self.language_model(
648
+ attention_mask=attention_mask,
649
+ position_ids=position_ids,
650
+ past_key_values=past_key_values,
651
+ inputs_embeds=inputs_embeds,
652
+ use_cache=use_cache,
653
+ output_attentions=output_attentions,
654
+ output_hidden_states=output_hidden_states,
655
+ return_dict=return_dict,
656
+ )
657
+
658
+ logits = outputs[0]
659
+
660
+ batch_shift = 100
661
+ loss = None
662
+ if labels is not None:
663
+ # Shift so that tokens < n predict n
664
+ if attention_mask is not None:
665
+ shift_attention_mask = attention_mask[..., 1:]
666
+ logits_shape = logits.shape
667
+ labels_shape = labels.shape
668
+ shift_attention_mask_shape = shift_attention_mask.shape
669
+ for i in range(0, shift_attention_mask.shape[1], batch_shift):
670
+ shift_logits = logits[..., i:min(i+batch_shift, logits_shape[1]-1), :][shift_attention_mask[..., i:min(i+batch_shift, shift_attention_mask_shape[1])].to(logits.device) != 0].contiguous()
671
+ shift_labels = labels[..., i+1:min(i+batch_shift+1, labels_shape[1])][shift_attention_mask[..., i:min(i+batch_shift, shift_attention_mask_shape[1])].to(labels.device) != 0].contiguous()
672
+ else:
673
+ shift_logits = logits[..., :-1, :].contiguous()
674
+ shift_labels = labels[..., 1:].contiguous()
675
+ # Flatten the tokens
676
+ loss_fct = nn.CrossEntropyLoss()
677
+ loss = loss_fct(
678
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
679
+ )
680
+
681
+ if not return_dict:
682
+ output = (logits,) + outputs[1:]
683
+ return (loss,) + output if loss is not None else output
684
+
685
+ return GeckoCausalLMOutputWithPast(
686
+ loss=loss,
687
+ logits=logits,
688
+ past_key_values=outputs.past_key_values,
689
+ hidden_states=outputs.hidden_states,
690
+ attentions=outputs.attentions,
691
+ )
692
+
693
+ def prepare_inputs_for_generation(
694
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, keyword_prompt_input_ids=None, coords=None, **kwargs
695
+ ):
696
+
697
+ if past_key_values is not None:
698
+ if isinstance(past_key_values, Cache):
699
+ cache_length = past_key_values.get_seq_length()
700
+ past_length = past_key_values.seen_tokens
701
+ else:
702
+ cache_length = past_length = past_key_values[0][0].shape[2]
703
+
704
+ # Keep only the unprocessed tokens:
705
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
706
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
707
+ # input)
708
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
709
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
710
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
711
+ # input_ids based on the past_length.
712
+ elif past_length < input_ids.shape[1]:
713
+ input_ids = input_ids[:, past_length:]
714
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
715
+ elif self.config.image_token_index in input_ids:
716
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
717
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
718
+ # older attention values, as their corresponding values are not part of the input.
719
+ if cache_length < past_length and attention_mask is not None:
720
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
721
+
722
+ position_ids = kwargs.get("position_ids", None)
723
+ if attention_mask is not None and position_ids is None:
724
+ # create position_ids on the fly for batch generation
725
+ position_ids = attention_mask.long().cumsum(-1) - 1
726
+ position_ids.masked_fill_(attention_mask == 0, 1)
727
+ if past_key_values:
728
+ position_ids = position_ids[:, -input_ids.shape[1] :]
729
+
730
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
731
+ if inputs_embeds is not None and past_key_values is None:
732
+ model_inputs = {"inputs_embeds": inputs_embeds}
733
+ else:
734
+ model_inputs = {"input_ids": input_ids}
735
+
736
+ model_inputs.update(
737
+ {
738
+ "position_ids": position_ids,
739
+ "past_key_values": past_key_values,
740
+ "use_cache": kwargs.get("use_cache"),
741
+ "attention_mask": attention_mask,
742
+ "pixel_values": pixel_values,
743
+ "keyword_prompt_input_ids": keyword_prompt_input_ids,
744
+ "coords": coords,
745
+ "topk": kwargs.get("topk"),
746
+ "vision_feature_select_strategy": kwargs.get("vision_feature_select_strategy"),
747
+ "vision_feature_layer": kwargs.get("vision_feature_layer"),
748
+ "patch_picking_strategy": kwargs.get("patch_picking_strategy"),
749
+ "keyword_criteria": kwargs.get("keyword_criteria"),
750
+ "positional_information": kwargs.get("positional_information"),
751
+ "visualize_patches": kwargs.get("visualize_patches"),
752
+ "visualize_topk_patches": kwargs.get("visualize_topk_patches"),
753
+ "print_keyword": kwargs.get("print_keyword"),
754
+ "print_topk_patches": kwargs.get("print_topk_patches"),
755
+ }
756
+ )
757
+ return model_inputs
758
+
759
+ def _reorder_cache(self, *args, **kwargs):
760
+ return self.language_model._reorder_cache(*args, **kwargs)
model/multimodal_encoder.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is referenced from https://github.com/dhansmair/flamingo-mini
2
+
3
+ import torch
4
+ from einops import rearrange, repeat
5
+ from einops_exts import rearrange_many
6
+ from torch import einsum, nn
7
+ import math
8
+ import torch.nn.functional as F
9
+ from .configuration_gecko import GeckoConfig
10
+ from transformers.activations import ACT2FN
11
+ from torch.nn.init import trunc_normal_
12
+ from functools import partial
13
+
14
+ def feed_forward_layer(dim: int, mult: int = 4, activation: str = 'gelu'):
15
+ """Feed forward layer with given activation function"""
16
+
17
+ activations = dict(gelu=nn.GELU, relu=nn.ReLU)
18
+ assert activation in activations, f'activation can only be one of {activations.keys()}'
19
+
20
+ inner_dim = int(dim * mult)
21
+ return nn.Sequential(
22
+ nn.LayerNorm(dim),
23
+ nn.Linear(dim, inner_dim, bias=False),
24
+ activations[activation](),
25
+ nn.Linear(inner_dim, dim, bias=False),
26
+ )
27
+
28
+ class PerceiverAttentionLayer(nn.Module):
29
+ """Perceiver Attention Layer"""
30
+
31
+ def __init__(self, dim: int, dim_head: int = 64, heads: int = 8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.heads = heads
35
+ self.dim_head = dim_head
36
+ inner_dim = dim_head * heads
37
+
38
+ # trainable components of PerceiverAttentionLayer
39
+ self.norm_media = nn.LayerNorm(dim)
40
+ self.norm_latents = nn.LayerNorm(dim)
41
+
42
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
43
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
44
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
45
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
46
+
47
+ def forward(self, features, latents):
48
+ """Latent vectors are cross-attending to the visual features x
49
+
50
+ Args:
51
+ features: Batch of visual features with shape (batch_size, n_tokens, dim)
52
+ latents: Latent learnt vectors which are used to compute queries with shape (batch_size, n_latents, dim)
53
+
54
+ Returns:
55
+ Attention score with shape (batch_size, n_latents, dim)
56
+ """
57
+ assert features.ndim == 3
58
+ assert latents.ndim == 3
59
+ assert features.shape[0] == latents.shape[0]
60
+ assert features.shape[2] == latents.shape[2]
61
+
62
+ n_heads = self.heads
63
+ n_batch, n_features, dim = features.shape
64
+ n_queries = latents.shape[1]
65
+
66
+ # Layer normalization
67
+ x = self.norm_media(features)
68
+ latents = self.norm_latents(latents)
69
+
70
+ # Compute the queries from the latents, for all attention heads simultaneously
71
+ q = self.to_q(latents)
72
+ q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads)
73
+ assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head])
74
+
75
+ # Keys and values for all attention heads
76
+ kv_input = torch.cat((x, latents), dim=-2)
77
+ n_features_latents = n_features + n_queries
78
+ k = self.to_k(kv_input)
79
+ v = self.to_v(kv_input)
80
+
81
+ k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads)
82
+ assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head])
83
+
84
+ q = q * self.scale
85
+
86
+ # Attention scores
87
+ sim = einsum('b h q d, b h f d -> b h q f', q, k)
88
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
89
+ alphas = sim.softmax(dim=-1)
90
+
91
+ out = einsum('b h q f, b h f v -> b h q v', alphas, v)
92
+ out = rearrange(out, 'b h q v -> b q (h v)')
93
+
94
+ return self.to_out(out)
95
+
96
+ class GeckoResamplerProjector(nn.Module):
97
+ """Perceiver Resampler with multi-head attention layer"""
98
+
99
+ def __init__(
100
+ self,
101
+ config: GeckoConfig,
102
+ num_queries: int = 64,
103
+ depth: int = 2,
104
+ dim_head: int = 32,
105
+ heads: int = 4,
106
+ ff_mult: int = 2,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.dim = config.text_config.hidden_size
111
+ self.num_queries = num_queries
112
+
113
+ self.latents = nn.Parameter(torch.randn(self.num_queries, self.dim)) # type: ignore[reportPrivateUsage]
114
+
115
+ self.linear = nn.Linear(config.vision_config.hidden_size, self.dim)
116
+
117
+ self.layers = nn.ModuleList([])
118
+ for _ in range(depth):
119
+ self.layers.append(
120
+ nn.ModuleList(
121
+ [
122
+ PerceiverAttentionLayer(dim=self.dim, dim_head=dim_head, heads=heads),
123
+ feed_forward_layer(dim=self.dim, mult=ff_mult, activation=config.projector_hidden_act),
124
+ ]
125
+ )
126
+ )
127
+
128
+ # Layer normalization takes as input the query vector length
129
+ self.norm = nn.LayerNorm(self.dim)
130
+
131
+ def forward(self, x_f: torch.Tensor):
132
+ """Run perceiver resampler on the input visual embeddings
133
+
134
+ Args:
135
+ x_f: Input visual embeddings of shape (batch_size, num_tokens, d_visual)
136
+
137
+ Returns:
138
+ Resampler features of shape (batch_size, num_queries, d_visual)
139
+ """
140
+ assert x_f.ndim == 3
141
+
142
+ x_f = self.linear(x_f)
143
+
144
+ batch_size, num_tokens, dim = x_f.shape
145
+
146
+ assert dim == self.dim
147
+
148
+ # Copy the latents for every element in the batch
149
+ x = repeat(self.latents, 'q d -> b q d', b=batch_size)
150
+
151
+ # Apply attention and feed forward layer
152
+ for attn, ffw in self.layers:
153
+ x = x + attn(x_f, x)
154
+ x = x + ffw(x)
155
+
156
+ assert x.shape == torch.Size([batch_size, self.num_queries, self.dim])
157
+
158
+ norm = self.norm(x)
159
+ return norm
160
+
161
+ class GeckoMLPProjector(nn.Module):
162
+ def __init__(self, config: GeckoConfig):
163
+ super().__init__()
164
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
165
+ self.act = ACT2FN[config.projector_hidden_act]
166
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
167
+
168
+ def forward(self, image_features):
169
+ hidden_states = self.linear_1(image_features)
170
+ hidden_states = self.act(hidden_states)
171
+ hidden_states = self.linear_2(hidden_states)
172
+ return hidden_states
model/processing_gecko.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Union, Dict
3
+
4
+ import torch
5
+ from PIL import Image
6
+ import logging
7
+
8
+ import os
9
+ import json
10
+ import re
11
+ from transformers.feature_extraction_sequence_utils import BatchFeature
12
+ from transformers.image_utils import ImageInput
13
+ from transformers import ProcessorMixin, ImageProcessingMixin, AutoImageProcessor, AutoTokenizer, AutoProcessor
14
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
15
+ from transformers.utils import TensorType
16
+ from transformers.processing_utils import transformers_module
17
+ from transformers.utils.hub import is_remote_url, download_url, cached_file, is_offline_mode
18
+ from transformers.utils import IMAGE_PROCESSOR_NAME
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class GeckoProcessor(ProcessorMixin):
24
+ attributes = ["image_processor", "tokenizer"]
25
+ image_processor_class = ("CLIPImageProcessor", "SiglipImageProcessor")
26
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast", "PreTrainedTokenizerFast")
27
+
28
+ def __init__(self, image_processor=None, tokenizer=None, use_keyword=False, crop_size=336, cropping_method='dynamic', **kwargs):
29
+ super().__init__(image_processor, tokenizer)
30
+ self.crop_size = crop_size if crop_size is not None else int(image_processor.size['height'])
31
+ self.use_keyword = use_keyword
32
+ self.image_token_index = None
33
+ self.cropping_method = cropping_method
34
+ self.load_clip_tokenizer()
35
+
36
+ def load_clip_tokenizer(self):
37
+ if 'clip' in self.image_processor.image_processor_type.lower():
38
+ self.clip_tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-large-patch14-336')
39
+ elif 'siglip' in self.image_processor.image_processor_type.lower():
40
+ self.clip_tokenizer = AutoTokenizer.from_pretrained("google/siglip-so400m-patch14-384")
41
+ else:
42
+ raise ValueError(f"Invalid image processor type: {self.image_processor.image_processor_type}")
43
+
44
+ def process_images(self, images: List[Image.Image]):
45
+ # create documentation
46
+ """
47
+ Parameters:
48
+ images: List[Image.Image]
49
+ List of PIL images to be processed
50
+ Returns:
51
+ Dict[str, torch.Tensor]:
52
+ pixel_values: List[torch.Tensor]
53
+ Pixel values of the images. Has shape (num_images, num_patches, num_channels, height, width)
54
+ coords: List[List[List[int]]]
55
+ Coordinates of the cropped images. Has shape (num_images, num_patches, 2)
56
+ """
57
+
58
+ pixel_values = []
59
+ coords = []
60
+
61
+ for image in images:
62
+ outputs, coord = self.dynamic_preprocess(image)
63
+ pixel_values.append(outputs)
64
+ coords.append(coord)
65
+
66
+ return {"pixel_values": pixel_values, "coords": coords}
67
+
68
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
69
+ best_ratio_diff = float('inf')
70
+ best_ratio = (1, 1)
71
+ area = width * height
72
+ for ratio in target_ratios:
73
+ target_aspect_ratio = ratio[0] / ratio[1]
74
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
75
+ if ratio_diff < best_ratio_diff:
76
+ best_ratio_diff = ratio_diff
77
+ best_ratio = ratio
78
+ elif ratio_diff == best_ratio_diff:
79
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
80
+ best_ratio = ratio
81
+ return best_ratio
82
+
83
+ def dynamic_preprocess(self, image):
84
+ orig_width, orig_height = image.size
85
+ aspect_ratio = orig_width / orig_height
86
+
87
+ if self.cropping_method == 'dynamic':
88
+ max_num = math.ceil(orig_width / self.crop_size) * math.ceil(orig_height / self.crop_size)
89
+
90
+ # calculate the existing image aspect ratio
91
+ target_ratios = set(
92
+ (i, j) for n in range(1, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
93
+ i * j <= max_num and i * j >= 1)
94
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
95
+
96
+ # find the closest aspect ratio to the target
97
+ target_aspect_ratio = self.find_closest_aspect_ratio(
98
+ aspect_ratio, target_ratios, orig_width, orig_height, self.crop_size)
99
+ # if target_aspect_ratio[0] * target_aspect_ratio[1] <= 25:
100
+ # target_aspect_ratio = (int(1.5 * target_aspect_ratio[0]), int(1.5 * target_aspect_ratio[1]))
101
+
102
+ elif self.cropping_method == 'naive':
103
+ target_aspect_ratio = (orig_width // self.crop_size, orig_height // self.crop_size)
104
+ # print(target_aspect_ratio)
105
+ # if target_aspect_ratio[0] * target_aspect_ratio[1] <= 25:
106
+ # target_aspect_ratio = (2 * orig_width // self.crop_size, 2 * orig_height // self.crop_size)
107
+ # print(target_aspect_ratio)
108
+ else:
109
+ raise ValueError(f"Invalid cropping method: {self.cropping_method}")
110
+
111
+ # calculate the target width and height
112
+ target_width = self.crop_size * target_aspect_ratio[0]
113
+ target_height = self.crop_size * target_aspect_ratio[1]
114
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
115
+
116
+ # add whole image
117
+ processed_images = []
118
+ processed_images.append(image.resize((self.crop_size, self.crop_size)))
119
+ coords = []
120
+ if blocks == 1:
121
+ return self.image_processor(images=processed_images, return_tensors='pt')['pixel_values'], coords
122
+
123
+ # resize the image
124
+ resized_img = image.resize((target_width, target_height))
125
+ for i in range(blocks):
126
+ x0 = (i % (target_width // self.crop_size))
127
+ y0 = (i // (target_width // self.crop_size))
128
+ x1 = ((i % (target_width // self.crop_size)) + 1)
129
+ y1 = ((i // (target_width // self.crop_size)) + 1)
130
+
131
+ box = (
132
+ x0 * self.crop_size,
133
+ y0 * self.crop_size,
134
+ x1 * self.crop_size,
135
+ y1 * self.crop_size
136
+ )
137
+ split_img = resized_img.crop(box)
138
+ processed_images.append(split_img)
139
+
140
+ coords.append([x0, y0])
141
+
142
+ # box = (
143
+ # (i % (target_width // self.crop_size)) * self.crop_size,
144
+ # (i // (target_width // self.crop_size)) * self.crop_size,
145
+ # ((i % (target_width // self.crop_size)) + 1) * self.crop_size,
146
+ # ((i // (target_width // self.crop_size)) + 1) * self.crop_size
147
+ # )
148
+ # split the image
149
+
150
+ assert len(processed_images) == blocks + 1
151
+
152
+ return self.image_processor(images=processed_images, return_tensors='pt')['pixel_values'], coords
153
+
154
+
155
+ def preprocess_interleaved_images_and_text(
156
+ self,
157
+ text,
158
+ images=None,
159
+ ):
160
+ """
161
+ Args:
162
+ text (`str`, `List[str]`):
163
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
164
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
165
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
166
+ text can contain <image> tokens as the placeholder for the image(s) to be inserted.
167
+ images (`PIL.Image.Image`, `List[PIL.Image.Image]`, `List[List[PIL.Image.Image]]`):
168
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
169
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
170
+ number of channels, H and W are image height and width.
171
+ the number of the images should match the number of <image> tokens in the text.
172
+
173
+ """
174
+ assert text is not None, "text cannot be None."
175
+
176
+ if images is not None:
177
+ if isinstance(images, Image.Image):
178
+ images = [images]
179
+ if isinstance(images, list) and isinstance(images[0], Image.Image):
180
+ if isinstance(text, str):
181
+ images = [images]
182
+ elif isinstance(text, list):
183
+ if len(text) != len(images):
184
+ raise ValueError("Invalid input text. Number of texts does not match number of images.")
185
+ images = [[image] for image in images]
186
+ if isinstance(text, str):
187
+ num_images = len(images[0])
188
+ num_image_tokens = text.count("<image>")
189
+ if num_image_tokens < num_images:
190
+ # prepend empty image tokens to text
191
+ if "USER:" in text:
192
+ text = text.replace("USER:", "USER:" + "<image>" * (num_images - num_image_tokens), 1)
193
+ elif "Human:" in text:
194
+ text = text.replace("Human:", "Human:" + "<image>" * (num_images - num_image_tokens), 1)
195
+ elif "HUMAN:" in text:
196
+ text = text.replace("HUMAN:", "HUMAN:" + "<image>" * (num_images - num_image_tokens), 1)
197
+ else:
198
+ text = "<image>" * (num_images - num_image_tokens) + text
199
+ # logger.warning("Image Tokens <image> are not provided in the text. Automatically prepending them before the text. This might cause model to behave unexpectedly.")
200
+ elif num_image_tokens > num_images:
201
+ text = text.split("<image>")
202
+ for i, t in enumerate(text):
203
+ if i < num_images:
204
+ text[i] = t + "<image>"
205
+ text = "".join(text)
206
+ logger.warning(f"Number of <image> tokens: {num_image_tokens} exceeds number of images: {num_images}. Automatically removing extra tokens at the end of the text.")
207
+ # raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
208
+ texts = [text]
209
+ elif isinstance(text, list):
210
+ if not isinstance(text[0], str):
211
+ raise ValueError("Invalid input text. Each element of text must be a string.")
212
+ for i, t in enumerate(text):
213
+ num_image_tokens = t.count("<image>")
214
+ num_images = len(images[i])
215
+ if num_image_tokens < num_images:
216
+ # prepend empty image tokens to text
217
+ if "USER:" in t:
218
+ t = t.replace("USER:", "USER:" + "<image>" * (num_images - num_image_tokens), 1)
219
+ elif "Human:" in t:
220
+ t = t.replace("Human:", "Human:" + "<image>" * (num_images - num_image_tokens), 1)
221
+ elif "HUMAN:" in t:
222
+ t = t.replace("HUMAN:", "HUMAN:" + "<image>" * (num_images - num_image_tokens), 1)
223
+ else:
224
+ t = "<image>" * (num_images - num_image_tokens) + t
225
+ # logger.warning("Image Tokens <image> are not provided in the text. Automatically prepending them before the text. This might cause model to behave unexpectedly.")
226
+ elif num_image_tokens > num_images:
227
+ t = t.split("<image>")
228
+ for j, s in enumerate(t):
229
+ if j < num_images:
230
+ t[j] = s + "<image>"
231
+ t = "".join(t)
232
+ logger.warning(f"Number of <image> tokens: {num_image_tokens} exceeds number of images: {num_images}. Automatically removing extra tokens at the end of the text.")
233
+ # raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
234
+ text[i] = t
235
+ texts = text
236
+ else:
237
+ raise ValueError("Invalid input text. text must be a string or a list of strings.")
238
+ assert all([t.count("<image>") == len(images_per_text) for t, images_per_text in zip(texts, images)]), "Number of <image> tokens in text does not match number of images."
239
+ # add image denotation in text before each <image> as "(image {i}: <image>)"
240
+ for i, t in enumerate(texts):
241
+ for j in range(len(images[i])):
242
+ t = t.replace("<image>", f"(image {j+1}: <Image><IMAGE></Image>)", 1)
243
+ t = t.replace("<IMAGE>", "<image>")
244
+ texts[i] = t
245
+
246
+ else:
247
+ if isinstance(text, str):
248
+ texts = [text]
249
+ elif isinstance(text, list):
250
+ if not isinstance(text[0], str):
251
+ raise ValueError("Invalid input text. Each element of text must be a string.")
252
+ texts = text
253
+ else:
254
+ raise ValueError("Invalid input text. text must be a string or a list of strings.")
255
+
256
+ return texts, images
257
+
258
+ def __call__(
259
+ self,
260
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
261
+ keywords_text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
262
+ images: ImageInput = None,
263
+ padding: Union[bool, str, PaddingStrategy] = False,
264
+ truncation: Union[bool, str, TruncationStrategy] = None,
265
+ max_length=None,
266
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
267
+ add_image_ids: bool = True,
268
+ cropping_method: str = None,
269
+ ) -> BatchFeature:
270
+ """
271
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
272
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
273
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
274
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
275
+ of the above two methods for more information.
276
+
277
+ Args:
278
+ text (`str`, `List[str]`, `List[List[str]]`):
279
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
280
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
281
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
282
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
283
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
284
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
285
+ number of channels, H and W are image height and width.
286
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
287
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
288
+ index) among:
289
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
290
+ sequence if provided).
291
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
292
+ acceptable input length for the model if that argument is not provided.
293
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
294
+ lengths).
295
+ max_length (`int`, *optional*):
296
+ Maximum length of the returned list and optionally padding length (see above).
297
+ truncation (`bool`, *optional*):
298
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
299
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
300
+ If set, will return tensors of a particular framework. Acceptable values are:
301
+
302
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
303
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
304
+ - `'np'`: Return NumPy `np.ndarray` objects.
305
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
306
+
307
+ Returns:
308
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
309
+
310
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
311
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
312
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
313
+ `None`).
314
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. Have shape of (num_images, num_patches, num_tokens, embed_dim)
315
+ - **coords** -- Coordinates of the cropped images. Returned when `images` is not `None`. Have shape of (num_images, num_patches, 2)
316
+ """
317
+
318
+ if cropping_method is not None:
319
+ self.cropping_method = cropping_method
320
+
321
+ if not self.image_token_index:
322
+ self.image_token_index = self.tokenizer.convert_tokens_to_ids("<image>")
323
+
324
+ if add_image_ids:
325
+ text, images = self.preprocess_interleaved_images_and_text(text, images)
326
+
327
+ text_inputs = self.tokenizer(
328
+ text,
329
+ padding=padding,
330
+ truncation=truncation,
331
+ max_length=max_length,
332
+ return_tensors=return_tensors,
333
+ )
334
+
335
+ if self.use_keyword and keywords_text is not None:
336
+ keywords_prompt_input_ids = self.tokenizer(keywords_text,
337
+ padding=padding,
338
+ truncation=truncation,
339
+ max_length=max_length,
340
+ return_tensors=return_tensors)['input_ids']
341
+ else:
342
+ keywords_prompt_input_ids = None
343
+
344
+ if images is not None:
345
+ input_ids = text_inputs["input_ids"]
346
+ num_image_tokens = torch.sum(input_ids == self.image_token_index, dim=-1)
347
+ for i, num_image_token in enumerate(num_image_tokens):
348
+ if num_image_token < len(images[i]):
349
+ images[i] = images[i][:num_image_token]
350
+ print(f"{len(images[i]) - num_image_token} ({len(images[i])} in total) image tokens in the text are truncated due to the max sequence length; removing the extra images.")
351
+ # flatten images
352
+ images = [image for images_per_text in images for image in images_per_text]
353
+ image_inputs = self.process_images(images)
354
+ else:
355
+ image_inputs = {"pixel_values": None, "coords": None}
356
+
357
+ return BatchFeature(data={**text_inputs, **image_inputs, "keyword_prompt_input_ids": keywords_prompt_input_ids})
358
+
359
+ def batch_decode(self, *args, **kwargs):
360
+ return self.tokenizer.batch_decode(*args, **kwargs)
361
+
362
+ def decode(self, *args, **kwargs):
363
+ return self.tokenizer.decode(*args, **kwargs)
364
+
365
+ @property
366
+ def model_input_names(self):
367
+ tokenizer_input_names = self.tokenizer.model_input_names
368
+ image_processor_input_names = self.image_processor.model_input_names
369
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
370
+
371
+ def _right_pad_inputs_with_attention_mask(self, model_inputs: List[Dict]):
372
+ results = {}
373
+ assert len(model_inputs) == 1, "This method only supports a single input, but get {} inputs".format(len(model_inputs))
374
+ for k in model_inputs[0].keys():
375
+ if k == "pixel_values" or k == "coords":
376
+ results[k] = model_inputs[0][k] if model_inputs[0][k] is not None else None
377
+ else:
378
+ results[k] = torch.cat([model_inputs[0][k]], dim=0) if model_inputs[0][k] is not None else None
379
+ return results
380
+
381
+ @classmethod
382
+ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
383
+ args = []
384
+
385
+ cache_dir = kwargs.pop("cache_dir", None)
386
+ force_download = kwargs.pop("force_download", False)
387
+ resume_download = kwargs.pop("resume_download", False)
388
+ proxies = kwargs.pop("proxies", None)
389
+ token = kwargs.pop("token", None)
390
+ local_files_only = kwargs.pop("local_files_only", False)
391
+ revision = kwargs.pop("revision", None)
392
+ subfolder = kwargs.pop("subfolder", "")
393
+
394
+ from_pipeline = kwargs.pop("_from_pipeline", None)
395
+ from_auto_class = kwargs.pop("_from_auto", False)
396
+
397
+ user_agent = {"file_type": "processor", "from_auto_class": from_auto_class}
398
+ if from_pipeline is not None:
399
+ user_agent["using_pipeline"] = from_pipeline
400
+
401
+ if is_offline_mode() and not local_files_only:
402
+ logger.info("Offline mode: forcing local_files_only=True")
403
+ local_files_only = True
404
+
405
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
406
+ is_local = os.path.isdir(pretrained_model_name_or_path)
407
+ if os.path.isdir(pretrained_model_name_or_path):
408
+ processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
409
+ if os.path.isfile(pretrained_model_name_or_path):
410
+ resolved_processor_file = pretrained_model_name_or_path
411
+ is_local = True
412
+ elif is_remote_url(pretrained_model_name_or_path):
413
+ processor_file = pretrained_model_name_or_path
414
+ resolved_processor_file = download_url(pretrained_model_name_or_path)
415
+ else:
416
+ processor_file = IMAGE_PROCESSOR_NAME
417
+ try:
418
+ # Load from local folder or from cache or download from model Hub and cache
419
+ resolved_processor_file = cached_file(
420
+ pretrained_model_name_or_path,
421
+ processor_file,
422
+ cache_dir=cache_dir,
423
+ force_download=force_download,
424
+ proxies=proxies,
425
+ resume_download=resume_download,
426
+ local_files_only=local_files_only,
427
+ token=token,
428
+ user_agent=user_agent,
429
+ revision=revision,
430
+ subfolder=subfolder,
431
+ _raise_exceptions_for_missing_entries=True,
432
+ )
433
+ except EnvironmentError:
434
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
435
+ # the original exception.
436
+ raise
437
+ except Exception:
438
+ # For any other exception, we throw a generic error.
439
+ raise EnvironmentError(
440
+ f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load"
441
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
442
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
443
+ f" directory containing a {IMAGE_PROCESSOR_NAME} file"
444
+ )
445
+
446
+ # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
447
+ # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
448
+ # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
449
+ # However, for models added in the future, we won't get the expected error if this file is missing.
450
+ if resolved_processor_file is None:
451
+ image_processor_dict = {}
452
+
453
+ try:
454
+ # Load processor dict
455
+ with open(resolved_processor_file, "r", encoding="utf-8") as reader:
456
+ text = reader.read()
457
+ image_processor_dict = json.loads(text)
458
+
459
+ except json.JSONDecodeError:
460
+ raise EnvironmentError(
461
+ f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file."
462
+ )
463
+
464
+ for attribute_name in cls.attributes:
465
+ class_name = getattr(cls, f"{attribute_name}_class")
466
+ if isinstance(class_name, tuple):
467
+ if attribute_name == "tokenizer":
468
+ classes = tuple(getattr(transformers_module, n) if n is not None else None for n in class_name)
469
+ use_fast = kwargs.get("use_fast", True)
470
+ if use_fast and classes[1] is not None:
471
+ attribute_class = classes[1]
472
+ else:
473
+ attribute_class = classes[0]
474
+ elif attribute_name == "image_processor":
475
+ image_processor_type = image_processor_dict.get("image_processor_type", None)
476
+ if image_processor_type is not None:
477
+ assert image_processor_type in class_name, f"Invalid image processor type: {image_processor_type}"
478
+ attribute_class = getattr(transformers_module, image_processor_type)
479
+ else:
480
+ attribute_class = getattr(transformers_module, class_name[0])
481
+ else:
482
+ raise ValueError(f"Invalid attribute name: {attribute_name}")
483
+ else:
484
+ attribute_class = getattr(transformers_module, class_name)
485
+
486
+ args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
487
+ return args
model/utils.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import torch
3
+ from .modelling_gecko import GeckoForConditionalGeneration
4
+ from .processing_gecko import GeckoProcessor
5
+ from .conversation import conv_llama_3 as default_conv, conv_templates
6
+ import transformers
7
+
8
+ from typing import List, Tuple, Union
9
+ from io import StringIO
10
+ import sys
11
+
12
+ class Capturing(list):
13
+ def __enter__(self):
14
+ self._stdout = sys.stdout
15
+ sys.stdout = self._stringio = StringIO()
16
+ return self
17
+ def __exit__(self, *args):
18
+ self.extend(self._stringio.getvalue().splitlines())
19
+ del self._stringio # free up some memory
20
+ sys.stdout = self._stdout
21
+
22
+
23
+ def chat_gecko(
24
+ text:str,
25
+ images: List[Union[PIL.Image.Image, str]],
26
+ model:GeckoForConditionalGeneration,
27
+ processor:GeckoProcessor,
28
+ max_input_length:int=None,
29
+ history:List[dict]=None,
30
+ **kwargs) -> Tuple[str, List[dict]]:
31
+
32
+ if "llama-3" in model.language_model.name_or_path.lower():
33
+ conv = conv_templates['llama_3']
34
+ terminators = [
35
+ processor.tokenizer.eos_token_id,
36
+ processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
37
+ ]
38
+ else:
39
+ conv = default_conv
40
+ terminators = None
41
+
42
+ kwargs["eos_token_id"] = terminators
43
+ conv = conv.copy()
44
+ conv.messages = []
45
+ if history is not None:
46
+ for message in history:
47
+ assert message["role"] in conv.roles
48
+ conv.append_message(message["role"], message["text"])
49
+ if text:
50
+ assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
51
+ conv.append_message(conv.roles[0], text)
52
+ conv.append_message(conv.roles[1], "")
53
+ history.append({"role": conv.roles[0], "text": text})
54
+ history.append({"role": conv.roles[1], "text": ""})
55
+ else:
56
+ if conv.messages[-1][0] == conv.roles[1]:
57
+ assert conv.messages[-1][1] == "", "No user message should be provided"
58
+ else:
59
+ assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
60
+ conv.append_message(conv.roles[0], "")
61
+ history.append({"role": conv.roles[0], "text": ""})
62
+ else:
63
+ history = []
64
+ history.append({"role": conv.roles[0], "text": text})
65
+ history.append({"role": conv.roles[1], "text": ""})
66
+ conv.append_message(conv.roles[0], text)
67
+ conv.append_message(conv.roles[1], "")
68
+ assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
69
+ assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
70
+
71
+ keyword_prompt = conv.generate_keyword_prompt(text.split("\n")[len(images)])
72
+ prompt = conv.get_prompt()
73
+ if images:
74
+ for i in range(len(images)):
75
+ if isinstance(images[i], str):
76
+ images[i] = PIL.Image.open(images[i]).convert("RGB")
77
+
78
+ inputs = processor(images=images, text=prompt, keywords_text=keyword_prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
79
+ for k, v in inputs.items():
80
+ if v is not None:
81
+ if isinstance(v, torch.Tensor):
82
+ inputs[k] = v.to(model.device)
83
+ elif isinstance(v, list):
84
+ if k == 'coords':
85
+ continue
86
+ inputs[k] = [x.to(model.device) for x in v]
87
+ elif isinstance(v, transformers.tokenization_utils_base.BatchEncoding) or isinstance(v, dict):
88
+ for key, value in v.items():
89
+ if value is not None:
90
+ if isinstance(value, list):
91
+ inputs[k][key] = [x.to(model.device) for x in value]
92
+ else:
93
+ inputs[k][key] = value.to(model.device)
94
+ else:
95
+ raise ValueError(f"Invalid input type: {type(v)}")
96
+
97
+ with torch.inference_mode():
98
+ output_ids = model.generate(**inputs, **kwargs)[0]
99
+
100
+ # remove the input tokens
101
+ generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
102
+ generated_text = processor.decode(generated_ids, skip_special_tokens=True)
103
+
104
+ history[-1]["text"] = generated_text
105
+
106
+ return generated_text, history
107
+
108
+ def chat_gecko_stream(
109
+ text:str,
110
+ images: List[Union[PIL.Image.Image, str]],
111
+ model:GeckoForConditionalGeneration,
112
+ processor:GeckoProcessor,
113
+ max_input_length:int=None,
114
+ history:List[dict]=None,
115
+ **kwargs) -> Tuple[str, List[dict]]:
116
+
117
+ if "llama-3" in model.language_model.name_or_path.lower():
118
+ conv = conv_templates['llama_3']
119
+ terminators = [
120
+ processor.tokenizer.eos_token_id,
121
+ processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
122
+ ]
123
+ else:
124
+ conv = default_conv
125
+ terminators = None
126
+ kwargs["eos_token_id"] = terminators
127
+ conv = conv.copy()
128
+ conv.messages = []
129
+ if history is not None:
130
+ for message in history:
131
+ assert message["role"] in conv.roles
132
+ conv.append_message(message["role"], message["text"])
133
+ if text:
134
+ assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
135
+ conv.append_message(conv.roles[0], text)
136
+ conv.append_message(conv.roles[1], "")
137
+ history.append({"role": conv.roles[0], "text": text})
138
+ history.append({"role": conv.roles[1], "text": ""})
139
+ else:
140
+ if conv.messages[-1][0] == conv.roles[1]:
141
+ assert conv.messages[-1][1] == "", "No user message should be provided"
142
+ else:
143
+ assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
144
+ conv.append_message(conv.roles[0], "")
145
+ history.append({"role": conv.roles[0], "text": ""})
146
+ else:
147
+ history = []
148
+ history.append({"role": conv.roles[0], "text": text})
149
+ history.append({"role": conv.roles[1], "text": ""})
150
+ conv.append_message(conv.roles[0], text)
151
+ conv.append_message(conv.roles[1], "")
152
+ assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
153
+ assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
154
+
155
+ if images:
156
+ for i in range(len(images)):
157
+ if isinstance(images[i], str):
158
+ images[i] = PIL.Image.open(images[i])
159
+ last_prompt = history[-2]['text'].split("?")[0]
160
+ last_prompt = last_prompt.replace('<image>', '').strip() if '<image>' in last_prompt else last_prompt.strip()
161
+ keyword_prompt = conv.generate_keyword_prompt(last_prompt.replace('<image>', '').strip()) if '<image>' in last_prompt else conv.generate_keyword_prompt(last_prompt.strip())
162
+ else:
163
+ keyword_prompt = None
164
+ prompt = conv.get_prompt()
165
+
166
+ inputs = processor(images=images, text=prompt, keywords_text=keyword_prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
167
+ for k, v in inputs.items():
168
+ if v is not None:
169
+ if isinstance(v, torch.Tensor):
170
+ inputs[k] = v.to(model.device)
171
+ elif isinstance(v, list):
172
+ if k == 'coords':
173
+ continue
174
+ inputs[k] = [x.to(model.device) for x in v]
175
+ elif isinstance(v, transformers.tokenization_utils_base.BatchEncoding) or isinstance(v, dict):
176
+ for key, value in v.items():
177
+ if value is not None:
178
+ if isinstance(value, list):
179
+ inputs[k][key] = [x.to(model.device) for x in value]
180
+ else:
181
+ inputs[k][key] = value.to(model.device)
182
+ else:
183
+ raise ValueError(f"Invalid input type: {type(v)}")
184
+
185
+ from transformers import TextIteratorStreamer
186
+ from threading import Thread
187
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
188
+ kwargs["streamer"] = streamer
189
+ inputs.update(kwargs)
190
+ thread = Thread(target=model.generate, kwargs=inputs)
191
+ thread.start()
192
+ generator = []
193
+ with Capturing() as print_kw:
194
+ for _output in streamer:
195
+ history[-1]["text"] += _output
196
+ generator.append((history[-1]["text"], history))
197
+ # yield history[-1]["text"], history
198
+ return generator, print_kw, inputs
199
+