merve HF staff commited on
Commit
fea549e
1 Parent(s): d1af207

Upload modeling_nllb_clip.py

Browse files
Files changed (1) hide show
  1. modeling_nllb_clip.py +1403 -0
modeling_nllb_clip.py ADDED
@@ -0,0 +1,1403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch NLLB CLIP model."""
2
+
3
+
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Any, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from configuration_nllb_clip import NLLBCLIPConfig, NLLBCLIPTextConfig
11
+ from torch import nn
12
+ from transformers import CLIPVisionConfig
13
+ from transformers.activations import ACT2FN
14
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import ModelOutput, logging
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ # contrastive loss function, adapted from
23
+ # https://sachinruk.github.io/blog/2021-03-07-clip.html
24
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
25
+ return nn.functional.cross_entropy(
26
+ logits, torch.arange(len(logits), device=logits.device)
27
+ )
28
+
29
+
30
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
31
+ caption_loss = contrastive_loss(similarity)
32
+ image_loss = contrastive_loss(similarity.t())
33
+ return (caption_loss + image_loss) / 2.0
34
+
35
+
36
+ class CLIPVisionEmbeddings(nn.Module):
37
+ def __init__(self, config: CLIPVisionConfig):
38
+ super().__init__()
39
+ self.config = config
40
+ self.embed_dim = config.hidden_size
41
+ self.image_size = config.image_size
42
+ self.patch_size = config.patch_size
43
+
44
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
45
+
46
+ self.patch_embedding = nn.Conv2d(
47
+ in_channels=config.num_channels,
48
+ out_channels=self.embed_dim,
49
+ kernel_size=self.patch_size,
50
+ stride=self.patch_size,
51
+ bias=False,
52
+ )
53
+
54
+ self.num_patches = (self.image_size // self.patch_size) ** 2
55
+ self.num_positions = self.num_patches + 1
56
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
57
+ self.register_buffer(
58
+ "position_ids",
59
+ torch.arange(self.num_positions).expand((1, -1)),
60
+ persistent=False,
61
+ )
62
+
63
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
64
+ batch_size = pixel_values.shape[0]
65
+ target_dtype = self.patch_embedding.weight.dtype
66
+ patch_embeds = self.patch_embedding(
67
+ pixel_values.to(dtype=target_dtype)
68
+ ) # shape = [*, width, grid, grid]
69
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
70
+
71
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
72
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
73
+ embeddings = embeddings + self.position_embedding(self.position_ids)
74
+ return embeddings
75
+
76
+
77
+ class CLIPAttention(nn.Module):
78
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
79
+
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.config = config
83
+ self.embed_dim = config.hidden_size
84
+ self.num_heads = config.num_attention_heads
85
+ self.head_dim = self.embed_dim // self.num_heads
86
+ if self.head_dim * self.num_heads != self.embed_dim:
87
+ raise ValueError(
88
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
89
+ f" {self.num_heads})."
90
+ )
91
+ self.scale = self.head_dim**-0.5
92
+ self.dropout = config.attention_dropout
93
+
94
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
95
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
96
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
97
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
98
+
99
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
100
+ return (
101
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
102
+ .transpose(1, 2)
103
+ .contiguous()
104
+ )
105
+
106
+ def forward(
107
+ self,
108
+ hidden_states: torch.Tensor,
109
+ attention_mask: Optional[torch.Tensor] = None,
110
+ causal_attention_mask: Optional[torch.Tensor] = None,
111
+ output_attentions: Optional[bool] = False,
112
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
113
+ """Input shape: Batch x Time x Channel"""
114
+
115
+ bsz, tgt_len, embed_dim = hidden_states.size()
116
+
117
+ # get query proj
118
+ query_states = self.q_proj(hidden_states) * self.scale
119
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
120
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
121
+
122
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
123
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
124
+ key_states = key_states.view(*proj_shape)
125
+ value_states = value_states.view(*proj_shape)
126
+
127
+ src_len = key_states.size(1)
128
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
129
+
130
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
131
+ raise ValueError(
132
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
133
+ f" {attn_weights.size()}"
134
+ )
135
+
136
+ # apply the causal_attention_mask first
137
+ if causal_attention_mask is not None:
138
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
139
+ raise ValueError(
140
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
141
+ f" {causal_attention_mask.size()}"
142
+ )
143
+ attn_weights = (
144
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
145
+ + causal_attention_mask
146
+ )
147
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
148
+
149
+ if attention_mask is not None:
150
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
151
+ raise ValueError(
152
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
153
+ )
154
+ attn_weights = (
155
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
156
+ + attention_mask
157
+ )
158
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
159
+
160
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
161
+
162
+ if output_attentions:
163
+ # this operation is a bit akward, but it's required to
164
+ # make sure that attn_weights keeps its gradient.
165
+ # In order to do so, attn_weights have to reshaped
166
+ # twice and have to be reused in the following
167
+ attn_weights_reshaped = attn_weights.view(
168
+ bsz, self.num_heads, tgt_len, src_len
169
+ )
170
+ attn_weights = attn_weights_reshaped.view(
171
+ bsz * self.num_heads, tgt_len, src_len
172
+ )
173
+ else:
174
+ attn_weights_reshaped = None
175
+
176
+ attn_probs = nn.functional.dropout(
177
+ attn_weights, p=self.dropout, training=self.training
178
+ )
179
+
180
+ attn_output = torch.bmm(attn_probs, value_states)
181
+
182
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
183
+ raise ValueError(
184
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
185
+ f" {attn_output.size()}"
186
+ )
187
+
188
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
189
+ attn_output = attn_output.transpose(1, 2)
190
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
191
+
192
+ attn_output = self.out_proj(attn_output)
193
+
194
+ return attn_output, attn_weights_reshaped
195
+
196
+
197
+ class CLIPMLP(nn.Module):
198
+ def __init__(self, config):
199
+ super().__init__()
200
+ self.config = config
201
+ self.activation_fn = ACT2FN[config.hidden_act]
202
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
203
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
204
+
205
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
206
+ hidden_states = self.fc1(hidden_states)
207
+ hidden_states = self.activation_fn(hidden_states)
208
+ hidden_states = self.fc2(hidden_states)
209
+ return hidden_states
210
+
211
+
212
+ class CLIPEncoderLayer(nn.Module):
213
+ def __init__(self, config: NLLBCLIPConfig):
214
+ super().__init__()
215
+ self.embed_dim = config.hidden_size
216
+ self.self_attn = CLIPAttention(config)
217
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
218
+ self.mlp = CLIPMLP(config)
219
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
220
+
221
+ def forward(
222
+ self,
223
+ hidden_states: torch.Tensor,
224
+ attention_mask: torch.Tensor,
225
+ causal_attention_mask: torch.Tensor,
226
+ output_attentions: Optional[bool] = False,
227
+ ) -> Tuple[torch.FloatTensor]:
228
+ """
229
+ Args:
230
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
231
+ attention_mask (`torch.FloatTensor`): attention mask of size
232
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
233
+ `(config.encoder_attention_heads,)`.
234
+ output_attentions (`bool`, *optional*):
235
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
236
+ returned tensors for more detail.
237
+ """
238
+ residual = hidden_states
239
+
240
+ hidden_states = self.layer_norm1(hidden_states)
241
+ hidden_states, attn_weights = self.self_attn(
242
+ hidden_states=hidden_states,
243
+ attention_mask=attention_mask,
244
+ causal_attention_mask=causal_attention_mask,
245
+ output_attentions=output_attentions,
246
+ )
247
+ hidden_states = residual + hidden_states
248
+
249
+ residual = hidden_states
250
+ hidden_states = self.layer_norm2(hidden_states)
251
+ hidden_states = self.mlp(hidden_states)
252
+ hidden_states = residual + hidden_states
253
+
254
+ outputs = (hidden_states,)
255
+
256
+ if output_attentions:
257
+ outputs += (attn_weights,)
258
+
259
+ return outputs
260
+
261
+
262
+ class CLIPEncoder(nn.Module):
263
+ """
264
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
265
+ [`CLIPEncoderLayer`].
266
+
267
+ Args:
268
+ config: CLIPConfig
269
+ """
270
+
271
+ def __init__(self, config: NLLBCLIPConfig):
272
+ super().__init__()
273
+ self.config = config
274
+ self.layers = nn.ModuleList(
275
+ [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]
276
+ )
277
+ self.gradient_checkpointing = False
278
+
279
+ def forward(
280
+ self,
281
+ inputs_embeds,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ causal_attention_mask: Optional[torch.Tensor] = None,
284
+ output_attentions: Optional[bool] = None,
285
+ output_hidden_states: Optional[bool] = None,
286
+ return_dict: Optional[bool] = None,
287
+ ) -> Union[Tuple, BaseModelOutput]:
288
+ r"""
289
+ Args:
290
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
291
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
292
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
293
+ than the model's internal embedding lookup matrix.
294
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
295
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
296
+
297
+ - 1 for tokens that are **not masked**,
298
+ - 0 for tokens that are **masked**.
299
+
300
+ [What are attention masks?](../glossary#attention-mask)
301
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
302
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
303
+
304
+ - 1 for tokens that are **not masked**,
305
+ - 0 for tokens that are **masked**.
306
+
307
+ [What are attention masks?](../glossary#attention-mask)
308
+ output_attentions (`bool`, *optional*):
309
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
310
+ returned tensors for more detail.
311
+ output_hidden_states (`bool`, *optional*):
312
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
313
+ for more detail.
314
+ return_dict (`bool`, *optional*):
315
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
316
+ """
317
+ output_attentions = (
318
+ output_attentions
319
+ if output_attentions is not None
320
+ else self.config.output_attentions
321
+ )
322
+ output_hidden_states = (
323
+ output_hidden_states
324
+ if output_hidden_states is not None
325
+ else self.config.output_hidden_states
326
+ )
327
+ return_dict = (
328
+ return_dict if return_dict is not None else self.config.use_return_dict
329
+ )
330
+
331
+ encoder_states = () if output_hidden_states else None
332
+ all_attentions = () if output_attentions else None
333
+
334
+ hidden_states = inputs_embeds
335
+ for idx, encoder_layer in enumerate(self.layers):
336
+ if output_hidden_states:
337
+ encoder_states = encoder_states + (hidden_states,)
338
+ if self.gradient_checkpointing and self.training:
339
+
340
+ def create_custom_forward(module):
341
+ def custom_forward(*inputs):
342
+ return module(*inputs, output_attentions)
343
+
344
+ return custom_forward
345
+
346
+ layer_outputs = torch.utils.checkpoint.checkpoint(
347
+ create_custom_forward(encoder_layer),
348
+ hidden_states,
349
+ attention_mask,
350
+ causal_attention_mask,
351
+ )
352
+ else:
353
+ layer_outputs = encoder_layer(
354
+ hidden_states,
355
+ attention_mask,
356
+ causal_attention_mask,
357
+ output_attentions=output_attentions,
358
+ )
359
+
360
+ hidden_states = layer_outputs[0]
361
+
362
+ if output_attentions:
363
+ all_attentions = all_attentions + (layer_outputs[1],)
364
+
365
+ if output_hidden_states:
366
+ encoder_states = encoder_states + (hidden_states,)
367
+
368
+ if not return_dict:
369
+ return tuple(
370
+ v
371
+ for v in [hidden_states, encoder_states, all_attentions]
372
+ if v is not None
373
+ )
374
+ return BaseModelOutput(
375
+ last_hidden_state=hidden_states,
376
+ hidden_states=encoder_states,
377
+ attentions=all_attentions,
378
+ )
379
+
380
+
381
+ class CLIPVisionTransformer(nn.Module):
382
+ def __init__(self, config: CLIPVisionConfig):
383
+ super().__init__()
384
+ self.config = config
385
+ embed_dim = config.hidden_size
386
+
387
+ self.embeddings = CLIPVisionEmbeddings(config)
388
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
389
+ self.encoder = CLIPEncoder(config)
390
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
391
+
392
+ def forward(
393
+ self,
394
+ pixel_values: Optional[torch.FloatTensor] = None,
395
+ output_attentions: Optional[bool] = None,
396
+ output_hidden_states: Optional[bool] = None,
397
+ return_dict: Optional[bool] = None,
398
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
399
+ r"""
400
+ Returns:
401
+
402
+ """
403
+ output_attentions = (
404
+ output_attentions
405
+ if output_attentions is not None
406
+ else self.config.output_attentions
407
+ )
408
+ output_hidden_states = (
409
+ output_hidden_states
410
+ if output_hidden_states is not None
411
+ else self.config.output_hidden_states
412
+ )
413
+ return_dict = (
414
+ return_dict if return_dict is not None else self.config.use_return_dict
415
+ )
416
+
417
+ if pixel_values is None:
418
+ raise ValueError("You have to specify pixel_values")
419
+
420
+ hidden_states = self.embeddings(pixel_values)
421
+ hidden_states = self.pre_layrnorm(hidden_states)
422
+
423
+ encoder_outputs = self.encoder(
424
+ inputs_embeds=hidden_states,
425
+ output_attentions=output_attentions,
426
+ output_hidden_states=output_hidden_states,
427
+ return_dict=return_dict,
428
+ )
429
+
430
+ last_hidden_state = encoder_outputs[0]
431
+ pooled_output = last_hidden_state[:, 0, :]
432
+ pooled_output = self.post_layernorm(pooled_output)
433
+
434
+ if not return_dict:
435
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
436
+
437
+ return BaseModelOutputWithPooling(
438
+ last_hidden_state=last_hidden_state,
439
+ pooler_output=pooled_output,
440
+ hidden_states=encoder_outputs.hidden_states,
441
+ attentions=encoder_outputs.attentions,
442
+ )
443
+
444
+
445
+ @dataclass
446
+ class NLLBCLIPOutput(ModelOutput):
447
+ """
448
+ Args:
449
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
450
+ Contrastive loss for image-text similarity.
451
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
452
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
453
+ similarity scores.
454
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
455
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
456
+ similarity scores.
457
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
458
+ The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
459
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
460
+ The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
461
+ text_model_output(`BaseModelOutputWithPooling`):
462
+ The output of the [`CLIPTextModel`].
463
+ vision_model_output(`BaseModelOutputWithPooling`):
464
+ The output of the [`CLIPVisionModel`].
465
+ """
466
+
467
+ loss: Optional[torch.FloatTensor] = None
468
+ logits_per_image: torch.FloatTensor = None
469
+ logits_per_text: torch.FloatTensor = None
470
+ text_embeds: torch.FloatTensor = None
471
+ image_embeds: torch.FloatTensor = None
472
+ text_model_output: BaseModelOutputWithPooling = None
473
+ vision_model_output: BaseModelOutputWithPooling = None
474
+
475
+ def to_tuple(self) -> Tuple[Any]:
476
+ return tuple(
477
+ self[k]
478
+ if k not in ["text_model_output", "vision_model_output"]
479
+ else getattr(self, k).to_tuple()
480
+ for k in self.keys()
481
+ )
482
+
483
+
484
+ class M2M100Attention(nn.Module):
485
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
486
+
487
+ def __init__(
488
+ self,
489
+ embed_dim: int,
490
+ num_heads: int,
491
+ dropout: float = 0.0,
492
+ is_decoder: bool = False,
493
+ bias: bool = True,
494
+ ):
495
+ super().__init__()
496
+ self.embed_dim = embed_dim
497
+ self.num_heads = num_heads
498
+ self.dropout = dropout
499
+ self.head_dim = embed_dim // num_heads
500
+
501
+ if (self.head_dim * num_heads) != self.embed_dim:
502
+ raise ValueError(
503
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
504
+ f" and `num_heads`: {num_heads})."
505
+ )
506
+ self.scaling = self.head_dim**-0.5
507
+ self.is_decoder = is_decoder
508
+
509
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
510
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
511
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
512
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
513
+
514
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
515
+ return (
516
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
517
+ .transpose(1, 2)
518
+ .contiguous()
519
+ )
520
+
521
+ def forward(
522
+ self,
523
+ hidden_states: torch.Tensor,
524
+ key_value_states: Optional[torch.Tensor] = None,
525
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
526
+ attention_mask: Optional[torch.Tensor] = None,
527
+ layer_head_mask: Optional[torch.Tensor] = None,
528
+ output_attentions: bool = False,
529
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
530
+ """Input shape: Batch x Time x Channel"""
531
+
532
+ # if key_value_states are provided this layer is used as a cross-attention layer
533
+ # for the decoder
534
+ is_cross_attention = key_value_states is not None
535
+
536
+ bsz, tgt_len, _ = hidden_states.size()
537
+
538
+ # get query proj
539
+ query_states = self.q_proj(hidden_states) * self.scaling
540
+ # get key, value proj
541
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
542
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
543
+ # the provided `key_value_states` to support prefix tuning
544
+ if (
545
+ is_cross_attention
546
+ and past_key_value is not None
547
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
548
+ ):
549
+ # reuse k,v, cross_attentions
550
+ key_states = past_key_value[0]
551
+ value_states = past_key_value[1]
552
+ elif is_cross_attention:
553
+ # cross_attentions
554
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
555
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
556
+ elif past_key_value is not None:
557
+ # reuse k, v, self_attention
558
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
559
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
560
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
561
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
562
+ else:
563
+ # self_attention
564
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
565
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
566
+
567
+ if self.is_decoder:
568
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
569
+ # Further calls to cross_attention layer can then reuse all cross-attention
570
+ # key/value_states (first "if" case)
571
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
572
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
573
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
574
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
575
+ past_key_value = (key_states, value_states)
576
+
577
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
578
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
579
+ key_states = key_states.reshape(*proj_shape)
580
+ value_states = value_states.reshape(*proj_shape)
581
+
582
+ src_len = key_states.size(1)
583
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
584
+
585
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
586
+ raise ValueError(
587
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
588
+ f" {attn_weights.size()}"
589
+ )
590
+
591
+ if attention_mask is not None:
592
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
593
+ raise ValueError(
594
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
595
+ )
596
+ attn_weights = (
597
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
598
+ + attention_mask
599
+ )
600
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
601
+
602
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
603
+
604
+ if layer_head_mask is not None:
605
+ if layer_head_mask.size() != (self.num_heads,):
606
+ raise ValueError(
607
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
608
+ f" {layer_head_mask.size()}"
609
+ )
610
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
611
+ bsz, self.num_heads, tgt_len, src_len
612
+ )
613
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
614
+
615
+ if output_attentions:
616
+ # this operation is a bit awkward, but it's required to
617
+ # make sure that attn_weights keeps its gradient.
618
+ # In order to do so, attn_weights have to be reshaped
619
+ # twice and have to be reused in the following
620
+ attn_weights_reshaped = attn_weights.view(
621
+ bsz, self.num_heads, tgt_len, src_len
622
+ )
623
+ attn_weights = attn_weights_reshaped.view(
624
+ bsz * self.num_heads, tgt_len, src_len
625
+ )
626
+ else:
627
+ attn_weights_reshaped = None
628
+
629
+ attn_probs = nn.functional.dropout(
630
+ attn_weights, p=self.dropout, training=self.training
631
+ )
632
+
633
+ attn_output = torch.bmm(attn_probs, value_states)
634
+
635
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
636
+ raise ValueError(
637
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
638
+ f" {attn_output.size()}"
639
+ )
640
+
641
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
642
+ attn_output = attn_output.transpose(1, 2)
643
+
644
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
645
+ # partitioned across GPUs when using tensor-parallelism.
646
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
647
+
648
+ attn_output = self.out_proj(attn_output)
649
+
650
+ return attn_output, attn_weights_reshaped, past_key_value
651
+
652
+ # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100
653
+
654
+
655
+ class M2M100EncoderLayer(nn.Module):
656
+ def __init__(self, config: NLLBCLIPConfig):
657
+ super().__init__()
658
+ self.embed_dim = config.d_model
659
+ self.self_attn = M2M100Attention(
660
+ embed_dim=self.embed_dim,
661
+ num_heads=config.encoder_attention_heads,
662
+ dropout=config.attention_dropout,
663
+ )
664
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
665
+ self.dropout = config.dropout
666
+ self.activation_fn = ACT2FN[config.activation_function]
667
+ self.activation_dropout = config.activation_dropout
668
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
669
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
670
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
671
+
672
+ def forward(
673
+ self,
674
+ hidden_states: torch.Tensor,
675
+ attention_mask: torch.Tensor,
676
+ layer_head_mask: torch.Tensor,
677
+ output_attentions: bool = False,
678
+ ) -> torch.Tensor:
679
+ """
680
+ Args:
681
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
682
+ attention_mask (`torch.FloatTensor`): attention mask of size
683
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
684
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
685
+ `(encoder_attention_heads,)`.
686
+ output_attentions (`bool`, *optional*):
687
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
688
+ returned tensors for more detail.
689
+ """
690
+ residual = hidden_states
691
+ hidden_states = self.self_attn_layer_norm(hidden_states)
692
+ hidden_states, attn_weights, _ = self.self_attn(
693
+ hidden_states=hidden_states,
694
+ attention_mask=attention_mask,
695
+ layer_head_mask=layer_head_mask,
696
+ output_attentions=output_attentions,
697
+ )
698
+ hidden_states = nn.functional.dropout(
699
+ hidden_states, p=self.dropout, training=self.training
700
+ )
701
+ hidden_states = residual + hidden_states
702
+
703
+ residual = hidden_states
704
+ hidden_states = self.final_layer_norm(hidden_states)
705
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
706
+ hidden_states = nn.functional.dropout(
707
+ hidden_states, p=self.activation_dropout, training=self.training
708
+ )
709
+ hidden_states = self.fc2(hidden_states)
710
+ hidden_states = nn.functional.dropout(
711
+ hidden_states, p=self.dropout, training=self.training
712
+ )
713
+ hidden_states = residual + hidden_states
714
+
715
+ if hidden_states.dtype == torch.float16 and (
716
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
717
+ ):
718
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
719
+ hidden_states = torch.clamp(
720
+ hidden_states, min=-clamp_value, max=clamp_value
721
+ )
722
+
723
+ outputs = (hidden_states,)
724
+
725
+ if output_attentions:
726
+ outputs += (attn_weights,)
727
+
728
+ return outputs
729
+
730
+
731
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
732
+ """
733
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
734
+ """
735
+ bsz, src_len = mask.size()
736
+ tgt_len = tgt_len if tgt_len is not None else src_len
737
+
738
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
739
+
740
+ inverted_mask = 1.0 - expanded_mask
741
+
742
+ return inverted_mask.masked_fill(
743
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
744
+ )
745
+
746
+
747
+ def create_position_ids_from_input_ids(
748
+ input_ids, padding_idx, past_key_values_length=0
749
+ ):
750
+ """
751
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
752
+ are ignored. This is modified from fairseq's `utils.make_positions`.
753
+ """
754
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
755
+ mask = input_ids.ne(padding_idx).int()
756
+ incremental_indices = (
757
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
758
+ ) * mask
759
+ return incremental_indices.long() + padding_idx
760
+
761
+
762
+ class M2M100SinusoidalPositionalEmbedding(nn.Module):
763
+ """This module produces sinusoidal positional embeddings of any length."""
764
+
765
+ def __init__(
766
+ self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
767
+ ):
768
+ super().__init__()
769
+ self.offset = 2
770
+ self.embedding_dim = embedding_dim
771
+ self.padding_idx = padding_idx
772
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
773
+
774
+ def make_weights(
775
+ self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
776
+ ):
777
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
778
+ if hasattr(self, "weights"):
779
+ # in forward put the weights on the correct dtype and device of the param
780
+ emb_weights = emb_weights.to(
781
+ dtype=self.weights.dtype, device=self.weights.device
782
+ )
783
+
784
+ self.register_buffer("weights", emb_weights, persistent=False)
785
+
786
+ @staticmethod
787
+ def get_embedding(
788
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
789
+ ):
790
+ """
791
+ Build sinusoidal embeddings.
792
+
793
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
794
+ "Attention Is All You Need".
795
+ """
796
+ half_dim = embedding_dim // 2
797
+ emb = math.log(10000) / (half_dim - 1)
798
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
799
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
800
+ 1
801
+ ) * emb.unsqueeze(0)
802
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
803
+ num_embeddings, -1
804
+ )
805
+ if embedding_dim % 2 == 1:
806
+ # zero pad
807
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
808
+ if padding_idx is not None:
809
+ emb[padding_idx, :] = 0
810
+
811
+ return emb.to(torch.get_default_dtype())
812
+
813
+ @torch.no_grad()
814
+ def forward(
815
+ self,
816
+ input_ids: torch.Tensor = None,
817
+ inputs_embeds: torch.Tensor = None,
818
+ past_key_values_length: int = 0,
819
+ ):
820
+ if input_ids is not None:
821
+ bsz, seq_len = input_ids.size()
822
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
823
+ position_ids = create_position_ids_from_input_ids(
824
+ input_ids, self.padding_idx, past_key_values_length
825
+ ).to(input_ids.device)
826
+ else:
827
+ bsz, seq_len = inputs_embeds.size()[:-1]
828
+ position_ids = self.create_position_ids_from_inputs_embeds(
829
+ inputs_embeds, past_key_values_length
830
+ )
831
+
832
+ # expand embeddings if needed
833
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
834
+ if max_pos > self.weights.size(0):
835
+ self.make_weights(
836
+ max_pos + self.offset, self.embedding_dim, self.padding_idx
837
+ )
838
+
839
+ return (
840
+ self.weights.index_select(0, position_ids.view(-1))
841
+ .view(bsz, seq_len, self.weights.shape[-1])
842
+ .detach()
843
+ )
844
+
845
+ def create_position_ids_from_inputs_embeds(
846
+ self, inputs_embeds, past_key_values_length
847
+ ):
848
+ """
849
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
850
+
851
+ Args:
852
+ inputs_embeds: torch.Tensor
853
+
854
+ Returns: torch.Tensor
855
+ """
856
+ input_shape = inputs_embeds.size()[:-1]
857
+ sequence_length = input_shape[1]
858
+
859
+ position_ids = torch.arange(
860
+ self.padding_idx + 1,
861
+ sequence_length + self.padding_idx + 1,
862
+ dtype=torch.long,
863
+ device=inputs_embeds.device,
864
+ )
865
+ return (
866
+ position_ids.unsqueeze(0).expand(input_shape).contiguous()
867
+ + past_key_values_length
868
+ )
869
+
870
+
871
+ class M2M100Encoder(PreTrainedModel):
872
+ """
873
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
874
+ [`M2M100EncoderLayer`].
875
+
876
+ Args:
877
+ config: M2M100Config
878
+ embed_tokens (nn.Embedding): output embedding
879
+ """
880
+
881
+ def __init__(
882
+ self, config: NLLBCLIPConfig, embed_tokens: Optional[nn.Embedding] = None
883
+ ):
884
+ super().__init__(config)
885
+
886
+ self.dropout = config.dropout
887
+ self.layerdrop = config.encoder_layerdrop
888
+
889
+ embed_dim = config.d_model
890
+ self.padding_idx = config.pad_token_id
891
+ self.max_source_positions = config.max_position_embeddings
892
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
893
+
894
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
895
+
896
+ if embed_tokens is not None:
897
+ self.embed_tokens.weight = embed_tokens.weight
898
+
899
+ self.embed_positions = M2M100SinusoidalPositionalEmbedding(
900
+ config.max_position_embeddings,
901
+ embed_dim,
902
+ self.padding_idx,
903
+ )
904
+ self.layers = nn.ModuleList(
905
+ [M2M100EncoderLayer(config) for _ in range(config.encoder_layers)]
906
+ )
907
+ self.layer_norm = nn.LayerNorm(config.d_model)
908
+
909
+ self.gradient_checkpointing = False
910
+ # Initialize weights and apply final processing
911
+ self.post_init()
912
+
913
+ def forward(
914
+ self,
915
+ input_ids: Optional[torch.Tensor] = None,
916
+ attention_mask: Optional[torch.Tensor] = None,
917
+ head_mask: Optional[torch.Tensor] = None,
918
+ inputs_embeds: Optional[torch.Tensor] = None,
919
+ output_attentions: Optional[bool] = None,
920
+ output_hidden_states: Optional[bool] = None,
921
+ return_dict: Optional[bool] = None,
922
+ ):
923
+ r"""
924
+ Args:
925
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
926
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
927
+ provide it.
928
+
929
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
930
+ [`PreTrainedTokenizer.__call__`] for details.
931
+
932
+ [What are input IDs?](../glossary#input-ids)
933
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
934
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
935
+
936
+ - 1 for tokens that are **not masked**,
937
+ - 0 for tokens that are **masked**.
938
+
939
+ [What are attention masks?](../glossary#attention-mask)
940
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
941
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
942
+
943
+ - 1 indicates the head is **not masked**,
944
+ - 0 indicates the head is **masked**.
945
+
946
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
947
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
948
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
949
+ than the model's internal embedding lookup matrix.
950
+ output_attentions (`bool`, *optional*):
951
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
952
+ returned tensors for more detail.
953
+ output_hidden_states (`bool`, *optional*):
954
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
955
+ for more detail.
956
+ return_dict (`bool`, *optional*):
957
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
958
+ """
959
+ output_attentions = (
960
+ output_attentions
961
+ if output_attentions is not None
962
+ else self.config.output_attentions
963
+ )
964
+ output_hidden_states = (
965
+ output_hidden_states
966
+ if output_hidden_states is not None
967
+ else self.config.output_hidden_states
968
+ )
969
+ return_dict = (
970
+ return_dict if return_dict is not None else self.config.use_return_dict
971
+ )
972
+
973
+ # retrieve input_ids and inputs_embeds
974
+ if input_ids is not None and inputs_embeds is not None:
975
+ raise ValueError(
976
+ "You cannot specify both input_ids and inputs_embeds at the same time"
977
+ )
978
+ elif input_ids is not None:
979
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
980
+ input_shape = input_ids.size()
981
+ input_ids = input_ids.view(-1, input_shape[-1])
982
+ elif inputs_embeds is not None:
983
+ input_shape = inputs_embeds.size()[:-1]
984
+ else:
985
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
986
+
987
+ if inputs_embeds is None:
988
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
989
+
990
+ embed_pos = self.embed_positions(input_ids, inputs_embeds)
991
+ embed_pos = embed_pos.to(inputs_embeds.device)
992
+
993
+ hidden_states = inputs_embeds + embed_pos
994
+ hidden_states = nn.functional.dropout(
995
+ hidden_states, p=self.dropout, training=self.training
996
+ )
997
+
998
+ # expand attention_mask
999
+ if attention_mask is not None:
1000
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1001
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
1002
+
1003
+ encoder_states = () if output_hidden_states else None
1004
+ all_attentions = () if output_attentions else None
1005
+
1006
+ # check if head_mask has a correct number of layers specified if desired
1007
+ if head_mask is not None:
1008
+ if head_mask.size()[0] != len(self.layers):
1009
+ raise ValueError(
1010
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
1011
+ f" {head_mask.size()[0]}."
1012
+ )
1013
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1014
+
1015
+ for idx, encoder_layer in enumerate(self.layers):
1016
+ if output_hidden_states:
1017
+ encoder_states = encoder_states + (hidden_states,)
1018
+
1019
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1020
+ dropout_probability = torch.rand([])
1021
+
1022
+ skip_the_layer = (
1023
+ True
1024
+ if self.training and (dropout_probability < self.layerdrop)
1025
+ else False
1026
+ )
1027
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1028
+ # under deepspeed zero3 all gpus must run in sync
1029
+
1030
+ if self.gradient_checkpointing and self.training:
1031
+ # create gradient checkpointing function
1032
+ def create_custom_forward(module):
1033
+ def custom_forward(*inputs):
1034
+ return module(*inputs, output_attentions)
1035
+
1036
+ return custom_forward
1037
+
1038
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1039
+ create_custom_forward(encoder_layer),
1040
+ hidden_states,
1041
+ attention_mask,
1042
+ (head_mask[idx] if head_mask is not None else None),
1043
+ )
1044
+ else:
1045
+ layer_outputs = encoder_layer(
1046
+ hidden_states,
1047
+ attention_mask,
1048
+ layer_head_mask=(
1049
+ head_mask[idx] if head_mask is not None else None
1050
+ ),
1051
+ output_attentions=output_attentions,
1052
+ )
1053
+
1054
+ hidden_states = layer_outputs[0]
1055
+
1056
+ if skip_the_layer:
1057
+ layer_outputs = (None, None)
1058
+
1059
+ if output_attentions:
1060
+ all_attentions = all_attentions + (layer_outputs[1],)
1061
+
1062
+ hidden_states = self.layer_norm(hidden_states)
1063
+
1064
+ if output_hidden_states:
1065
+ encoder_states = encoder_states + (hidden_states,)
1066
+
1067
+ if not return_dict:
1068
+ return tuple(
1069
+ v
1070
+ for v in [hidden_states, encoder_states, all_attentions]
1071
+ if v is not None
1072
+ )
1073
+ return BaseModelOutput(
1074
+ last_hidden_state=hidden_states,
1075
+ hidden_states=encoder_states,
1076
+ attentions=all_attentions,
1077
+ )
1078
+
1079
+
1080
+ class CLIPTextTransformer(nn.Module):
1081
+ def __init__(self, config: NLLBCLIPTextConfig):
1082
+ super().__init__()
1083
+ self.config = config
1084
+ embed_dim = config.hidden_size
1085
+ self.encoder = M2M100Encoder(config)
1086
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1087
+
1088
+ # For `pooled_output` computation
1089
+ self.eos_token_id = config.eos_token_id
1090
+
1091
+ def forward(
1092
+ self,
1093
+ input_ids: Optional[torch.Tensor] = None,
1094
+ attention_mask: Optional[torch.Tensor] = None,
1095
+ output_attentions: Optional[bool] = None,
1096
+ output_hidden_states: Optional[bool] = None,
1097
+ return_dict: Optional[bool] = None,
1098
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1099
+ r"""
1100
+ Returns:
1101
+
1102
+ """
1103
+ output_attentions = (
1104
+ output_attentions
1105
+ if output_attentions is not None
1106
+ else self.config.output_attentions
1107
+ )
1108
+ output_hidden_states = (
1109
+ output_hidden_states
1110
+ if output_hidden_states is not None
1111
+ else self.config.output_hidden_states
1112
+ )
1113
+ return_dict = (
1114
+ return_dict if return_dict is not None else self.config.use_return_dict
1115
+ )
1116
+
1117
+ if input_ids is None:
1118
+ raise ValueError("You have to specify input_ids")
1119
+
1120
+ input_shape = input_ids.size()
1121
+ input_ids = input_ids.view(-1, input_shape[-1])
1122
+
1123
+ encoder_outputs = self.encoder(
1124
+ input_ids=input_ids,
1125
+ attention_mask=attention_mask,
1126
+ output_attentions=output_attentions,
1127
+ output_hidden_states=output_hidden_states,
1128
+ return_dict=return_dict,
1129
+ )
1130
+
1131
+ last_hidden_state = encoder_outputs[0]
1132
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
1133
+
1134
+ pooled_output = last_hidden_state[
1135
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
1136
+ 0,
1137
+ ]
1138
+
1139
+ if not return_dict:
1140
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1141
+
1142
+ return BaseModelOutputWithPooling(
1143
+ last_hidden_state=last_hidden_state,
1144
+ pooler_output=pooled_output,
1145
+ hidden_states=encoder_outputs.hidden_states,
1146
+ attentions=encoder_outputs.attentions,
1147
+ )
1148
+
1149
+
1150
+ class NLLBCLIPModel(PreTrainedModel):
1151
+ config_class = NLLBCLIPConfig
1152
+
1153
+ def __init__(self, config: NLLBCLIPConfig):
1154
+ super().__init__(config)
1155
+
1156
+ if not isinstance(config.text_config, NLLBCLIPTextConfig):
1157
+ raise ValueError(
1158
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
1159
+ f" {type(config.text_config)}."
1160
+ )
1161
+
1162
+ if not isinstance(config.vision_config, CLIPVisionConfig):
1163
+ raise ValueError(
1164
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
1165
+ f" {type(config.vision_config)}."
1166
+ )
1167
+
1168
+ text_config = config.text_config
1169
+ vision_config = config.vision_config
1170
+
1171
+ self.projection_dim = config.projection_dim
1172
+ self.text_embed_dim = text_config.hidden_size
1173
+ self.vision_embed_dim = vision_config.hidden_size
1174
+
1175
+ self.text_model = CLIPTextTransformer(text_config)
1176
+ self.vision_model = CLIPVisionTransformer(vision_config)
1177
+
1178
+ self.visual_projection = nn.Linear(
1179
+ self.vision_embed_dim, self.projection_dim, bias=False
1180
+ )
1181
+ self.text_projection = nn.Linear(
1182
+ self.text_embed_dim, self.projection_dim, bias=False
1183
+ )
1184
+ self.logit_scale = nn.Parameter(
1185
+ torch.tensor(self.config.logit_scale_init_value)
1186
+ )
1187
+
1188
+ # Initialize weights and apply final processing
1189
+ self.post_init()
1190
+
1191
+ def get_text_features(
1192
+ self,
1193
+ input_ids: Optional[torch.Tensor] = None,
1194
+ attention_mask: Optional[torch.Tensor] = None,
1195
+ position_ids: Optional[torch.Tensor] = None,
1196
+ output_attentions: Optional[bool] = None,
1197
+ output_hidden_states: Optional[bool] = None,
1198
+ return_dict: Optional[bool] = None,
1199
+ ) -> torch.FloatTensor:
1200
+ r"""
1201
+ Returns:
1202
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1203
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
1204
+
1205
+ Examples:
1206
+
1207
+ ```python
1208
+ >>> from transformers import AutoTokenizer, CLIPModel
1209
+
1210
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1211
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1212
+
1213
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1214
+ >>> text_features = model.get_text_features(**inputs)
1215
+ ```"""
1216
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1217
+ output_attentions = (
1218
+ output_attentions
1219
+ if output_attentions is not None
1220
+ else self.config.output_attentions
1221
+ )
1222
+ output_hidden_states = (
1223
+ output_hidden_states
1224
+ if output_hidden_states is not None
1225
+ else self.config.output_hidden_states
1226
+ )
1227
+ return_dict = (
1228
+ return_dict if return_dict is not None else self.config.use_return_dict
1229
+ )
1230
+
1231
+ text_outputs = self.text_model(
1232
+ input_ids=input_ids,
1233
+ attention_mask=attention_mask,
1234
+ position_ids=position_ids,
1235
+ output_attentions=output_attentions,
1236
+ output_hidden_states=output_hidden_states,
1237
+ return_dict=return_dict,
1238
+ )
1239
+
1240
+ pooled_output = text_outputs[1]
1241
+ text_features = self.text_projection(pooled_output)
1242
+
1243
+ return text_features
1244
+
1245
+ def get_image_features(
1246
+ self,
1247
+ pixel_values: Optional[torch.FloatTensor] = None,
1248
+ output_attentions: Optional[bool] = None,
1249
+ output_hidden_states: Optional[bool] = None,
1250
+ return_dict: Optional[bool] = None,
1251
+ ) -> torch.FloatTensor:
1252
+ r"""
1253
+ Returns:
1254
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1255
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
1256
+
1257
+ Examples:
1258
+
1259
+ ```python
1260
+ >>> from PIL import Image
1261
+ >>> import requests
1262
+ >>> from transformers import AutoProcessor, CLIPModel
1263
+
1264
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1265
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1266
+
1267
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1268
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1269
+
1270
+ >>> inputs = processor(images=image, return_tensors="pt")
1271
+
1272
+ >>> image_features = model.get_image_features(**inputs)
1273
+ ```"""
1274
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1275
+ output_attentions = (
1276
+ output_attentions
1277
+ if output_attentions is not None
1278
+ else self.config.output_attentions
1279
+ )
1280
+ output_hidden_states = (
1281
+ output_hidden_states
1282
+ if output_hidden_states is not None
1283
+ else self.config.output_hidden_states
1284
+ )
1285
+ return_dict = (
1286
+ return_dict if return_dict is not None else self.config.use_return_dict
1287
+ )
1288
+
1289
+ vision_outputs = self.vision_model(
1290
+ pixel_values=pixel_values,
1291
+ output_attentions=output_attentions,
1292
+ output_hidden_states=output_hidden_states,
1293
+ return_dict=return_dict,
1294
+ )
1295
+
1296
+ pooled_output = vision_outputs[1] # pooled_output
1297
+ image_features = self.visual_projection(pooled_output)
1298
+
1299
+ return image_features
1300
+
1301
+ def forward(
1302
+ self,
1303
+ input_ids: Optional[torch.LongTensor] = None,
1304
+ pixel_values: Optional[torch.FloatTensor] = None,
1305
+ attention_mask: Optional[torch.Tensor] = None,
1306
+ return_loss: Optional[bool] = None,
1307
+ output_attentions: Optional[bool] = None,
1308
+ output_hidden_states: Optional[bool] = None,
1309
+ return_dict: Optional[bool] = None,
1310
+ ) -> Union[Tuple, NLLBCLIPOutput]:
1311
+ r"""
1312
+ Returns:
1313
+
1314
+ Examples:
1315
+
1316
+ ```python
1317
+ >>> from PIL import Image
1318
+ >>> import requests
1319
+ >>> from transformers import AutoProcessor, CLIPModel
1320
+
1321
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1322
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1323
+
1324
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1325
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1326
+
1327
+ >>> inputs = processor(
1328
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1329
+ ... )
1330
+
1331
+ >>> outputs = model(**inputs)
1332
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1333
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1334
+ ```"""
1335
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1336
+ output_attentions = (
1337
+ output_attentions
1338
+ if output_attentions is not None
1339
+ else self.config.output_attentions
1340
+ )
1341
+ output_hidden_states = (
1342
+ output_hidden_states
1343
+ if output_hidden_states is not None
1344
+ else self.config.output_hidden_states
1345
+ )
1346
+ return_dict = (
1347
+ return_dict if return_dict is not None else self.config.use_return_dict
1348
+ )
1349
+
1350
+ vision_outputs = self.vision_model(
1351
+ pixel_values=pixel_values,
1352
+ output_attentions=output_attentions,
1353
+ output_hidden_states=output_hidden_states,
1354
+ return_dict=return_dict,
1355
+ )
1356
+
1357
+ text_outputs = self.text_model(
1358
+ input_ids=input_ids,
1359
+ attention_mask=attention_mask,
1360
+ output_attentions=output_attentions,
1361
+ output_hidden_states=output_hidden_states,
1362
+ return_dict=return_dict,
1363
+ )
1364
+
1365
+ image_embeds = vision_outputs[1]
1366
+ image_embeds = self.visual_projection(image_embeds)
1367
+
1368
+ text_embeds = text_outputs[1]
1369
+ text_embeds = self.text_projection(text_embeds)
1370
+
1371
+ # normalized features
1372
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1373
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1374
+
1375
+ # cosine similarity as logits
1376
+ logit_scale = self.logit_scale.exp()
1377
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1378
+ logits_per_image = logits_per_text.t()
1379
+
1380
+ loss = None
1381
+ if return_loss:
1382
+ loss = clip_loss(logits_per_text)
1383
+
1384
+ if not return_dict:
1385
+ output = (
1386
+ logits_per_image,
1387
+ logits_per_text,
1388
+ text_embeds,
1389
+ image_embeds,
1390
+ text_outputs,
1391
+ vision_outputs,
1392
+ )
1393
+ return ((loss,) + output) if loss is not None else output
1394
+
1395
+ return NLLBCLIPOutput(
1396
+ loss=loss,
1397
+ logits_per_image=logits_per_image,
1398
+ logits_per_text=logits_per_text,
1399
+ text_embeds=text_embeds,
1400
+ image_embeds=image_embeds,
1401
+ text_model_output=text_outputs,
1402
+ vision_model_output=vision_outputs,
1403
+ )