Ye27 commited on
Commit
7b58103
1 Parent(s): c20b8c1

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - multimodal
5
+ - text
6
+ - image
7
+ - image-to-text
8
+ license: mit
9
+ datasets:
10
+ - HuggingFaceM4/OBELICS
11
+ - laion/laion2B-en
12
+ - coyo-700m
13
+ - mmc4
14
+ pipeline_tag: text-generation
15
+ inference: true
16
+ ---
17
+
18
+ <br>
19
+ <p align="center">
20
+ <img src="assets/infimm-logo.webp" alt="InfiMM-logo" width="400"></a>
21
+ </p>
22
+ <br>
23
+
24
+ # InfiMM
25
+
26
+ InfiMM, inspired by the Flamingo architecture, sets itself apart with unique training data and diverse large language models (LLMs). This approach allows InfiMM to maintain the core strengths of Flamingo while offering enhanced capabilities. As the premier open-sourced variant in this domain, InfiMM excels in accessibility and adaptability, driven by community collaboration. It's more than an emulation of Flamingo; it's an innovation in visual language processing.
27
+
28
+ Our model is another attempt to produce the result reported in the paper "Flamingo: A Large-scale Visual Language Model for Multimodal Understanding" by DeepMind.
29
+ Compared with previous open-sourced attempts ([OpenFlamingo](https://github.com/mlfoundations/open_flamingo) and [IDEFIC](https://huggingface.co/blog/idefics)), InfiMM offers a more flexible models, allowing for a wide range of applications.
30
+ In particular, InfiMM integrates the latest LLM models into VLM domain the reveals the impact of LLMs with different scales and architectures.
31
+
32
+ Please note that InfiMM is currently in beta stage and we are continuously working on improving it.
33
+
34
+ ## Model Details
35
+
36
+ - **Developed by**: Institute of Automation, Chinese Academy of Sciences and ByteDance
37
+ - **Model Type**: Visual Language Model (VLM)
38
+ - **Language**: English
39
+ - **LLMs**: [Zephyr](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta), [LLaMA2-13B](https://ai.meta.com/llama/), [Vicuna-13B](https://huggingface.co/lmsys/vicuna-13b-v1.5)
40
+ - **Vision Model**: [EVA CLIP](https://huggingface.co/QuanSun/EVA-CLIP)
41
+ - **Language(s) (NLP):** en
42
+ - **License:** see [License section](#license)
43
+ <!---
44
+ - **Parent Models:** [QuanSun/EVA-CLIP](https://huggingface.co/QuanSun/EVA-CLIP/blob/main/EVA02_CLIP_L_336_psz14_s6B.pt) and [HuggingFaceH4/zephyr-7b--beta ta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta)
45
+ -->
46
+
47
+ ## Model Family
48
+
49
+ Our model consists of several different model. Please see the details below.
50
+ | Model | LLM | Vision Encoder | IFT |
51
+ | ---------------------- | -------------- | -------------- | --- |
52
+ | InfiMM-Zephyr | Zehpyr-7B-beta | ViT-L-336 | No |
53
+ | InfiMM-Llama-13B | Llama2-13B | ViT-G-224 | No |
54
+ | InfiMM-Vicuna-13B | Vicuna-13B | ViT-E-224 | No |
55
+ | InfiMM-Zephyr-Chat | Zehpyr-7B-beta | ViT-L-336 | Yes |
56
+ | InfiMM-Llama-13B-Chat | Llama2-13B | ViT-G-224 | Yes |
57
+ | InfiMM-Vicuna-13B-Chat | Vicuna-13B | ViT-E-224 | Yes |
58
+
59
+ <!-- InfiMM-Zephyr-Chat is an light-weighted, open-source re-production of Flamingo-style Multimodal large language models with chat capability that takes sequences of interleaved images and texts as inputs and generates text outputs, with only 9B parameters.
60
+ -->
61
+
62
+ ## Demo
63
+
64
+ Will be released soon.
65
+
66
+ Our model adopts the Flamingo architecture, leveraging EVA CLIP as the visual encoder and employing LLaMA2, Vicuna, and Zephyr as language models. The visual and language modalities are connected through a Cross Attention module.
67
+
68
+ ## Quickstart
69
+
70
+ Use the code below to get started with the base model:
71
+ ```python
72
+ import torch
73
+ from transformers import AutoModelForCausalLM, AutoProcessor
74
+
75
+
76
+ processor = AutoProcessor.from_pretrained("InfiMM/infimm-zephyr", trust_remote_code=True)
77
+
78
+ prompts = [
79
+ {
80
+ "role": "user",
81
+ "content": [
82
+ {"image": "assets/infimm-logo.webp"},
83
+ "Please explain this image to me.",
84
+ ],
85
+ }
86
+ ]
87
+ inputs = processor(prompts)
88
+
89
+ # use bf16
90
+ model = AutoModelForCausalLM.from_pretrained(
91
+ "InfiMM/infimm-zephyr",
92
+ local_files_only=True,
93
+ torch_dtype=torch.bfloat16,
94
+ trust_remote_code=True,
95
+ ).eval()
96
+
97
+
98
+ inputs = inputs.to(model.device)
99
+ inputs["batch_images"] = inputs["batch_images"].to(torch.bfloat16)
100
+ generated_ids = model.generate(
101
+ **inputs,
102
+ min_generation_length=0,
103
+ max_generation_length=256,
104
+ )
105
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
106
+ print(generated_text)
107
+ ```
108
+
109
+ ## Training Details
110
+
111
+ We employed three stages to train our model: pretraining (PT), multi-task training (MTT), and instruction finetuning (IFT). Refer to the table below for detailed configurations in each stage. Due to significant noise in the pretraining data, we aimed to enhance the model's accuracy by incorporating higher-quality data. In the multi-task training (MTT) phase, we utilized substantial training data from diverse datasets. However, as the answer in these data mainly consisted of single words or phrases, the model's conversational ability was limited. Therefore, in the third stage, we introduced a considerable amount of image-text dialogue data (llava665k) for fine-tuning the model's instructions.
112
+
113
+ ### Pretraining (PT)
114
+
115
+ We follow similar training procedures used in [IDEFICS](https://huggingface.co/HuggingFaceM4/idefics-9b-instruct/blob/main/README.md).
116
+
117
+ The model is trained on a mixture of image-text pairs and unstructured multimodal web documents. All data are from public sources. Many image URL links are expired, we are capable of only downloading partial samples. We filter low quality data, here are resulting data we used:
118
+
119
+ | Data Source | Type of Data | Number of Tokens in Source | Number of Images in Source | Number of Samples | Epochs |
120
+ | ---------------------------------------------------------------- | ------------------------------------- | -------------------------- | -------------------------- | ----------------- | ------ |
121
+ | [OBELICS](https://huggingface.co/datasets/HuggingFaceM4/OBELICS) | Unstructured Multimodal Web Documents | - | - | 101M | 1 |
122
+ | [MMC4](https://github.com/allenai/mmc4) | Unstructured Multimodal Web Documents | - | - | 53M | 1 |
123
+ | [LAION](https://huggingface.co/datasets/laion/laion2B-en) | Image-Text Pairs | - | 115M | 115M | 1 |
124
+ | [COYO](https://github.com/kakaobrain/coyo-dataset) | Image-Text Pairs | - | 238M | 238M | 1 |
125
+ | [LAION-COCO](https://laion.ai/blog/laion-coco/) | Image-Text Pairs | - | 140M | 140M | 1 |
126
+ | [PMD\*](https://huggingface.co/datasets/facebook/pmd) | Image-Text Pairs | - | 20M | 1 |
127
+
128
+ \*PMD is only used in models with 13B LLMs, not the 7B Zephyr model.
129
+
130
+ During pretraining of interleaved image text sample, we apply masked cross-attention, however, we didn't strictly follow Flamingo, which alternate attention of image to its previous text or later text by change of 0.5.
131
+
132
+ We use the following hyper parameters:
133
+ | Categories | Parameters | Value |
134
+ | ------------------------ | -------------------------- | -------------------- |
135
+ | Perceiver Resampler | Number of Layers | 6 |
136
+ | | Number of Latents | 64 |
137
+ | | Number of Heads | 16 |
138
+ | | Resampler Head Dimension | 96 |
139
+ | Training | Sequence Length | 384 (13B) / 792 (7B) |
140
+ | | Effective Batch Size | 40\*128 |
141
+ | | Max Images per Sample | 6 |
142
+ | | Weight Decay | 0.1 |
143
+ | | Optimizer | Adam(0.9, 0.999) |
144
+ | | Gradient Accumulation Step | 2 |
145
+ | Learning Rate | Initial Max | 1e-4 |
146
+ | | Decay Schedule | Constant |
147
+ | | Warmup Step rate | 0.005 |
148
+ | Large-scale Optimization | Gradient Checkpointing | False |
149
+ | | Precision | bf16 |
150
+ | | ZeRO Optimization | Stage 2 |
151
+
152
+ ### Multi-Task Training (MTT)
153
+
154
+ Here we use mix_cap_vqa to represent the mixed training set from COCO caption, TextCap, VizWiz Caption, VQAv2, OKVQA, VizWiz VQA, TextVQA, OCRVQA, STVQA, DocVQA, GQA and ScienceQA-image. For caption, we add prefix such as "Please describe the image." before the question. And for QA, we add "Answer the question using a single word or phrase.". Specifically, for VizWiz VQA, we use "When the provided information is insufficient, respond with 'Unanswerable'. Answer the question using a single word or phrase.". While for ScienceQA-image, we use "Answer with the option's letter from the given choices directly."
155
+
156
+ ### Instruction Fine-Tuning (IFT)
157
+
158
+ For instruction fine-tuning stage, we use the recently released [LLaVA-MIX-665k](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/tree/main).
159
+
160
+ We use the following hyper parameters:
161
+ | Categories | Parameters | Value |
162
+ | ------------------------ | -------------------------- | -------------------- |
163
+ | Perceiver Resampler | Number of Layers | 6 |
164
+ | | Number of Latents | 64 |
165
+ | | Number of Heads | 16 |
166
+ | | Resampler Head Dimension | 96 |
167
+ | Training | Sequence Length | 384 (13B) / 792 (7B) |
168
+ | | Effective Batch Size | 64 |
169
+ | | Max Images per Sample | 6 |
170
+ | | Weight Decay | 0.1 |
171
+ | | Optimizer | Adam(0.9, 0.999) |
172
+ | | Gradient Accumulation Step | 2 |
173
+ | Learning Rate | Initial Max | 1e-5 |
174
+ | | Decay Schedule | Constant |
175
+ | | Warmup Step rate | 0.005 |
176
+ | Large-scale Optimization | Gradient Checkpointing | False |
177
+ | | Precision | bf16 |
178
+ | | ZeRO Optimization | Stage 2 |
179
+
180
+ During IFT, similar to pretrain, we keep ViT and LLM frozen for both chat-based LLM (Vicuna and Zephyr). For Llama model, we keep LLM trainable during the IFT stage. We also apply chat-template to process the training samples.
181
+
182
+ ## Evaluation
183
+
184
+ ### PreTraining Evaluation
185
+
186
+ We evaluate the pretrained models on the following downstream tasks: Image Captioning and VQA. We also compare with our results with [IDEFICS](https://huggingface.co/blog/idefics).
187
+
188
+ | Model | Shots | COCO CIDEr | Flickr30K CIDEr | VQA v2 Acc | TextVQA Acc | OK-VQA Acc |
189
+ | ----------------- | ----- | ---------- | --------------- | ---------- | ----------- | ---------- |
190
+ | IDEFICS-9B | 0 | 46 | 27.3 | 50.9 | 25.9 | 38.4 |
191
+ | | 4 | 93 | 59.7 | 55.4 | 27.6 | 45.5 |
192
+ | IDEFICS-80B | 0 | 91.8 | 53.7 | 60 | 30.9 | 45.2 |
193
+ | | 4 | 110.3 | 73.7 | 64.6 | 34.4 | 52.4 |
194
+ | InfiMM-Zephyr-7B | 0 | 78.8 | 60.7 | 33.7 | 15.2 | 17.1 |
195
+ | | 4 | 108.6 | 71.9 | 59.1 | 34.3 | 50.5 |
196
+ | InfiMM-Llama2-13B | 0 | 85.4 | 54.6 | 51.6 | 24.2 | 26.4 |
197
+ | | 4 | 125.2 | 87.1 | 66.1 | 38.2 | 55.5 |
198
+ | InfiMM-Vicuna13B | 0 | 69.6 | 49.6 | 60.4 | 32.8 | 49.2 |
199
+ | | 4 | 118.1 | 81.4 | 64.2 | 38.4 | 53.7 |
200
+
201
+ ### IFT Evaluation
202
+
203
+ In our analysis, we concentrate on two primary benchmarks for evaluating MLLMs: 1) Multi-choice Question Answering (QA) and 2) Open-ended Evaluation. We've observed that the evaluation metrics for tasks like Visual Question Answering (VQA) and Text-VQA are overly sensitive to exact answer matches. This approach can be misleading, particularly when models provide synonymous but technically accurate responses. Therefore, these metrics have been omitted from our comparison for a more precise assessment. The evaluation results are shown in the table below.
204
+
205
+ | Model | ScienceQA-Img | MME | MM-VET | InfiMM-Eval | MMbench | MMMU-Val | MMMU-Test |
206
+ | ------------------- | ------------- | --------------------- | ------ | ------------ | ------- | -------- | --------- |
207
+ | Otter-9B | - | 1292/306 | 24.6 | 32.2 | - | 22.69 | - |
208
+ | IDEFICS-9B-Instruct | 60.6 | -/- | - | - | - | 24.53 | - |
209
+ | InfiMM-Zephyr-7B | 71.1 | P: 1406<br>C:327 | 32.8 | 36.0 | 59.7 | 39.4 | 35.5 |
210
+ | InfiMM-Llama-13b | 73.0 | P: 1444.5<br>C: 337.6 | 39.2 | 0.4559/0.414 | 66.4 | 39.1 | 35.2 |
211
+ | InfiMM-Vicuna-13B | 74.0 | P: 1461.2<br>C: 323.5 | 36.0 | 40.0 | 66.7 | 37.6 | 34.6 |
212
+
213
+ <!--
214
+ | Model | TextVQA (no ocr) | OK-VQA | VQAv2 | ScienceQA-Img | GQA | MME | MM-VET | MMMU | InfiMM-Eval | MMbench |
215
+ | ----------------- | ---------------- | ------ | ----- | ------------- | ---- | --------------------- | ------ | ---- | ------------ | ------- |
216
+ | InfiMM-Zephyr-7B | 36.7 | 55.4 | / | 71.1 | | P: 1406<br>C:327 | 32.8 | 39.4 | 36.0 | 59.7 |
217
+ | InfiMM-Llama-13b | 44.6 | 62.3 | 78.5 | 73.0 | 61.2 | P: 1444.5<br>C: 337.6 | 39.2 | 39.1 | 0.4559/0.414 | 66.4 |
218
+ | InfiMM-Vicuna-13B | 41.7 | 58.5 | 73.0 | 74.0 | 58.5 | P: 1461.2<br>C: 323.5 | 36.0 | 37.6 | 40.0 | 66.7 |
219
+
220
+ We select checkpoint after 1 epoch instruction fine-tuning.
221
+
222
+ | Model | <nobr>ScienceQA <br>acc.</nobr> | <nobr>MME <br>P/C</nobr> | <nobr>MM-Vet</nobr> | <nobr>InfiMM-Eval</nobr> | <nobr>MMMU (val)</nobr> |
223
+ | :------------------ | ------------------------------: | -----------------------: | ------------------: | -----------------------: | ----------------------: |
224
+ | Otter-9B | - | 1292/306 | 24.6 | 22.69 | 32.2 |
225
+ | IDEFICS-9B-Instruct | 60.6 | -/- | - | 24.53 | - |
226
+ | InfiMM-Zephyr-Chat | 71.14 | 1406/327 | 33.3 | 35.97 | 39.4 |
227
+ -->
228
+
229
+ <details>
230
+ <summary>Leaderboard Details</summary>
231
+
232
+ <img src="assets/infimm-zephyr-mmmu-val.jpeg" style="zoom:40%;" />
233
+ <br>MMMU-Val split results<br>
234
+ <img src="assets/infimm-zephyr-mmmu-test.jpeg" style="zoom:40%;" />
235
+ <br>MMMU-Test split results<br>
236
+
237
+ </details>
238
+
239
+ ## Citation
240
+
241
+ @misc{infimm-v1,
242
+ title={InfiMM: },
243
+ author={InfiMM Team},
244
+ year={2024}
245
+ }
246
+
247
+ ## License
248
+
249
+ <a href="https://creativecommons.org/licenses/by-nc/4.0/deed.en">
250
+ <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/d/d3/Cc_by-nc_icon.svg/600px-Cc_by-nc_icon.svg.png" width="160">
251
+ </a>
252
+
253
+ This project is licensed under the **CC BY-NC 4.0**.
254
+
255
+ The copyright of the images belongs to the original authors.
256
+
257
+ See [LICENSE](LICENSE) for more information.
258
+
259
+ ## Contact Us
260
+
261
+ Please feel free to contact us via email [[email protected]]([email protected]) if you have any questions.
added_tokens.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "<image>": 32001,
3
+ "<|endofchunk|>": 32000
4
+ }
config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "architectures": [
4
+ "InfiMMVicunaModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_infimm_vicuna.InfiMMConfig",
8
+ "AutoModelForCausalLM": "modeling_infimm_vicuna.InfiMMVicunaModel"
9
+ },
10
+ "model_type": "infimm-vicuna",
11
+ "seq_length": 1024,
12
+ "tokenizer_type": "LlamaTokenizer",
13
+ "torch_dtype": "bfloat16",
14
+ "transformers_version": "4.34.0",
15
+ "use_cache": true,
16
+ "use_flash_attn": false,
17
+ "cross_attn_every_n_layers": 4,
18
+ "use_grad_checkpoint": false,
19
+ "freeze_llm": true,
20
+ "image_token_id": 32001,
21
+ "eoc_token_id": 32000,
22
+ "visual": {
23
+ "image_size": 224,
24
+ "layers": 64,
25
+ "width": 1792,
26
+ "head_width": 112,
27
+ "mlp_ratio": 8.571428571428571,
28
+ "patch_size": 14,
29
+ "eva_model_name": "eva-clip-4b-14-x",
30
+ "drop_path_rate": 0,
31
+ "xattn": false,
32
+ "postnorm": true,
33
+ "fusedLN": false,
34
+ "embed_dim": 1024,
35
+ "patch_dropout": 0
36
+ },
37
+ "language": {
38
+ "_name_or_path": "lmsys/vicuna-13b-v1.5",
39
+ "architectures": [
40
+ "LlamaForCausalLM"
41
+ ],
42
+ "bos_token_id": 1,
43
+ "eos_token_id": 2,
44
+ "hidden_act": "silu",
45
+ "hidden_size": 5120,
46
+ "initializer_range": 0.02,
47
+ "intermediate_size": 13824,
48
+ "max_length": 4096,
49
+ "max_position_embeddings": 4096,
50
+ "model_type": "llama",
51
+ "num_attention_heads": 40,
52
+ "num_hidden_layers": 40,
53
+ "num_key_value_heads": 40,
54
+ "pad_token_id": 0,
55
+ "pretraining_tp": 1,
56
+ "rms_norm_eps": 1e-05,
57
+ "rope_scaling": null,
58
+ "tie_word_embeddings": false,
59
+ "torch_dtype": "float16",
60
+ "transformers_version": "4.34.0",
61
+ "use_cache": true,
62
+ "vocab_size": 32002
63
+ }
64
+ }
configuration_infimm_vicuna.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ from transformers import PretrainedConfig
5
+
6
+
7
+ class InfiMMConfig(PretrainedConfig):
8
+ model_type = "infimm"
9
+
10
+ def __init__(
11
+ self,
12
+ model_type="infimm-vicuna",
13
+ seq_length=1024,
14
+ tokenizer_type="LlamaTokenizer",
15
+ torch_dtype="bfloat16",
16
+ transformers_version="4.34.0",
17
+ use_cache=True,
18
+ use_flash_attn=False,
19
+ cross_attn_every_n_layers=2,
20
+ use_grad_checkpoint=False,
21
+ freeze_llm=True,
22
+ visual=None,
23
+ language=None,
24
+ image_token_id=None,
25
+ eoc_token_id=None,
26
+ **kwargs,
27
+ ):
28
+ self.model_type = model_type
29
+ self.seq_length = seq_length
30
+ self.tokenizer_type = tokenizer_type
31
+ self.torch_dtype = torch_dtype
32
+ self.transformers_version = transformers_version
33
+ self.use_cache = use_cache
34
+ self.use_flash_attn = use_flash_attn
35
+ self.cross_attn_every_n_layers = cross_attn_every_n_layers
36
+ self.use_grad_checkpoint = use_grad_checkpoint
37
+ self.freeze_llm = freeze_llm
38
+ self.visual = visual
39
+ self.language = language
40
+ self.image_token_id = image_token_id
41
+ self.eoc_token_id = eoc_token_id
42
+ super().__init__(**kwargs)
eva_vit.py ADDED
@@ -0,0 +1,948 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from https://github.com/baaivision/EVA/blob/master/EVA-CLIP/rei/eva_clip/eva_vit_model.py
3
+ # --------------------------------------------------------
4
+ import logging
5
+ import math
6
+ import os
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ from math import pi
10
+ from typing import Optional, Tuple, Union
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+
17
+ if os.getenv("ENV_TYPE") == "deepspeed":
18
+ try:
19
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
20
+ except:
21
+ from torch.utils.checkpoint import checkpoint
22
+ else:
23
+ from torch.utils.checkpoint import checkpoint
24
+
25
+ try:
26
+ import xformers.ops as xops
27
+ except ImportError:
28
+ xops = None
29
+ print("Please 'pip install xformers'")
30
+
31
+
32
+ class PatchDropout(nn.Module):
33
+ """
34
+ https://arxiv.org/abs/2212.00794
35
+ """
36
+
37
+ def __init__(self, prob, exclude_first_token=True):
38
+ super().__init__()
39
+ assert 0 <= prob < 1.0
40
+ self.prob = prob
41
+ self.exclude_first_token = exclude_first_token # exclude CLS token
42
+
43
+ def forward(self, x):
44
+ if not self.training or self.prob == 0.0:
45
+ return x
46
+
47
+ if self.exclude_first_token:
48
+ cls_tokens, x = x[:, :1], x[:, 1:]
49
+ else:
50
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
51
+
52
+ batch = x.size()[0]
53
+ num_tokens = x.size()[1]
54
+
55
+ batch_indices = torch.arange(batch)
56
+ batch_indices = batch_indices[..., None]
57
+
58
+ keep_prob = 1 - self.prob
59
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
60
+
61
+ rand = torch.randn(batch, num_tokens)
62
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
63
+
64
+ x = x[batch_indices, patch_indices_keep]
65
+
66
+ if self.exclude_first_token:
67
+ x = torch.cat((cls_tokens, x), dim=1)
68
+
69
+ if self.training and os.getenv("RoPE") == "1":
70
+ return x, patch_indices_keep
71
+
72
+ return x
73
+
74
+
75
+ class DropPath(nn.Module):
76
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
77
+
78
+ def __init__(self, drop_prob=None):
79
+ super(DropPath, self).__init__()
80
+ self.drop_prob = drop_prob
81
+
82
+ def forward(self, x):
83
+ return drop_path(x, self.drop_prob, self.training)
84
+
85
+ def extra_repr(self) -> str:
86
+ return "p={}".format(self.drop_prob)
87
+
88
+
89
+ class Mlp(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_features,
93
+ hidden_features=None,
94
+ out_features=None,
95
+ act_layer=nn.GELU,
96
+ norm_layer=nn.LayerNorm,
97
+ drop=0.0,
98
+ subln=False,
99
+ ):
100
+ super().__init__()
101
+ out_features = out_features or in_features
102
+ hidden_features = hidden_features or in_features
103
+
104
+ self.fc1 = nn.Linear(in_features, hidden_features)
105
+ self.act = act_layer()
106
+
107
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
108
+
109
+ self.fc2 = nn.Linear(hidden_features, out_features)
110
+ self.drop = nn.Dropout(drop)
111
+
112
+ def forward(self, x):
113
+ x = self.fc1(x)
114
+ x = self.act(x)
115
+ # x = self.drop(x)
116
+ # commit this for the orignal BERT implement
117
+ x = self.ffn_ln(x)
118
+
119
+ x = self.fc2(x)
120
+ x = self.drop(x)
121
+ return x
122
+
123
+
124
+ class SwiGLU(nn.Module):
125
+ def __init__(
126
+ self,
127
+ in_features,
128
+ hidden_features=None,
129
+ out_features=None,
130
+ act_layer=nn.SiLU,
131
+ drop=0.0,
132
+ norm_layer=nn.LayerNorm,
133
+ subln=False,
134
+ ):
135
+ super().__init__()
136
+ out_features = out_features or in_features
137
+ hidden_features = hidden_features or in_features
138
+
139
+ self.w1 = nn.Linear(in_features, hidden_features)
140
+ self.w2 = nn.Linear(in_features, hidden_features)
141
+
142
+ self.act = act_layer()
143
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
144
+
145
+ self.w3 = nn.Linear(hidden_features, out_features)
146
+
147
+ self.drop = nn.Dropout(drop)
148
+
149
+ def forward(self, x):
150
+ x1 = self.w1(x)
151
+ x2 = self.w2(x)
152
+ hidden = self.act(x1) * x2
153
+ x = self.ffn_ln(hidden)
154
+ x = self.w3(x)
155
+ x = self.drop(x)
156
+ return x
157
+
158
+
159
+ class Attention(nn.Module):
160
+ def __init__(
161
+ self,
162
+ dim,
163
+ num_heads=8,
164
+ qkv_bias=False,
165
+ qk_scale=None,
166
+ attn_drop=0.0,
167
+ proj_drop=0.0,
168
+ window_size=None,
169
+ attn_head_dim=None,
170
+ xattn=False,
171
+ rope=None,
172
+ subln=False,
173
+ norm_layer=nn.LayerNorm,
174
+ ):
175
+ super().__init__()
176
+ self.num_heads = num_heads
177
+ head_dim = dim // num_heads
178
+ if attn_head_dim is not None:
179
+ head_dim = attn_head_dim
180
+ all_head_dim = head_dim * self.num_heads
181
+ self.scale = qk_scale or head_dim**-0.5
182
+
183
+ self.subln = subln
184
+ if self.subln:
185
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
186
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
187
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
188
+
189
+ else:
190
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
191
+
192
+ if qkv_bias:
193
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
194
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
195
+ else:
196
+ self.q_bias = None
197
+ self.v_bias = None
198
+
199
+ if window_size:
200
+ self.window_size = window_size
201
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
202
+ 2 * window_size[1] - 1
203
+ ) + 3
204
+ self.relative_position_bias_table = nn.Parameter(
205
+ torch.zeros(self.num_relative_distance, num_heads)
206
+ ) # 2*Wh-1 * 2*Ww-1, nH
207
+ # cls to token & token 2 cls & cls to cls
208
+
209
+ # get pair-wise relative position index for each token inside the window
210
+ coords_h = torch.arange(window_size[0])
211
+ coords_w = torch.arange(window_size[1])
212
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
213
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
214
+ relative_coords = (
215
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
216
+ ) # 2, Wh*Ww, Wh*Ww
217
+ relative_coords = relative_coords.permute(
218
+ 1, 2, 0
219
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
220
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
221
+ relative_coords[:, :, 1] += window_size[1] - 1
222
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
223
+ relative_position_index = torch.zeros(
224
+ size=(window_size[0] * window_size[1] + 1,) * 2,
225
+ dtype=relative_coords.dtype,
226
+ )
227
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
228
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
229
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
230
+ relative_position_index[0, 0] = self.num_relative_distance - 1
231
+
232
+ self.register_buffer("relative_position_index", relative_position_index)
233
+ else:
234
+ self.window_size = None
235
+ self.relative_position_bias_table = None
236
+ self.relative_position_index = None
237
+
238
+ self.attn_drop = nn.Dropout(attn_drop)
239
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
240
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
241
+ self.proj = nn.Linear(all_head_dim, dim)
242
+ self.proj_drop = nn.Dropout(proj_drop)
243
+ self.xattn = xattn
244
+ self.xattn_drop = attn_drop
245
+
246
+ self.rope = rope
247
+
248
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
249
+ B, N, C = x.shape
250
+ if self.subln:
251
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
252
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
253
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
254
+
255
+ q = q.reshape(B, N, self.num_heads, -1).permute(
256
+ 0, 2, 1, 3
257
+ ) # B, num_heads, N, C
258
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
259
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
260
+ else:
261
+ qkv_bias = None
262
+ if self.q_bias is not None:
263
+ qkv_bias = torch.cat(
264
+ (
265
+ self.q_bias,
266
+ torch.zeros_like(self.v_bias, requires_grad=False),
267
+ self.v_bias,
268
+ )
269
+ )
270
+
271
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
272
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
273
+ 2, 0, 3, 1, 4
274
+ ) # 3, B, num_heads, N, C
275
+ q, k, v = qkv[0], qkv[1], qkv[2]
276
+
277
+ if self.rope:
278
+ # slightly fast impl
279
+ q_t = q[:, :, 1:, :]
280
+ ro_q_t = self.rope(q_t)
281
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
282
+
283
+ k_t = k[:, :, 1:, :]
284
+ ro_k_t = self.rope(k_t)
285
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
286
+
287
+ if self.xattn:
288
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
289
+ k = k.permute(0, 2, 1, 3)
290
+ v = v.permute(0, 2, 1, 3)
291
+
292
+ x = xops.memory_efficient_attention(
293
+ q,
294
+ k,
295
+ v,
296
+ p=self.xattn_drop,
297
+ scale=self.scale,
298
+ )
299
+ x = x.reshape(B, N, -1)
300
+ x = self.inner_attn_ln(x)
301
+ x = self.proj(x)
302
+ x = self.proj_drop(x)
303
+ else:
304
+ q = q * self.scale
305
+ attn = q @ k.transpose(-2, -1)
306
+
307
+ if self.relative_position_bias_table is not None:
308
+ relative_position_bias = self.relative_position_bias_table[
309
+ self.relative_position_index.view(-1)
310
+ ].view(
311
+ self.window_size[0] * self.window_size[1] + 1,
312
+ self.window_size[0] * self.window_size[1] + 1,
313
+ -1,
314
+ ) # Wh*Ww,Wh*Ww,nH
315
+ relative_position_bias = relative_position_bias.permute(
316
+ 2, 0, 1
317
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
318
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
319
+
320
+ if rel_pos_bias is not None:
321
+ attn = attn + rel_pos_bias.type_as(attn)
322
+
323
+ if attn_mask is not None:
324
+ attn_mask = attn_mask.bool()
325
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
326
+
327
+ attn = attn.softmax(dim=-1)
328
+ attn = self.attn_drop(attn)
329
+
330
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
331
+ x = self.inner_attn_ln(x)
332
+ x = self.proj(x)
333
+ x = self.proj_drop(x)
334
+ return x
335
+
336
+
337
+ class Block(nn.Module):
338
+ def __init__(
339
+ self,
340
+ dim,
341
+ num_heads,
342
+ mlp_ratio=4.0,
343
+ qkv_bias=False,
344
+ qk_scale=None,
345
+ drop=0.0,
346
+ attn_drop=0.0,
347
+ drop_path=0.0,
348
+ init_values=None,
349
+ act_layer=nn.GELU,
350
+ norm_layer=nn.LayerNorm,
351
+ window_size=None,
352
+ attn_head_dim=None,
353
+ xattn=False,
354
+ rope=None,
355
+ postnorm=False,
356
+ subln=False,
357
+ naiveswiglu=False,
358
+ ):
359
+ super().__init__()
360
+ self.norm1 = norm_layer(dim)
361
+ self.attn = Attention(
362
+ dim,
363
+ num_heads=num_heads,
364
+ qkv_bias=qkv_bias,
365
+ qk_scale=qk_scale,
366
+ attn_drop=attn_drop,
367
+ proj_drop=drop,
368
+ window_size=window_size,
369
+ attn_head_dim=attn_head_dim,
370
+ xattn=xattn,
371
+ rope=rope,
372
+ subln=subln,
373
+ norm_layer=norm_layer,
374
+ )
375
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
376
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
377
+ self.norm2 = norm_layer(dim)
378
+ mlp_hidden_dim = int(dim * mlp_ratio)
379
+
380
+ if naiveswiglu:
381
+ self.mlp = SwiGLU(
382
+ in_features=dim,
383
+ hidden_features=mlp_hidden_dim,
384
+ subln=subln,
385
+ norm_layer=norm_layer,
386
+ )
387
+ else:
388
+ self.mlp = Mlp(
389
+ in_features=dim,
390
+ hidden_features=mlp_hidden_dim,
391
+ act_layer=act_layer,
392
+ subln=subln,
393
+ drop=drop,
394
+ )
395
+
396
+ if init_values is not None and init_values > 0:
397
+ self.gamma_1 = nn.Parameter(
398
+ init_values * torch.ones((dim)), requires_grad=True
399
+ )
400
+ self.gamma_2 = nn.Parameter(
401
+ init_values * torch.ones((dim)), requires_grad=True
402
+ )
403
+ else:
404
+ self.gamma_1, self.gamma_2 = None, None
405
+
406
+ self.postnorm = postnorm
407
+
408
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
409
+ if self.gamma_1 is None:
410
+ if self.postnorm:
411
+ x = x + self.drop_path(
412
+ self.norm1(
413
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
414
+ )
415
+ )
416
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
417
+ else:
418
+ x = x + self.drop_path(
419
+ self.attn(
420
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
421
+ )
422
+ )
423
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
424
+ else:
425
+ if self.postnorm:
426
+ x = x + self.drop_path(
427
+ self.gamma_1
428
+ * self.norm1(
429
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
430
+ )
431
+ )
432
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
433
+ else:
434
+ x = x + self.drop_path(
435
+ self.gamma_1
436
+ * self.attn(
437
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
438
+ )
439
+ )
440
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
441
+ return x
442
+
443
+
444
+ class PatchEmbed(nn.Module):
445
+ """Image to Patch Embedding"""
446
+
447
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
448
+ super().__init__()
449
+ img_size = to_2tuple(img_size)
450
+ patch_size = to_2tuple(patch_size)
451
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
452
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
453
+ self.img_size = img_size
454
+ self.patch_size = patch_size
455
+ self.num_patches = num_patches
456
+
457
+ self.proj = nn.Conv2d(
458
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
459
+ )
460
+
461
+ def forward(self, x, **kwargs):
462
+ B, C, H, W = x.shape
463
+ # FIXME look at relaxing size constraints
464
+ assert H == self.img_size[0] and W == self.img_size[1], (
465
+ f"Input image size ({H}*{W}) doesn't match model"
466
+ f" ({self.img_size[0]}*{self.img_size[1]})."
467
+ )
468
+ x = self.proj(x).flatten(2).transpose(1, 2)
469
+ return x
470
+
471
+
472
+ class RelativePositionBias(nn.Module):
473
+ def __init__(self, window_size, num_heads):
474
+ super().__init__()
475
+ self.window_size = window_size
476
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
477
+ 2 * window_size[1] - 1
478
+ ) + 3
479
+ self.relative_position_bias_table = nn.Parameter(
480
+ torch.zeros(self.num_relative_distance, num_heads)
481
+ ) # 2*Wh-1 * 2*Ww-1, nH
482
+ # cls to token & token 2 cls & cls to cls
483
+
484
+ # get pair-wise relative position index for each token inside the window
485
+ coords_h = torch.arange(window_size[0])
486
+ coords_w = torch.arange(window_size[1])
487
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
488
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
489
+ relative_coords = (
490
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
491
+ ) # 2, Wh*Ww, Wh*Ww
492
+ relative_coords = relative_coords.permute(
493
+ 1, 2, 0
494
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
495
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
496
+ relative_coords[:, :, 1] += window_size[1] - 1
497
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
498
+ relative_position_index = torch.zeros(
499
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
500
+ )
501
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
502
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
503
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
504
+ relative_position_index[0, 0] = self.num_relative_distance - 1
505
+
506
+ self.register_buffer("relative_position_index", relative_position_index)
507
+
508
+ def forward(self):
509
+ relative_position_bias = self.relative_position_bias_table[
510
+ self.relative_position_index.view(-1)
511
+ ].view(
512
+ self.window_size[0] * self.window_size[1] + 1,
513
+ self.window_size[0] * self.window_size[1] + 1,
514
+ -1,
515
+ ) # Wh*Ww,Wh*Ww,nH
516
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
517
+
518
+
519
+ class EVAVisionTransformer(nn.Module):
520
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
521
+
522
+ def __init__(
523
+ self,
524
+ img_size=224,
525
+ patch_size=16,
526
+ in_chans=3,
527
+ num_classes=1000,
528
+ embed_dim=768,
529
+ depth=12,
530
+ num_heads=12,
531
+ mlp_ratio=4.0,
532
+ qkv_bias=False,
533
+ qk_scale=None,
534
+ drop_rate=0.0,
535
+ attn_drop_rate=0.0,
536
+ drop_path_rate=0.0,
537
+ norm_layer=nn.LayerNorm,
538
+ init_values=None,
539
+ patch_dropout=0.0,
540
+ use_abs_pos_emb=True,
541
+ use_rel_pos_bias=False,
542
+ use_shared_rel_pos_bias=False,
543
+ rope=False,
544
+ use_mean_pooling=True,
545
+ init_scale=0.001,
546
+ grad_checkpointing=False,
547
+ xattn=False,
548
+ postnorm=False,
549
+ pt_hw_seq_len=16,
550
+ intp_freq=False,
551
+ naiveswiglu=False,
552
+ subln=False,
553
+ ):
554
+ super().__init__()
555
+ self.image_size = img_size
556
+ self.num_classes = num_classes
557
+ self.num_features = (
558
+ self.embed_dim
559
+ ) = embed_dim # num_features for consistency with other models
560
+
561
+ self.patch_embed = PatchEmbed(
562
+ img_size=img_size,
563
+ patch_size=patch_size,
564
+ in_chans=in_chans,
565
+ embed_dim=embed_dim,
566
+ )
567
+ num_patches = self.patch_embed.num_patches
568
+
569
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
570
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
571
+ if use_abs_pos_emb:
572
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
573
+ else:
574
+ self.pos_embed = None
575
+ self.pos_drop = nn.Dropout(p=drop_rate)
576
+
577
+ if use_shared_rel_pos_bias:
578
+ self.rel_pos_bias = RelativePositionBias(
579
+ window_size=self.patch_embed.patch_shape, num_heads=num_heads
580
+ )
581
+ else:
582
+ self.rel_pos_bias = None
583
+
584
+ if rope:
585
+ half_head_dim = embed_dim // num_heads // 2
586
+ hw_seq_len = img_size // patch_size
587
+ self.rope = VisionRotaryEmbeddingFast(
588
+ dim=half_head_dim,
589
+ pt_seq_len=pt_hw_seq_len,
590
+ ft_seq_len=hw_seq_len if intp_freq else None,
591
+ # patch_dropout=patch_dropout
592
+ )
593
+ else:
594
+ self.rope = None
595
+
596
+ self.naiveswiglu = naiveswiglu
597
+
598
+ dpr = [
599
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
600
+ ] # stochastic depth decay rule
601
+ self.use_rel_pos_bias = use_rel_pos_bias
602
+ self.blocks = nn.ModuleList(
603
+ [
604
+ Block(
605
+ dim=embed_dim,
606
+ num_heads=num_heads,
607
+ mlp_ratio=mlp_ratio,
608
+ qkv_bias=qkv_bias,
609
+ qk_scale=qk_scale,
610
+ drop=drop_rate,
611
+ attn_drop=attn_drop_rate,
612
+ drop_path=dpr[i],
613
+ norm_layer=norm_layer,
614
+ init_values=init_values,
615
+ window_size=(
616
+ self.patch_embed.patch_shape if use_rel_pos_bias else None
617
+ ),
618
+ xattn=xattn,
619
+ rope=self.rope,
620
+ postnorm=postnorm,
621
+ subln=subln,
622
+ naiveswiglu=naiveswiglu,
623
+ )
624
+ for i in range(depth)
625
+ ]
626
+ )
627
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
628
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
629
+ self.head = (
630
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
631
+ )
632
+
633
+ if self.pos_embed is not None:
634
+ trunc_normal_(self.pos_embed, std=0.02)
635
+
636
+ trunc_normal_(self.cls_token, std=0.02)
637
+ # trunc_normal_(self.mask_token, std=.02)
638
+
639
+ self.apply(self._init_weights)
640
+ self.fix_init_weight()
641
+
642
+ if isinstance(self.head, nn.Linear):
643
+ trunc_normal_(self.head.weight, std=0.02)
644
+ self.head.weight.data.mul_(init_scale)
645
+ self.head.bias.data.mul_(init_scale)
646
+
647
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
648
+ self.patch_dropout = (
649
+ PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
650
+ )
651
+
652
+ self.grad_checkpointing = grad_checkpointing
653
+
654
+ def fix_init_weight(self):
655
+ def rescale(param, layer_id):
656
+ param.div_(math.sqrt(2.0 * layer_id))
657
+
658
+ for layer_id, layer in enumerate(self.blocks):
659
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
660
+ if self.naiveswiglu:
661
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
662
+ else:
663
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
664
+
665
+ def get_cast_dtype(self) -> torch.dtype:
666
+ return self.blocks[0].mlp.fc2.weight.dtype
667
+
668
+ def _init_weights(self, m):
669
+ if isinstance(m, nn.Linear):
670
+ trunc_normal_(m.weight, std=0.02)
671
+ if m.bias is not None:
672
+ nn.init.constant_(m.bias, 0)
673
+ elif isinstance(m, nn.LayerNorm):
674
+ nn.init.constant_(m.bias, 0)
675
+ nn.init.constant_(m.weight, 1.0)
676
+
677
+ def get_num_layers(self):
678
+ return len(self.blocks)
679
+
680
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
681
+ assert (
682
+ unlocked_groups == 0
683
+ ), "partial locking not currently supported for this model"
684
+ for param in self.parameters():
685
+ param.requires_grad = False
686
+
687
+ @torch.jit.ignore
688
+ def set_grad_checkpointing(self, enable=True):
689
+ self.grad_checkpointing = enable
690
+
691
+ @torch.jit.ignore
692
+ def no_weight_decay(self):
693
+ return {"pos_embed", "cls_token"}
694
+
695
+ def get_classifier(self):
696
+ return self.head
697
+
698
+ def reset_classifier(self, num_classes, global_pool=""):
699
+ self.num_classes = num_classes
700
+ self.head = (
701
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
702
+ )
703
+
704
+ def forward_features(self, x, return_all_features=False, return_all_layers=False):
705
+ x = self.patch_embed(x)
706
+ batch_size, seq_len, _ = x.size()
707
+
708
+ cls_tokens = self.cls_token.expand(
709
+ batch_size, -1, -1
710
+ ) # stole cls_tokens impl from Phil Wang, thanks
711
+ x = torch.cat((cls_tokens, x), dim=1)
712
+ if self.pos_embed is not None:
713
+ x = x + self.pos_embed
714
+ x = self.pos_drop(x)
715
+
716
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
717
+ if os.getenv("RoPE") == "1":
718
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
719
+ x, patch_indices_keep = self.patch_dropout(x)
720
+ self.rope.forward = partial(
721
+ self.rope.forward, patch_indices_keep=patch_indices_keep
722
+ )
723
+ else:
724
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
725
+ x = self.patch_dropout(x)
726
+ else:
727
+ x = self.patch_dropout(x)
728
+
729
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
730
+
731
+ all_x = []
732
+ for blk in self.blocks:
733
+ if self.grad_checkpointing:
734
+ x = checkpoint(blk, x, (rel_pos_bias,))
735
+ else:
736
+ x = blk(x, rel_pos_bias=rel_pos_bias)
737
+
738
+ if return_all_layers:
739
+ all_x.append(x)
740
+
741
+ if not return_all_features:
742
+ x = self.norm(x)
743
+ if self.fc_norm is not None:
744
+ return self.fc_norm(x.mean(1))
745
+ else:
746
+ return x[:, 0]
747
+ return x if not return_all_layers else all_x
748
+
749
+ def forward(self, x, return_all_features=False, return_all_layers=False):
750
+ if return_all_features:
751
+ return self.forward_features(x, return_all_features, return_all_layers)
752
+ x = self.forward_features(x)
753
+ x = self.head(x)
754
+ return x
755
+
756
+
757
+ @dataclass
758
+ class CLIPVisionCfg:
759
+ layers: Union[Tuple[int, int, int, int], int] = 12
760
+ width: int = 768
761
+ head_width: int = 64
762
+ mlp_ratio: float = 4.0
763
+ patch_size: int = 16
764
+ image_size: Union[Tuple[int, int], int] = 224
765
+ ls_init_value: Optional[float] = None # layer scale initial value
766
+ patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
767
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
768
+ drop_path_rate: Optional[float] = None # drop path rate
769
+ timm_model_name: str = (
770
+ None # a valid model name overrides layers, width, patch_size
771
+ )
772
+ timm_model_pretrained: bool = (
773
+ False # use (imagenet) pretrained weights for named model
774
+ )
775
+ timm_pool: str = ( # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
776
+ "avg"
777
+ )
778
+ timm_proj: str = ( # linear projection for timm model output ('linear', 'mlp', '')
779
+ "linear"
780
+ )
781
+ timm_proj_bias: bool = False # enable bias final projection
782
+ eva_model_name: str = (
783
+ None # a valid eva model name overrides layers, width, patch_size
784
+ )
785
+ qkv_bias: bool = True
786
+ fusedLN: bool = False
787
+ embed_dim: int = 1024
788
+ xattn: bool = False
789
+ postnorm: bool = False
790
+ rope: bool = False
791
+ pt_hw_seq_len: int = 16 # 224/14
792
+ intp_freq: bool = False
793
+ naiveswiglu: bool = False
794
+ subln: bool = False
795
+
796
+
797
+ def broadcat(tensors, dim=-1):
798
+ num_tensors = len(tensors)
799
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
800
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
801
+ shape_len = list(shape_lens)[0]
802
+ dim = (dim + shape_len) if dim < 0 else dim
803
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
804
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
805
+ assert all(
806
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
807
+ ), "invalid dimensions for broadcastable concatentation"
808
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
809
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
810
+ expanded_dims.insert(dim, (dim, dims[dim]))
811
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
812
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
813
+ return torch.cat(tensors, dim=dim)
814
+
815
+
816
+ def rotate_half(x):
817
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
818
+ x1, x2 = x.unbind(dim=-1)
819
+ x = torch.stack((-x2, x1), dim=-1)
820
+ return rearrange(x, "... d r -> ... (d r)")
821
+
822
+
823
+ class VisionRotaryEmbedding(nn.Module):
824
+ def __init__(
825
+ self,
826
+ dim,
827
+ pt_seq_len,
828
+ ft_seq_len=None,
829
+ custom_freqs=None,
830
+ freqs_for="lang",
831
+ theta=10000,
832
+ max_freq=10,
833
+ num_freqs=1,
834
+ ):
835
+ super().__init__()
836
+ if custom_freqs:
837
+ freqs = custom_freqs
838
+ elif freqs_for == "lang":
839
+ freqs = 1.0 / (
840
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
841
+ )
842
+ elif freqs_for == "pixel":
843
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
844
+ elif freqs_for == "constant":
845
+ freqs = torch.ones(num_freqs).float()
846
+ else:
847
+ raise ValueError(f"unknown modality {freqs_for}")
848
+
849
+ if ft_seq_len is None:
850
+ ft_seq_len = pt_seq_len
851
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
852
+
853
+ freqs_h = torch.einsum("..., f -> ... f", t, freqs)
854
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
855
+
856
+ freqs_w = torch.einsum("..., f -> ... f", t, freqs)
857
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
858
+
859
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
860
+
861
+ self.register_buffer("freqs_cos", freqs.cos())
862
+ self.register_buffer("freqs_sin", freqs.sin())
863
+
864
+ logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
865
+
866
+ def forward(self, t, start_index=0):
867
+ rot_dim = self.freqs_cos.shape[-1]
868
+ end_index = start_index + rot_dim
869
+ assert rot_dim <= t.shape[-1], (
870
+ f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in"
871
+ f" all the positions {rot_dim}"
872
+ )
873
+ t_left, t, t_right = (
874
+ t[..., :start_index],
875
+ t[..., start_index:end_index],
876
+ t[..., end_index:],
877
+ )
878
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
879
+
880
+ return torch.cat((t_left, t, t_right), dim=-1)
881
+
882
+
883
+ class VisionRotaryEmbeddingFast(nn.Module):
884
+ def __init__(
885
+ self,
886
+ dim,
887
+ pt_seq_len,
888
+ ft_seq_len=None,
889
+ custom_freqs=None,
890
+ freqs_for="lang",
891
+ theta=10000,
892
+ max_freq=10,
893
+ num_freqs=1,
894
+ patch_dropout=0.0,
895
+ ):
896
+ super().__init__()
897
+ if custom_freqs:
898
+ freqs = custom_freqs
899
+ elif freqs_for == "lang":
900
+ freqs = 1.0 / (
901
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
902
+ )
903
+ elif freqs_for == "pixel":
904
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
905
+ elif freqs_for == "constant":
906
+ freqs = torch.ones(num_freqs).float()
907
+ else:
908
+ raise ValueError(f"unknown modality {freqs_for}")
909
+
910
+ if ft_seq_len is None:
911
+ ft_seq_len = pt_seq_len
912
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
913
+
914
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
915
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
916
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
917
+
918
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
919
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
920
+
921
+ self.patch_dropout = patch_dropout
922
+
923
+ self.register_buffer("freqs_cos", freqs_cos)
924
+ self.register_buffer("freqs_sin", freqs_sin)
925
+
926
+ logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
927
+
928
+ def forward(self, t, patch_indices_keep=None):
929
+ if patch_indices_keep is not None:
930
+ batch = t.size()[0]
931
+ batch_indices = torch.arange(batch)
932
+ batch_indices = batch_indices[..., None]
933
+
934
+ freqs_cos = repeat(
935
+ self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]
936
+ )
937
+ freqs_sin = repeat(
938
+ self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]
939
+ )
940
+
941
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
942
+ freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
943
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
944
+ freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
945
+
946
+ return t * freqs_cos + rotate_half(t) * freqs_sin
947
+
948
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
flamingo.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import torch
3
+ from einops import rearrange
4
+ from torch import nn
5
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
6
+
7
+ from .helpers import PerceiverResampler
8
+
9
+
10
+ def unwrap_fsdp(m):
11
+ if isinstance(m, FSDP):
12
+ return unwrap_fsdp(m.module)
13
+ return m
14
+
15
+
16
+ def accepts_parameter(func, parameter_name):
17
+ signature = inspect.signature(func)
18
+ return parameter_name in signature.parameters
19
+
20
+
21
+ class Flamingo(nn.Module):
22
+ def __init__(
23
+ self,
24
+ vision_encoder: nn.Module,
25
+ lang_encoder: nn.Module,
26
+ eoc_token_id: int,
27
+ media_token_id: int,
28
+ vis_dim: int,
29
+ cross_attn_every_n_layers: int = 1,
30
+ gradient_checkpointing: bool = False,
31
+ enable_init_network_params: bool = False,
32
+ initializer_range: float = 0.02,
33
+ ):
34
+ """
35
+ Args:
36
+ vision_encoder (nn.Module): HF CLIPModel
37
+ lang_encoder (nn.Module): HF causal language model
38
+ eoc_token_id (int): Token id for <|endofchunk|>
39
+ media_token_id (int): Token id for <image>
40
+ vis_dim (int): Dimension of the visual features.
41
+ Visual features are projected to match this shape along the last dimension.
42
+ cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
43
+ """
44
+ super().__init__()
45
+ self.eoc_token_id = eoc_token_id
46
+ self.media_token_id = media_token_id
47
+ self.vis_dim = vis_dim
48
+ if hasattr(lang_encoder.config, "d_model"):
49
+ self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
50
+ else:
51
+ self.lang_dim = lang_encoder.config.hidden_size
52
+
53
+ self.vision_encoder = (
54
+ vision_encoder.visual
55
+ if hasattr(vision_encoder, "visual")
56
+ else vision_encoder
57
+ )
58
+ self.perceiver = PerceiverResampler(
59
+ dim=self.vis_dim,
60
+ enable_init_network_params=enable_init_network_params,
61
+ initializer_range=initializer_range,
62
+ gradient_checkpointing=gradient_checkpointing,
63
+ )
64
+ self.lang_encoder = lang_encoder
65
+ self.lang_encoder.init_flamingo(
66
+ media_token_id=media_token_id,
67
+ lang_hidden_size=self.lang_dim,
68
+ vis_hidden_size=self.vis_dim,
69
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
70
+ gradient_checkpointing=gradient_checkpointing,
71
+ enable_init_network_params=enable_init_network_params,
72
+ initializer_range=initializer_range,
73
+ )
74
+ self._use_gradient_checkpointing = gradient_checkpointing
75
+ self.perceiver._use_gradient_checkpointing = gradient_checkpointing
76
+
77
+ def forward(
78
+ self,
79
+ vision_x: torch.Tensor,
80
+ lang_x: torch.Tensor,
81
+ attention_mask: torch.Tensor = None,
82
+ labels: torch.Tensor = None,
83
+ clear_conditioned_layers: bool = True,
84
+ past_key_values=None,
85
+ use_cache: bool = False,
86
+ ):
87
+ """
88
+ Forward pass of Flamingo.
89
+
90
+ Args:
91
+ vision_x (torch.Tensor): Vision input
92
+ shape (B, T_img, F, C, H, W) with F=1
93
+ lang_x (torch.Tensor): Language input ids
94
+ shape (B, T_txt)
95
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
96
+ labels (torch.Tensor, optional): Labels. Defaults to None.
97
+ clear_conditioned_layers: if True, clear the conditioned layers
98
+ once the foward pass is completed. Set this to false if the
99
+ same set of images will be reused in another subsequent
100
+ forward pass.
101
+ past_key_values: pre-computed values to pass to language model.
102
+ See past_key_values documentation in Hugging Face
103
+ CausalLM models.
104
+ use_cache: whether to use cached key values. See use_cache
105
+ documentation in Hugging Face CausalLM models.
106
+ """
107
+ assert (
108
+ self.lang_encoder.initialized_flamingo
109
+ ), "Flamingo layers are not initialized. Please call `init_flamingo` first."
110
+
111
+ assert (
112
+ self.lang_encoder._use_cached_vision_x or vision_x is not None
113
+ ), "Must provide either vision_x or have precached media using cache_media()."
114
+
115
+ if self.lang_encoder._use_cached_vision_x:
116
+ # Case: use cached; vision_x should be cached and other
117
+ # vision-related inputs should not be provided.
118
+ assert vision_x is None, (
119
+ "Expect vision_x to be None when media has been cached using"
120
+ " cache_media(). Try uncache_media() first."
121
+ )
122
+ assert self.lang_encoder.is_conditioned()
123
+
124
+ else:
125
+ # Case: do not use caching (i.e. this is a standard forward pass);
126
+ self._encode_vision_x(vision_x=vision_x)
127
+ self._condition_media_locations(input_ids=lang_x)
128
+
129
+ output = self.lang_encoder(
130
+ input_ids=lang_x,
131
+ attention_mask=attention_mask,
132
+ labels=labels,
133
+ past_key_values=past_key_values,
134
+ use_cache=use_cache,
135
+ )
136
+
137
+ if clear_conditioned_layers:
138
+ self.lang_encoder.clear_conditioned_layers()
139
+
140
+ return output
141
+
142
+ def generate(
143
+ self,
144
+ vision_x: torch.Tensor,
145
+ lang_x: torch.Tensor,
146
+ attention_mask: torch.Tensor = None,
147
+ **kwargs,
148
+ ):
149
+ """
150
+ Generate text conditioned on vision and language inputs.
151
+
152
+ Args:
153
+ vision_x (torch.Tensor): Vision input
154
+ shape (B, T_img, F, C, H, W)
155
+ images in the same chunk are collated along T_img, and frames are collated along F
156
+ currently only F=1 is supported (single-frame videos)
157
+ lang_x (torch.Tensor): Language input
158
+ shape (B, T_txt)
159
+ **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs:
160
+ max_length (int, optional): Maximum length of the output. Defaults to None.
161
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
162
+ num_beams (int, optional): Number of beams. Defaults to 1.
163
+ max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
164
+ temperature (float, optional): Temperature. Defaults to 1.0.
165
+ top_k (int, optional): Top k. Defaults to 50.
166
+ top_p (float, optional): Top p. Defaults to 1.0.
167
+ no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
168
+ length_penalty (float, optional): Length penalty. Defaults to 1.0.
169
+ num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
170
+ do_sample (bool, optional): Do sample. Defaults to False.
171
+ early_stopping (bool, optional): Early stopping. Defaults to False.
172
+ Returns:
173
+ torch.Tensor: lang_x with generated tokens appended to it
174
+ """
175
+ num_beams = kwargs.pop("num_beams", 1)
176
+ if num_beams > 1:
177
+ vision_x = vision_x.repeat_interleave(num_beams, dim=0)
178
+
179
+ self.lang_encoder._use_cached_vision_x = True
180
+ self._encode_vision_x(vision_x=vision_x)
181
+
182
+ # eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
183
+ output = self.lang_encoder.generate(
184
+ input_ids=lang_x,
185
+ attention_mask=attention_mask,
186
+ # eos_token_id=eos_token_id,
187
+ num_beams=num_beams,
188
+ **kwargs,
189
+ )
190
+
191
+ self.lang_encoder.clear_conditioned_layers()
192
+ self.lang_encoder._use_cached_vision_x = False
193
+ return output
194
+
195
+ def _encode_vision_x(self, vision_x: torch.Tensor):
196
+ """
197
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
198
+ Args:
199
+ vision_x (torch.Tensor): Vision input
200
+ shape (B, T_img, F, C, H, W)
201
+ Images in the same chunk are collated along T_img, and frames are collated along F
202
+ Currently only F=1 is supported (single-frame videos)
203
+
204
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
205
+ """
206
+
207
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
208
+ b, T, F = vision_x.shape[:3]
209
+ assert F == 1, "Only single frame supported"
210
+
211
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
212
+
213
+ with torch.no_grad():
214
+ module_to_inspect = unwrap_fsdp(self.vision_encoder)
215
+ if accepts_parameter(module_to_inspect.forward, "return_all_features"):
216
+ vision_x = self.vision_encoder(vision_x, return_all_features=True)
217
+ else:
218
+ vision_x = self.vision_encoder(vision_x)[1]
219
+
220
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
221
+ vision_x = self.perceiver(vision_x)
222
+
223
+ for layer in self.lang_encoder._get_decoder_layers():
224
+ layer.condition_vis_x(vision_x)
225
+
226
+ def _condition_media_locations(self, input_ids: torch.Tensor):
227
+ """
228
+ Compute the media token locations from lang_x and condition the language model on these.
229
+ Args:
230
+ input_ids (torch.Tensor): Language input
231
+ shape (B, T_txt)
232
+ """
233
+ media_locations = input_ids == self.media_token_id
234
+
235
+ for layer in self.lang_encoder._get_decoder_layers():
236
+ layer.condition_media_locations(media_locations)
237
+
238
+ def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor):
239
+ """
240
+ Pre-cache a prompt/sequence of images / text for log-likelihood evaluations.
241
+ All subsequent calls to forward() will generate attending to the LAST
242
+ image in vision_x.
243
+ This is not meant to be used to cache things for generate().
244
+ Args:
245
+ input_ids (torch.Tensor): Language input
246
+ shape (B, T_txt)
247
+ vision_x (torch.Tensor): Vision input
248
+ shape (B, T_img, F, C, H, W)
249
+ Images in the same chunk are collated along T_img, and frames are collated along F
250
+ Currently only F=1 is supported (single-frame videos)
251
+ """
252
+ self._encode_vision_x(vision_x=vision_x)
253
+ self._condition_media_locations(input_ids=input_ids)
254
+ self.lang_encoder._use_cached_vision_x = True
255
+
256
+ def uncache_media(self):
257
+ """
258
+ Clear all conditioning.
259
+ """
260
+ self.lang_encoder.clear_conditioned_layers()
261
+ self.lang_encoder._use_cached_vision_x = False
flamingo_lm.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .helpers import GatedCrossAttentionBlock
4
+ from .utils import getattr_recursive, setattr_recursive
5
+
6
+
7
+ class FlamingoLayer(nn.Module):
8
+ """
9
+ FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
10
+ """
11
+
12
+ def __init__(
13
+ self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False
14
+ ):
15
+ super().__init__()
16
+ self.gated_cross_attn_layer = gated_cross_attn_layer
17
+ self.decoder_layer = decoder_layer
18
+ self.vis_x = None
19
+ self.media_locations = None
20
+ if self.gated_cross_attn_layer is not None:
21
+ self.gated_cross_attn_layer._use_gradient_checkpointing = (
22
+ gradient_checkpointing
23
+ )
24
+ self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
25
+
26
+ def is_conditioned(self) -> bool:
27
+ """Check whether the layer is conditioned."""
28
+ return self.vis_x is not None and self.media_locations is not None
29
+
30
+ # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
31
+ def condition_vis_x(self, vis_x):
32
+ self.vis_x = vis_x
33
+
34
+ def condition_media_locations(self, media_locations):
35
+ self.media_locations = media_locations
36
+
37
+ def condition_use_cached_media(self, use_cached_media):
38
+ self.use_cached_media = use_cached_media
39
+
40
+ def forward(
41
+ self,
42
+ lang_x,
43
+ attention_mask=None,
44
+ **decoder_layer_kwargs,
45
+ ):
46
+ # Cross attention
47
+ if self.gated_cross_attn_layer is not None:
48
+ if self.vis_x is None:
49
+ raise ValueError("vis_x must be conditioned before forward pass")
50
+
51
+ if self.media_locations is None:
52
+ raise ValueError(
53
+ "media_locations must be conditioned before forward pass"
54
+ )
55
+
56
+ lang_x = self.gated_cross_attn_layer(
57
+ lang_x,
58
+ self.vis_x,
59
+ media_locations=self.media_locations,
60
+ use_cached_media=self.use_cached_media,
61
+ )
62
+
63
+ # Normal decoder layer
64
+ lang_x = self.decoder_layer(
65
+ lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
66
+ )
67
+ return lang_x
68
+
69
+
70
+ class FlamingoLMMixin(nn.Module):
71
+ """
72
+ Mixin to add cross-attention layers to a language model.
73
+ """
74
+
75
+ def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
76
+ self.decoder_layers_attr_name = decoder_layers_attr_name
77
+
78
+ def _get_decoder_layers(self):
79
+ return getattr_recursive(self, self.decoder_layers_attr_name)
80
+
81
+ def _set_decoder_layers(self, value):
82
+ setattr_recursive(self, self.decoder_layers_attr_name, value)
83
+
84
+ def init_flamingo(
85
+ self,
86
+ media_token_id,
87
+ lang_hidden_size,
88
+ vis_hidden_size,
89
+ cross_attn_every_n_layers,
90
+ *,
91
+ enable_init_network_params=False,
92
+ initializer_range=0.02,
93
+ gradient_checkpointing=False,
94
+ ):
95
+ """
96
+ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
97
+ """
98
+ self.old_decoder_blocks = self._get_decoder_layers()
99
+ self.gated_cross_attn_layers = nn.ModuleList(
100
+ [
101
+ (
102
+ GatedCrossAttentionBlock(
103
+ dim=lang_hidden_size,
104
+ dim_visual=vis_hidden_size,
105
+ ff_mult=4,
106
+ enable_init_network_params=enable_init_network_params,
107
+ initializer_range=initializer_range,
108
+ gradient_checkpointing=gradient_checkpointing,
109
+ )
110
+ if (layer_idx + 1) % cross_attn_every_n_layers == 0
111
+ else None
112
+ )
113
+ for layer_idx, _ in enumerate(self._get_decoder_layers())
114
+ ]
115
+ )
116
+ self.init_flamingo_layers(gradient_checkpointing)
117
+ self.media_token_id = media_token_id
118
+ self.initialized_flamingo = True
119
+ self._use_cached_vision_x = False
120
+
121
+ def init_flamingo_layers(self, gradient_checkpointing):
122
+ """
123
+ Re initializes the FlamingoLayers.
124
+ Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
125
+ """
126
+ self._set_decoder_layers(
127
+ nn.ModuleList(
128
+ [
129
+ FlamingoLayer(
130
+ gated_cross_attn_layer, decoder_layer, gradient_checkpointing
131
+ )
132
+ for gated_cross_attn_layer, decoder_layer in zip(
133
+ self.gated_cross_attn_layers, self.old_decoder_blocks
134
+ )
135
+ ]
136
+ )
137
+ )
138
+
139
+ def forward(self, input_ids, attention_mask, **kwargs):
140
+ """Condition the Flamingo layers on the media locations before forward()"""
141
+ if not self.initialized_flamingo:
142
+ raise ValueError(
143
+ "Flamingo layers are not initialized. Please call `init_flamingo`"
144
+ " first."
145
+ )
146
+
147
+ media_locations = input_ids == self.media_token_id
148
+
149
+ # if there are media already cached and we're generating and there are no media tokens in the input,
150
+ # we'll assume that ALL input tokens should attend to the last previous media that is cached.
151
+ # this is especially important for HF generate() compatibility, since generate() calls forward()
152
+ # repeatedly one token at a time (with no media tokens).
153
+ # without this check, the model would not attend to any images when generating (after the first token)
154
+ use_cached_media_locations = (
155
+ self._use_cached_vision_x
156
+ and self.is_conditioned()
157
+ and not media_locations.any()
158
+ )
159
+
160
+ for layer in self._get_decoder_layers():
161
+ if not use_cached_media_locations:
162
+ layer.condition_media_locations(media_locations)
163
+ layer.condition_use_cached_media(use_cached_media_locations)
164
+
165
+ # package arguments for the other parent's forward. since we don't know the order of the arguments,
166
+ # make them all kwargs
167
+ kwargs["input_ids"] = input_ids
168
+ kwargs["attention_mask"] = attention_mask
169
+ return super().forward(**kwargs) # Call the other parent's forward method
170
+
171
+ def is_conditioned(self) -> bool:
172
+ """Check whether all decoder layers are already conditioned."""
173
+ return all(l.is_conditioned() for l in self._get_decoder_layers())
174
+
175
+ def clear_conditioned_layers(self):
176
+ for layer in self._get_decoder_layers():
177
+ layer.condition_vis_x(None)
178
+ layer.condition_media_locations(None)
179
+ layer.condition_use_cached_media(None)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "max_new_tokens": 512,
4
+ "top_p": 0.6,
5
+ "temperature": 0.9,
6
+ "transformers_version": "4.34.0"
7
+ }
helpers.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on: https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from torch import einsum, nn
8
+
9
+ from einops_exts import rearrange_many
10
+
11
+ try:
12
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
13
+ except:
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+
21
+ def FeedForward(
22
+ dim,
23
+ mult=4,
24
+ enable_init_network_params=False,
25
+ initializer_range=0.02,
26
+ ):
27
+ inner_dim = int(dim * mult)
28
+ net = nn.Sequential(
29
+ nn.LayerNorm(dim),
30
+ nn.Linear(dim, inner_dim, bias=False),
31
+ nn.GELU(),
32
+ nn.Linear(inner_dim, dim, bias=False),
33
+ )
34
+
35
+ if enable_init_network_params:
36
+ # then start the initialization
37
+ net[0].weight.data.normal_(mean=0.0, std=initializer_range)
38
+ net[0].bias.data.zero_()
39
+ net[1].weight.data.normal_(mean=0.0, std=initializer_range)
40
+ net[3].weight.data.normal_(mean=0.0, std=initializer_range)
41
+ return net
42
+
43
+
44
+ class PerceiverAttention(nn.Module):
45
+ def __init__(
46
+ self,
47
+ *,
48
+ dim,
49
+ dim_head=64,
50
+ heads=8,
51
+ enable_init_network_params=False,
52
+ initializer_range=0.02,
53
+ ):
54
+ super().__init__()
55
+
56
+ self.scale = dim_head**-0.5
57
+ self.heads = heads
58
+ self.initializer_range = initializer_range
59
+
60
+ inner_dim = dim_head * heads
61
+
62
+ self.norm_media = nn.LayerNorm(dim)
63
+ self.norm_latents = nn.LayerNorm(dim)
64
+
65
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
66
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
67
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
68
+
69
+ if enable_init_network_params:
70
+ self.apply(self._init_weights)
71
+
72
+ def _init_weights(self, module):
73
+ if isinstance(module, nn.Linear):
74
+ # Slightly different from the TF version which uses truncated_normal for initialization
75
+ # cf https://github.com/pytorch/pytorch/pull/5617
76
+ module.weight.data.normal_(mean=0.0, std=self.initializer_range)
77
+ if module.bias is not None:
78
+ module.bias.data.zero_()
79
+
80
+ elif isinstance(module, nn.LayerNorm):
81
+ module.bias.data.zero_()
82
+ module.weight.data.fill_(1.0)
83
+
84
+ def forward(self, x, latents):
85
+ """
86
+ Args:
87
+ x (torch.Tensor): image features
88
+ shape (b, T, n1, D)
89
+ latent (torch.Tensor): latent features
90
+ shape (b, T, n2, D)
91
+ """
92
+ x = self.norm_media(x)
93
+ latents = self.norm_latents(latents.contiguous())
94
+
95
+ h = self.heads
96
+
97
+ q = self.to_q(latents)
98
+ kv_input = torch.cat((x, latents), dim=-2)
99
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
100
+
101
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
102
+ q = q * self.scale
103
+ # attention
104
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
105
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
106
+ attn = sim.softmax(dim=-1)
107
+
108
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
109
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
110
+ return self.to_out(out)
111
+
112
+
113
+ class PerceiverResampler(nn.Module):
114
+ def __init__(
115
+ self,
116
+ *,
117
+ dim,
118
+ depth=6,
119
+ dim_head=64,
120
+ heads=8,
121
+ num_latents=64,
122
+ max_num_media=None,
123
+ max_num_frames=None,
124
+ ff_mult=4,
125
+ enable_init_network_params=False,
126
+ initializer_range=0.02,
127
+ gradient_checkpointing=False,
128
+ ):
129
+ super().__init__()
130
+
131
+ self.gradient_checkpointing = gradient_checkpointing
132
+ self.initializer_range = initializer_range
133
+
134
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
135
+ self.frame_embs = (
136
+ nn.Parameter(torch.randn(max_num_frames, dim))
137
+ if exists(max_num_frames)
138
+ else None
139
+ )
140
+ self.media_time_embs = (
141
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
142
+ if exists(max_num_media)
143
+ else None
144
+ )
145
+
146
+ self.layers = nn.ModuleList([])
147
+
148
+ for _ in range(depth):
149
+ self.layers.append(
150
+ nn.ModuleList(
151
+ [
152
+ PerceiverAttention(
153
+ dim=dim,
154
+ dim_head=dim_head,
155
+ heads=heads,
156
+ enable_init_network_params=enable_init_network_params,
157
+ initializer_range=initializer_range,
158
+ ),
159
+ FeedForward(
160
+ dim=dim,
161
+ mult=ff_mult,
162
+ enable_init_network_params=enable_init_network_params,
163
+ initializer_range=initializer_range,
164
+ ),
165
+ ]
166
+ )
167
+ )
168
+ # Should this norm layer also change?
169
+ self.norm = nn.LayerNorm(dim)
170
+ if enable_init_network_params:
171
+ self.apply(self._init_weights)
172
+
173
+ def _init_weights(self, module):
174
+ if isinstance(module, nn.Linear):
175
+ # Slightly different from the TF version which uses truncated_normal for initialization
176
+ # cf https://github.com/pytorch/pytorch/pull/5617
177
+ module.weight.data.normal_(mean=0.0, std=self.initializer_range)
178
+ if module.bias is not None:
179
+ module.bias.data.zero_()
180
+
181
+ elif isinstance(module, nn.LayerNorm):
182
+ module.bias.data.zero_()
183
+ module.weight.data.fill_(1.0)
184
+
185
+ elif isinstance(module, nn.Parameter):
186
+ module.data.normal_(mean=0.0, std=self.initializer_range)
187
+
188
+ def forward(self, x):
189
+ """
190
+ Args:
191
+ x (torch.Tensor): image features
192
+ shape (b, T, F, v, D)
193
+ Returns:
194
+ shape (b, T, n, D) where n is self.num_latents
195
+ """
196
+
197
+ b, T, F, v = x.shape[:4]
198
+
199
+ # frame and media time embeddings
200
+ if exists(self.frame_embs):
201
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
202
+ x = x + frame_embs
203
+ x = rearrange(
204
+ x, "b T F v d -> b T (F v) d"
205
+ ) # flatten the frame and spatial dimensions
206
+ if exists(self.media_time_embs):
207
+ x = x + self.media_time_embs[:T]
208
+
209
+ # blocks
210
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
211
+ for attn, ff in self.layers:
212
+ if self.gradient_checkpointing and latents.requires_grad:
213
+ latents = checkpoint(attn, x, (latents)) + latents
214
+ latents = checkpoint(ff, latents) + latents
215
+ else:
216
+ latents = attn(x, latents) + latents
217
+ latents = ff(latents) + latents
218
+
219
+ return self.norm(latents)
220
+
221
+
222
+ # gated cross attention
223
+ class MaskedCrossAttention(nn.Module):
224
+ def __init__(
225
+ self,
226
+ *,
227
+ dim,
228
+ dim_visual,
229
+ dim_head=64,
230
+ heads=8,
231
+ only_attend_immediate_media=True,
232
+ enable_init_network_params=False,
233
+ initializer_range=0.02,
234
+ ):
235
+ super().__init__()
236
+ self.scale = dim_head**-0.5
237
+ self.heads = heads
238
+ self.initializer_range = initializer_range
239
+ inner_dim = dim_head * heads
240
+
241
+ self.norm = nn.LayerNorm(dim)
242
+
243
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
244
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
245
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
246
+
247
+ # whether for text to only attend to immediate preceding image, or all previous images
248
+ self.only_attend_immediate_media = only_attend_immediate_media
249
+
250
+ if enable_init_network_params:
251
+ self.apply(self._init_weights)
252
+
253
+ def _init_weights(self, module):
254
+ if isinstance(module, nn.Linear):
255
+ # Slightly different from the TF version which uses truncated_normal for initialization
256
+ # cf https://github.com/pytorch/pytorch/pull/5617
257
+ module.weight.data.normal_(mean=0.0, std=self.initializer_range)
258
+ if module.bias is not None:
259
+ module.bias.data.zero_()
260
+
261
+ elif isinstance(module, nn.LayerNorm):
262
+ module.bias.data.zero_()
263
+ module.weight.data.fill_(1.0)
264
+
265
+ def forward(self, x, media, media_locations=None, use_cached_media=False):
266
+ """
267
+ Args:
268
+ x (torch.Tensor): text features
269
+ shape (B, T_txt, D_txt)
270
+ media (torch.Tensor): image features
271
+ shape (B, T_img, n, D_img) where n is the dim of the latents
272
+ media_locations: boolean mask identifying the media tokens in x
273
+ shape (B, T_txt)
274
+ use_cached_media: bool
275
+ If true, treat all of x as if they occur after the last media
276
+ registered in media_locations. T_txt does not need to exactly
277
+ equal media_locations.shape[1] in this case
278
+ """
279
+
280
+ if not use_cached_media:
281
+ assert media_locations.shape[1] == x.shape[1], (
282
+ f"media_location.shape is {media_locations.shape} but x.shape is"
283
+ f" {x.shape}"
284
+ )
285
+
286
+ T_txt = x.shape[1]
287
+ _, T_img, n = media.shape[:3]
288
+ h = self.heads
289
+
290
+ x = self.norm(x.contiguous())
291
+ q = self.to_q(x)
292
+ media = rearrange(media, "b t n d -> b (t n) d")
293
+
294
+ k, v = self.to_kv(media).chunk(2, dim=-1)
295
+
296
+ if exists(media_locations):
297
+ media_time = torch.arange(T_img, device=x.device) + 1
298
+
299
+ if use_cached_media:
300
+ # text time is set to the last cached media location
301
+ text_time = repeat(
302
+ torch.count_nonzero(media_locations, dim=1),
303
+ "b -> b i",
304
+ i=T_txt,
305
+ )
306
+ else:
307
+ # at each boolean of True, increment the time counter (relative to media time)
308
+ text_time = media_locations.cumsum(dim=-1)
309
+
310
+ # text time must equal media time if only attending to most immediate image
311
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
312
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
313
+ text_to_media_mask = mask_op(
314
+ rearrange(text_time, "b i -> b 1 i 1"),
315
+ repeat(media_time, "j -> 1 1 1 (j n)", n=n),
316
+ )
317
+
318
+ if self.only_attend_immediate_media:
319
+ # any text without a preceding media needs to have attention zeroed out
320
+ text_without_media_mask = text_time == 0
321
+ text_without_media_mask = rearrange(
322
+ text_without_media_mask, "b i -> b 1 i 1"
323
+ )
324
+
325
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
326
+ q = q * self.scale
327
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
328
+
329
+ if exists(media_locations):
330
+ sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
331
+
332
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
333
+ attn = sim.softmax(dim=-1)
334
+
335
+ if exists(media_locations) and self.only_attend_immediate_media:
336
+ # any text without a preceding media needs to have attention zeroed out
337
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
338
+
339
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
340
+ out = rearrange(out, "b h n d -> b n (h d)")
341
+ return self.to_out(out)
342
+
343
+
344
+ class GatedCrossAttentionBlock(nn.Module):
345
+ def __init__(
346
+ self,
347
+ *,
348
+ dim,
349
+ dim_visual,
350
+ dim_head=64,
351
+ heads=8,
352
+ ff_mult=4,
353
+ only_attend_immediate_media=True,
354
+ enable_init_network_params=False,
355
+ initializer_range=0.02,
356
+ gradient_checkpointing=False,
357
+ ):
358
+ super().__init__()
359
+ self.attn = MaskedCrossAttention(
360
+ dim=dim,
361
+ dim_visual=dim_visual,
362
+ dim_head=dim_head,
363
+ heads=heads,
364
+ only_attend_immediate_media=only_attend_immediate_media,
365
+ enable_init_network_params=enable_init_network_params,
366
+ initializer_range=initializer_range,
367
+ )
368
+ self.attn_gate = nn.Parameter(torch.tensor([0.0]))
369
+ self.ff = FeedForward(dim, mult=ff_mult)
370
+ self.ff_gate = nn.Parameter(torch.tensor([0.0]))
371
+ self.gradient_checkpointing = gradient_checkpointing
372
+
373
+ def forward(
374
+ self,
375
+ x,
376
+ media,
377
+ media_locations=None,
378
+ use_cached_media=False,
379
+ ):
380
+ if exists(media_locations):
381
+ flag = torch.sum(media_locations, dim=-1)
382
+ flag = torch.where(flag > 0.0, 1.0, 0.0)
383
+ flag = flag.unsqueeze(1).unsqueeze(1).to(torch.bfloat16)
384
+ else:
385
+ flag = 1.0
386
+
387
+ if self.gradient_checkpointing and media.requires_grad:
388
+ x = (
389
+ flag
390
+ * checkpoint(self.attn, x, media, media_locations, use_cached_media)
391
+ * self.attn_gate.tanh()
392
+ + x
393
+ )
394
+ x = flag * checkpoint(self.ff, x) * self.ff_gate.tanh() + x
395
+
396
+ else:
397
+ x = (
398
+ flag
399
+ * self.attn(
400
+ x,
401
+ media,
402
+ media_locations=media_locations,
403
+ use_cached_media=use_cached_media,
404
+ )
405
+ * self.attn_gate.tanh()
406
+ + x
407
+ )
408
+ x = flag * self.ff(x) * self.ff_gate.tanh() + x
409
+
410
+ return x
modeling_infimm_vicuna.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import math
3
+ from functools import partial
4
+ from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Tuple, Union
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch.cuda.amp import autocast
9
+
10
+ from transformers import GenerationConfig, PreTrainedTokenizer, StoppingCriteriaList
11
+ from transformers.generation.logits_process import LogitsProcessorList
12
+
13
+ if TYPE_CHECKING:
14
+ from transformers.generation.streamers import BaseStreamer
15
+
16
+ from transformers.generation.utils import GenerateOutput
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ )
21
+ from transformers.modeling_utils import PreTrainedModel
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer
23
+ from transformers.utils import logging
24
+
25
+ try:
26
+ from einops import rearrange
27
+ except ImportError:
28
+ rearrange = None
29
+ from torch import nn
30
+
31
+ from .configuration_infimm_vicuna import InfiMMConfig
32
+ from .eva_vit import CLIPVisionCfg, EVAVisionTransformer
33
+ from .flamingo import Flamingo
34
+ from .flamingo_lm import FlamingoLMMixin
35
+ from .helpers import PerceiverResampler
36
+ from .utils import _infer_decoder_layers_attr_name, extend_instance
37
+
38
+ SUPPORT_CUDA = torch.cuda.is_available()
39
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
40
+ SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
41
+
42
+
43
+ class InfiMMPreTrainedModel(PreTrainedModel):
44
+ config_class = InfiMMConfig
45
+ base_model_prefix = "transformer"
46
+ is_parallelizable = False
47
+ supports_gradient_checkpointing = True
48
+
49
+ def __init__(self, *inputs, **kwargs):
50
+ super().__init__(*inputs, **kwargs)
51
+
52
+
53
+ class InfiMMVicunaModel(InfiMMPreTrainedModel):
54
+ def __init__(self, config):
55
+ super().__init__(config)
56
+
57
+ self.vision_config = config.visual
58
+ vision_encoder = self.build_vision_encoder()
59
+ self.language_config = config.language
60
+ language_encoder = self.build_language_encoder()
61
+
62
+ self.model = self.build_flamingo(vision_encoder, language_encoder)
63
+
64
+ def build_vision_encoder(self):
65
+ vision_cfg = CLIPVisionCfg(**self.vision_config)
66
+
67
+ vision_encoder = EVAVisionTransformer(
68
+ img_size=vision_cfg.image_size,
69
+ patch_size=vision_cfg.patch_size,
70
+ num_classes=vision_cfg.embed_dim,
71
+ use_mean_pooling=vision_cfg.global_average_pool, # False
72
+ init_values=vision_cfg.ls_init_value,
73
+ patch_dropout=vision_cfg.patch_dropout,
74
+ embed_dim=vision_cfg.width,
75
+ depth=vision_cfg.layers,
76
+ num_heads=vision_cfg.width // vision_cfg.head_width,
77
+ mlp_ratio=vision_cfg.mlp_ratio,
78
+ qkv_bias=vision_cfg.qkv_bias,
79
+ drop_path_rate=vision_cfg.drop_path_rate,
80
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
81
+ xattn=vision_cfg.xattn,
82
+ rope=vision_cfg.rope,
83
+ postnorm=vision_cfg.postnorm,
84
+ pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
85
+ intp_freq=vision_cfg.intp_freq,
86
+ naiveswiglu=vision_cfg.naiveswiglu,
87
+ subln=vision_cfg.subln,
88
+ )
89
+
90
+ return vision_encoder
91
+
92
+ def build_language_encoder(self):
93
+ lang_encoder = AutoModelForCausalLM.from_pretrained(
94
+ self.language_config["_name_or_path"]
95
+ )
96
+ lang_encoder.resize_token_embeddings(self.language_config["vocab_size"])
97
+ return lang_encoder
98
+
99
+ def build_flamingo(self, vision_encoder, lang_encoder):
100
+ extend_instance(lang_encoder, FlamingoLMMixin)
101
+
102
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
103
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
104
+ # lang_encoder.resize_token_embeddings(self.config.)
105
+
106
+ model = Flamingo(
107
+ vision_encoder,
108
+ lang_encoder,
109
+ self.config.eoc_token_id,
110
+ self.config.image_token_id,
111
+ vis_dim=self.vision_config["width"],
112
+ cross_attn_every_n_layers=self.config.cross_attn_every_n_layers,
113
+ gradient_checkpointing=self.config.use_grad_checkpoint,
114
+ )
115
+
116
+ return model
117
+
118
+ def generate(
119
+ self,
120
+ input_ids,
121
+ attention_mask,
122
+ batch_images,
123
+ min_generation_length: int,
124
+ max_generation_length: int,
125
+ **kwargs,
126
+ ):
127
+ with torch.inference_mode():
128
+ outputs = self.model.generate(
129
+ batch_images,
130
+ input_ids,
131
+ attention_mask,
132
+ min_new_tokens=min_generation_length,
133
+ max_new_tokens=max_generation_length,
134
+ **kwargs,
135
+ )
136
+
137
+ # Extract only the new gnerated tokens
138
+ outputs = outputs[:, len(input_ids[0]) :]
139
+ return outputs
preprocessor_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_infimm_vicuna.InfiMMVicunaProcessor"
5
+ },
6
+ "image_size": 224
7
+ }
processing_infimm_vicuna.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for InfiMMVicuna.
17
+ """
18
+
19
+ import random
20
+ from typing import List, Optional, Tuple, Union
21
+ import torch
22
+ import torchvision.transforms.functional as F
23
+ from PIL import Image
24
+ from torchvision.transforms import (
25
+ CenterCrop,
26
+ Compose,
27
+ InterpolationMode,
28
+ Normalize,
29
+ Resize,
30
+ ToTensor,
31
+ )
32
+
33
+ from transformers import AutoTokenizer
34
+ from transformers.image_processing_utils import ImageProcessingMixin
35
+ from transformers.processing_utils import ProcessorMixin
36
+ from transformers.tokenization_utils_base import BatchEncoding
37
+
38
+ IMAGE_TOKEN = "<image>"
39
+ END_OF_CHUNK_TOKEN = "<|endofchunk|>"
40
+
41
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
42
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
43
+
44
+
45
+ def _convert_to_rgb(image):
46
+ return image.convert("RGB")
47
+
48
+
49
+ class ResizeKeepRatio:
50
+ """Resize and Keep Ratio
51
+
52
+ Copy & paste from `timm`
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ size,
58
+ longest=0.0,
59
+ interpolation=InterpolationMode.BICUBIC,
60
+ random_scale_prob=0.0,
61
+ random_scale_range=(0.85, 1.05),
62
+ random_aspect_prob=0.0,
63
+ random_aspect_range=(0.9, 1.11),
64
+ ):
65
+ if isinstance(size, (list, tuple)):
66
+ self.size = tuple(size)
67
+ else:
68
+ self.size = (size, size)
69
+ self.interpolation = interpolation
70
+ self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
71
+ self.random_scale_prob = random_scale_prob
72
+ self.random_scale_range = random_scale_range
73
+ self.random_aspect_prob = random_aspect_prob
74
+ self.random_aspect_range = random_aspect_range
75
+
76
+ @staticmethod
77
+ def get_params(
78
+ img,
79
+ target_size,
80
+ longest,
81
+ random_scale_prob=0.0,
82
+ random_scale_range=(0.85, 1.05),
83
+ random_aspect_prob=0.0,
84
+ random_aspect_range=(0.9, 1.11),
85
+ ):
86
+ """Get parameters"""
87
+ source_size = img.size[::-1] # h, w
88
+ h, w = source_size
89
+ target_h, target_w = target_size
90
+ ratio_h = h / target_h
91
+ ratio_w = w / target_w
92
+ ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (
93
+ 1.0 - longest
94
+ )
95
+ if random_scale_prob > 0 and random.random() < random_scale_prob:
96
+ ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
97
+ ratio_factor = (ratio_factor, ratio_factor)
98
+ else:
99
+ ratio_factor = (1.0, 1.0)
100
+ if random_aspect_prob > 0 and random.random() < random_aspect_prob:
101
+ aspect_factor = random.uniform(
102
+ random_aspect_range[0], random_aspect_range[1]
103
+ )
104
+ ratio_factor = (
105
+ ratio_factor[0] / aspect_factor,
106
+ ratio_factor[1] * aspect_factor,
107
+ )
108
+ size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
109
+ return size
110
+
111
+ def __call__(self, img):
112
+ """
113
+ Args:
114
+ img (PIL Image): Image to be cropped and resized.
115
+
116
+ Returns:
117
+ PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
118
+ """
119
+ size = self.get_params(
120
+ img,
121
+ self.size,
122
+ self.longest,
123
+ self.random_scale_prob,
124
+ self.random_scale_range,
125
+ self.random_aspect_prob,
126
+ self.random_aspect_range,
127
+ )
128
+ img = F.resize(img, size, self.interpolation)
129
+ return img
130
+
131
+ def __repr__(self):
132
+ format_string = self.__class__.__name__ + "(size={0}".format(self.size)
133
+ format_string += f", interpolation={self.interpolation})"
134
+ format_string += f", longest={self.longest:.3f})"
135
+ return format_string
136
+
137
+
138
+ def image_transform(
139
+ image_size: Union[int, Tuple[int, int]],
140
+ mean: Optional[Tuple[float, ...]] = None,
141
+ std: Optional[Tuple[float, ...]] = None,
142
+ resize_mode: Optional[str] = None,
143
+ interpolation: Optional[str] = None,
144
+ ):
145
+ mean = mean or OPENAI_DATASET_MEAN
146
+ if not isinstance(mean, (list, tuple)):
147
+ mean = (mean,) * 3
148
+
149
+ std = std or OPENAI_DATASET_STD
150
+ if not isinstance(std, (list, tuple)):
151
+ std = (std,) * 3
152
+
153
+ interpolation = interpolation or "bicubic"
154
+ assert interpolation in ["bicubic", "bilinear", "random"]
155
+ # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set
156
+ interpolation_mode = (
157
+ InterpolationMode.BILINEAR
158
+ if interpolation == "bilinear"
159
+ else InterpolationMode.BICUBIC
160
+ )
161
+
162
+ resize_mode = resize_mode or "shortest"
163
+ assert resize_mode in ("shortest", "longest", "squash")
164
+
165
+ normalize = Normalize(mean=mean, std=std)
166
+
167
+ assert resize_mode == "shortest"
168
+ if not isinstance(image_size, (tuple, list)):
169
+ image_size = (image_size, image_size)
170
+ if image_size[0] == image_size[1]:
171
+ # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
172
+ transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
173
+ else:
174
+ # resize shortest edge to matching target dim for non-square target
175
+ transforms = [ResizeKeepRatio(image_size)]
176
+ transforms += [CenterCrop(image_size)]
177
+
178
+ transforms.extend(
179
+ [
180
+ _convert_to_rgb,
181
+ ToTensor(),
182
+ normalize,
183
+ ]
184
+ )
185
+ return Compose(transforms)
186
+
187
+
188
+ class EVAClipImageProcessor(ImageProcessingMixin):
189
+ def __init__(self, **kwargs) -> None:
190
+ super().__init__(**kwargs)
191
+ self.processor = image_transform(image_size=224)
192
+
193
+ def _prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor:
194
+ """
195
+ Convert images to tensors, reshape them, and stack them.
196
+ Args:
197
+ batch: A list of lists of images.
198
+ Returns:
199
+ preprocessed images (tensors) or None
200
+ shape (B, T_img, F, C, H, W)
201
+ None if no images in batch
202
+ """
203
+ images_per_example = max(len(x) for x in batch)
204
+ batch_images = None
205
+ for iexample, example in enumerate(batch):
206
+ for iimage, image in enumerate(example):
207
+ preprocessed = self.processor(image)
208
+ if batch_images is None:
209
+ batch_images = torch.zeros(
210
+ (len(batch), images_per_example, 1) + preprocessed.shape,
211
+ dtype=preprocessed.dtype,
212
+ )
213
+ batch_images[iexample, iimage, 0] = preprocessed
214
+ return batch_images
215
+
216
+ def preprocess(self, imgpaths=None):
217
+ if imgpaths is None or len(imgpaths) == 0:
218
+ images = [(Image.new("RGB", (224, 224), color="black"))]
219
+ else:
220
+ images = [Image.open(fp) for fp in imgpaths]
221
+ return self._prepare_images([images])
222
+
223
+
224
+ class InfiMMVicunaProcessor(ProcessorMixin):
225
+ r"""
226
+ Constructs a InfiMMVicuan processor which wraps a tokenizer and an image processor into a single processor.
227
+
228
+ Args:
229
+ image_processor (`EVAClipImageProcessor`):
230
+ An instance of [`EVAClipImageProcessor`]. The image processor is a required input.
231
+ tokenizer (`LlamaTokenizer`):
232
+ An instance of [`LlamaTokenizer`]. The tokenizer is a required input.
233
+ image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
234
+ """
235
+
236
+ attributes = ["tokenizer"]
237
+ tokenizer_class = "LlamaTokenizer"
238
+
239
+ def __init__(self, tokenizer=None, **kwargs):
240
+ self.image_processor = EVAClipImageProcessor()
241
+ if tokenizer is None:
242
+ tokenizer = AutoTokenizer.from_pretrained("infimm-vicuna", verbose=False)
243
+
244
+ super().__init__(tokenizer, tokenizer)
245
+
246
+ def _prepare_text(
247
+ self,
248
+ batch: List[List[str]],
249
+ padding="longest",
250
+ truncation=True,
251
+ max_length=2048,
252
+ ):
253
+ """
254
+ Tokenize the text and stack them.
255
+ Args:
256
+ batch: A list of lists of strings.
257
+ Returns:
258
+ input_ids (tensor)
259
+ shape (B, T_txt)
260
+ attention_mask (tensor)
261
+ shape (B, T_txt)
262
+ """
263
+ encodings = self.tokenizer(
264
+ batch,
265
+ padding=padding,
266
+ truncation=truncation,
267
+ return_tensors="pt",
268
+ max_length=max_length,
269
+ )
270
+ input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"]
271
+ return input_ids, attention_mask
272
+
273
+ def __call__(
274
+ self,
275
+ prompts,
276
+ ) -> BatchEncoding:
277
+ """This method takes batched or non-batched prompts made of text and images and converts them into prompts that
278
+ the model was trained on and prepares the image pixel values for the model to process.
279
+ """
280
+ image_paths = self._extract_image_paths(prompts)
281
+ images = self.image_processor.preprocess(image_paths)
282
+ prompts = self._replace_with_media_tokens(prompts)
283
+ final_prompt = self.apply_chat_template(prompts)
284
+ print(final_prompt)
285
+ input_ids, attention_mask = self._prepare_text([final_prompt])
286
+ return BatchEncoding(
287
+ data={
288
+ "input_ids": input_ids,
289
+ "attention_mask": attention_mask,
290
+ "batch_images": images,
291
+ }
292
+ )
293
+
294
+ def _extract_image_paths(self, prompts):
295
+ image_paths = []
296
+ for round in prompts:
297
+ if round["role"] != "user":
298
+ continue
299
+ for piece in round["content"]:
300
+ if isinstance(piece, dict):
301
+ image_paths.append(piece["image"])
302
+ return image_paths
303
+
304
+ def _replace_with_media_tokens(self, prompts):
305
+ new_prompts = []
306
+ is_first_img = True
307
+ for round in prompts:
308
+ if round["role"] != "user":
309
+ new_prompts.append(round)
310
+ new_content = []
311
+ for piece in round["content"]:
312
+ if isinstance(piece, dict):
313
+ new_content.append(
314
+ f"{IMAGE_TOKEN}" if is_first_img
315
+ else f"{END_OF_CHUNK_TOKEN}{IMAGE_TOKEN}"
316
+ )
317
+ is_first_img = False
318
+ else:
319
+ new_content.append(piece)
320
+ new_prompts.append({"role": "user", "content": "".join(new_content)})
321
+ return new_prompts
322
+
323
+ def apply_chat_template(self, messages, task="generation"):
324
+ prompt = self.tokenizer.apply_chat_template(
325
+ messages,
326
+ tokenize=False,
327
+ add_generation_prompt=True if task == "generation" else False,
328
+ )
329
+ return prompt
330
+
331
+ def batch_decode(self, *args, **kwargs):
332
+ """
333
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
334
+ refer to the docstring of this method for more information.
335
+ """
336
+ return self.tokenizer.batch_decode(*args, **kwargs)
337
+
338
+ def decode(self, *args, **kwargs):
339
+ """
340
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
341
+ the docstring of this method for more information.
342
+ """
343
+ return self.tokenizer.decode(*args, **kwargs)
344
+
345
+ @property
346
+ def model_input_names(self):
347
+ tokenizer_input_names = self.tokenizer.model_input_names
348
+ image_processor_input_names = self.image_processor.model_input_names
349
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dfd97d3f84c72d126a8d7b4481ae88089b09a7c10c8765f4528001fcb27440e
3
+ size 39422353013
special_tokens_map.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<|endofchunk|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "<image>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ }
17
+ ],
18
+ "bos_token": {
19
+ "content": "<s>",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false
24
+ },
25
+ "eos_token": {
26
+ "content": "</s>",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ },
32
+ "pad_token": {
33
+ "content": "<unk>",
34
+ "lstrip": false,
35
+ "normalized": false,
36
+ "rstrip": false,
37
+ "single_word": false
38
+ },
39
+ "unk_token": {
40
+ "content": "<unk>",
41
+ "lstrip": false,
42
+ "normalized": false,
43
+ "rstrip": false,
44
+ "single_word": false
45
+ }
46
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "32000": {
30
+ "content": "<|endofchunk|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "32001": {
38
+ "content": "<image>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ }
45
+ },
46
+ "additional_special_tokens": [
47
+ "<|endofchunk|>",
48
+ "<image>"
49
+ ],
50
+ "bos_token": "<s>",
51
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ system_message }}{% endif %}{% if message['role'] == 'user' %}{{ ' USER: ' + message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ ' ASSISTANT: ' + message['content'].strip() + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' ASSISTANT:' }}{% endif %}",
52
+ "clean_up_tokenization_spaces": false,
53
+ "eos_token": "</s>",
54
+ "legacy": false,
55
+ "model_max_length": 4096,
56
+ "pad_token": "<unk>",
57
+ "padding_side": "left",
58
+ "sp_model_kwargs": {},
59
+ "spaces_between_special_tokens": false,
60
+ "tokenizer_class": "LlamaTokenizer",
61
+ "unk_token": "<unk>",
62
+ "use_default_system_prompt": false
63
+ }
utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def extend_instance(obj, mixin):
2
+ """Apply mixins to a class instance after creation"""
3
+ base_cls = obj.__class__
4
+ base_cls_name = obj.__class__.__name__
5
+ obj.__class__ = type(
6
+ base_cls_name, (mixin, base_cls), {}
7
+ ) # mixin needs to go first for our forward() logic to work
8
+
9
+
10
+ def getattr_recursive(obj, att):
11
+ """
12
+ Return nested attribute of obj
13
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
14
+ """
15
+ if att == "":
16
+ return obj
17
+ i = att.find(".")
18
+ if i < 0:
19
+ return getattr(obj, att)
20
+ else:
21
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
22
+
23
+
24
+ def setattr_recursive(obj, att, val):
25
+ """
26
+ Set nested attribute of obj
27
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
28
+ """
29
+ if "." in att:
30
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
31
+ setattr(obj, att.split(".")[-1], val)
32
+
33
+
34
+ def _infer_decoder_layers_attr_name(model):
35
+ for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
36
+ if k.lower() in model.__class__.__name__.lower():
37
+ return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
38
+
39
+ raise ValueError(
40
+ "We require the attribute name for the nn.ModuleList in the decoder storing"
41
+ " the transformer block layers. Please supply this string manually."
42
+ )
43
+
44
+
45
+ __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
46
+ "llama": "model.layers",
47
+ "mistral": "model.layers",
48
+ }