AlekseyCalvin commited on
Commit
c838d44
1 Parent(s): d40efd6

Upload 2 files

Browse files
Files changed (2) hide show
  1. text_encoder.py +465 -0
  2. text_encoder_2.safetensors +3 -0
text_encoder.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ import math
4
+ from safetensors.torch import load_model
5
+ import torch
6
+ from torch import nn
7
+ from transformers.modeling_utils import ModuleUtilsMixin, PretrainedConfig, PreTrainedModel
8
+ from transformers.modeling_outputs import ModelOutput
9
+ from transformers.models.t5.configuration_t5 import T5Config
10
+ from transformers.models.t5.modeling_t5 import (
11
+ T5LayerNorm,
12
+ T5DenseGatedActDense,
13
+ )
14
+ from typing import Optional
15
+
16
+ ###
17
+ # Code from from PiotrNawrot/nanoT5/nanoT5/utils/t5_model.py
18
+
19
+ @dataclass
20
+ class EncoderOutput(ModelOutput):
21
+ hidden_states: torch.FloatTensor = None
22
+ attention_mask: torch.FloatTensor = None
23
+
24
+
25
+ @dataclass
26
+ class Seq2SeqLMOutput(ModelOutput):
27
+ loss: torch.FloatTensor = None
28
+ logits: torch.FloatTensor = None
29
+ encoder_outputs: EncoderOutput = None
30
+
31
+
32
+ class T5LayerFF(nn.Module):
33
+ def __init__(self, config: T5Config):
34
+ super().__init__()
35
+ assert config.is_gated_act
36
+ self.DenseReluDense = T5DenseGatedActDense(config)
37
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
38
+ self.dropout = nn.Dropout(config.dropout_rate)
39
+
40
+ def forward(self, hidden_states):
41
+ forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states)
42
+ forwarded_states = self.DenseReluDense(forwarded_states)
43
+ hidden_states = hidden_states + self.dropout(forwarded_states)
44
+ return hidden_states
45
+
46
+
47
+ class T5Attention(nn.Module):
48
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
49
+ super().__init__()
50
+ self.is_decoder = config.is_decoder
51
+ self.has_relative_attention_bias = has_relative_attention_bias
52
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
53
+ self.relative_attention_max_distance = config.relative_attention_max_distance
54
+ self.d_model = config.d_model
55
+ self.key_value_proj_dim = config.d_kv
56
+ self.n_heads = config.num_heads
57
+ self.dropout = config.dropout_rate
58
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
59
+
60
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
61
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
62
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
63
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
64
+
65
+ if self.has_relative_attention_bias:
66
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
67
+
68
+ @staticmethod
69
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
70
+ """
71
+ Adapted from Mesh Tensorflow:
72
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
73
+
74
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
75
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
76
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
77
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
78
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
79
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
80
+
81
+ Args:
82
+ relative_position: an int32 Tensor
83
+ bidirectional: a boolean - whether the attention is bidirectional
84
+ num_buckets: an integer
85
+ max_distance: an integer
86
+
87
+ Returns:
88
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
89
+ """
90
+ relative_buckets = 0
91
+ if bidirectional:
92
+ num_buckets //= 2
93
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
94
+ relative_position = torch.abs(relative_position)
95
+ else:
96
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
97
+ # now relative_position is in the range [0, inf)
98
+
99
+ # half of the buckets are for exact increments in positions
100
+ max_exact = num_buckets // 2
101
+ is_small = relative_position < max_exact
102
+
103
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
104
+ relative_position_if_large = max_exact + (
105
+ torch.log(relative_position.float() / max_exact)
106
+ / math.log(max_distance / max_exact)
107
+ * (num_buckets - max_exact)
108
+ ).to(torch.long)
109
+ relative_position_if_large = torch.min(
110
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
111
+ )
112
+
113
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
114
+ return relative_buckets
115
+
116
+ def compute_bias(self, query_length, key_length, device=None):
117
+ """Compute binned relative position bias"""
118
+ if device is None:
119
+ device = self.relative_attention_bias.weight.device
120
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
121
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
122
+ relative_position = memory_position - context_position # shape (query_length, key_length)
123
+ relative_position_bucket = self._relative_position_bucket(
124
+ relative_position, # shape (query_length, key_length)
125
+ bidirectional=(not self.is_decoder),
126
+ num_buckets=self.relative_attention_num_buckets,
127
+ max_distance=self.relative_attention_max_distance,
128
+ )
129
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
130
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
131
+ return values
132
+
133
+ def forward(
134
+ self,
135
+ hidden_states,
136
+ mask=None,
137
+ key_value_states=None,
138
+ position_bias=None,
139
+ ):
140
+ """
141
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
142
+ """
143
+ # Input is (batch_size, seq_length, dim)
144
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
145
+ batch_size, seq_length = hidden_states.shape[:2]
146
+ real_seq_length = seq_length
147
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
148
+
149
+ def shape(states):
150
+ """projection"""
151
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
152
+
153
+ def unshape(states):
154
+ """reshape"""
155
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
156
+
157
+ query_states = self.q(hidden_states)
158
+ if key_value_states is None:
159
+ key_states, value_states = self.k(hidden_states), self.v(hidden_states)
160
+ else:
161
+ key_states, value_states = self.k(key_value_states), self.v(key_value_states)
162
+ query_states, key_states, value_states = shape(query_states), shape(key_states), shape(value_states)
163
+
164
+ scores = torch.matmul(
165
+ query_states, key_states.transpose(3, 2)
166
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
167
+
168
+ if position_bias is None:
169
+ if not self.has_relative_attention_bias:
170
+ position_bias = torch.zeros(
171
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
172
+ )
173
+ else:
174
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
175
+
176
+ if mask is not None:
177
+ # Masking happens here, masked elements in the mask have the value of -inf
178
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
179
+
180
+ position_bias_masked = position_bias
181
+
182
+ scores += position_bias_masked
183
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
184
+ scores
185
+ ) # (batch_size, n_heads, seq_length, key_length)
186
+ attn_weights = nn.functional.dropout(
187
+ attn_weights, p=self.dropout, training=self.training
188
+ ) # (batch_size, n_heads, seq_length, key_length)
189
+
190
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
191
+ attn_output = self.o(attn_output)
192
+
193
+ return (attn_output, position_bias)
194
+
195
+
196
+ class T5LayerSelfAttention(nn.Module):
197
+ def __init__(self, config, has_relative_attention_bias=False):
198
+ super().__init__()
199
+ self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
200
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
201
+ self.dropout = nn.Dropout(config.dropout_rate)
202
+
203
+ def forward(
204
+ self,
205
+ hidden_states,
206
+ attention_mask=None,
207
+ position_bias=None,
208
+ ):
209
+ normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states)
210
+ attention_output = self.SelfAttention(
211
+ normed_hidden_states,
212
+ mask=attention_mask,
213
+ position_bias=position_bias,
214
+ )
215
+ hidden_states = hidden_states + self.dropout(attention_output[0])
216
+ outputs = (hidden_states,) + attention_output[1:]
217
+ return outputs
218
+
219
+
220
+ class T5LayerCrossAttention(nn.Module):
221
+ def __init__(self, config):
222
+ super().__init__()
223
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
224
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
225
+ self.dropout = nn.Dropout(config.dropout_rate)
226
+
227
+ def forward(
228
+ self,
229
+ hidden_states,
230
+ key_value_states,
231
+ attention_mask=None,
232
+ position_bias=None,
233
+ ):
234
+ normed_hidden_states = self.layer_norm(hidden_states)
235
+ attention_output = self.EncDecAttention(
236
+ normed_hidden_states,
237
+ mask=attention_mask,
238
+ key_value_states=key_value_states,
239
+ position_bias=position_bias,
240
+ )
241
+ layer_output = hidden_states + self.dropout(attention_output[0])
242
+ outputs = (layer_output,) + attention_output[1:]
243
+ return outputs
244
+
245
+
246
+ class T5Block(nn.Module):
247
+ def __init__(self, config, has_relative_attention_bias=False):
248
+ super().__init__()
249
+ self.is_decoder = config.is_decoder
250
+ self.layer = nn.ModuleList()
251
+ self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
252
+ if self.is_decoder:
253
+ self.layer.append(T5LayerCrossAttention(config))
254
+
255
+ self.layer.append(T5LayerFF(config))
256
+
257
+ def forward(
258
+ self,
259
+ hidden_states,
260
+ attention_mask=None,
261
+ position_bias=None,
262
+ encoder_hidden_states=None,
263
+ encoder_attention_mask=None,
264
+ encoder_decoder_position_bias=None,
265
+ ):
266
+ self_attention_outputs = self.layer[0](
267
+ hidden_states,
268
+ attention_mask=attention_mask,
269
+ position_bias=position_bias,
270
+ )
271
+ hidden_states = self_attention_outputs[0]
272
+ attention_outputs = self_attention_outputs[1:] # Relative position weights
273
+
274
+ if self.is_decoder and encoder_hidden_states is not None:
275
+ cross_attention_outputs = self.layer[1](
276
+ hidden_states,
277
+ key_value_states=encoder_hidden_states,
278
+ attention_mask=encoder_attention_mask,
279
+ position_bias=encoder_decoder_position_bias,
280
+ )
281
+ hidden_states = cross_attention_outputs[0]
282
+
283
+ # Keep relative position weights
284
+ attention_outputs = attention_outputs + cross_attention_outputs[1:]
285
+
286
+ # Apply Feed Forward layer
287
+ hidden_states = self.layer[-1](hidden_states)
288
+
289
+ outputs = (hidden_states,)
290
+ outputs = outputs + attention_outputs
291
+
292
+ return outputs # hidden-states, (self-attention position bias), (cross-attention position bias)
293
+
294
+
295
+ class T5Stack(nn.Module, ModuleUtilsMixin):
296
+ def __init__(self, config, embed_tokens):
297
+ super().__init__()
298
+ assert embed_tokens is not None
299
+
300
+ self.config = config
301
+ self.embed_tokens = embed_tokens
302
+ self.is_decoder = config.is_decoder
303
+
304
+ self.block = nn.ModuleList(
305
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
306
+ )
307
+
308
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
309
+ self.dropout = nn.Dropout(config.dropout_rate)
310
+
311
+ def forward(
312
+ self,
313
+ input_ids=None,
314
+ attention_mask=None,
315
+ encoder_hidden_states=None,
316
+ encoder_attention_mask=None,
317
+ ) -> EncoderOutput:
318
+ input_shape = input_ids.size()
319
+ batch_size, seq_length = input_shape
320
+
321
+ inputs_embeds = self.embed_tokens(input_ids)
322
+
323
+ if hasattr(self.config, 'is_bf16') and self.config.is_bf16:
324
+ inputs_embeds = inputs_embeds.to(torch.bfloat16)
325
+
326
+ # Masking
327
+ if attention_mask is None:
328
+ attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device)
329
+
330
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
331
+ encoder_seq_length = encoder_hidden_states.shape[1]
332
+ encoder_attention_mask = torch.ones(
333
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
334
+ )
335
+
336
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
337
+ # ourselves in which case we just need to make it broadcastable to all heads.
338
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
339
+
340
+ # If a 2D or 3D attention mask is provided for the cross-attention
341
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
342
+ if self.is_decoder and encoder_hidden_states is not None:
343
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
344
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
345
+ if encoder_attention_mask is None:
346
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
347
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
348
+ else:
349
+ encoder_extended_attention_mask = None
350
+
351
+ position_bias = None
352
+ encoder_decoder_position_bias = None
353
+
354
+ hidden_states = self.dropout(inputs_embeds)
355
+
356
+ for _, layer_module in enumerate(self.block):
357
+ layer_outputs = layer_module(
358
+ hidden_states,
359
+ attention_mask=extended_attention_mask,
360
+ position_bias=position_bias,
361
+ encoder_hidden_states=encoder_hidden_states,
362
+ encoder_attention_mask=encoder_extended_attention_mask,
363
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
364
+ )
365
+ hidden_states = layer_outputs[0]
366
+
367
+ # We share the position biases between the layers - the first layer store them
368
+ position_bias = layer_outputs[1]
369
+ if self.is_decoder and encoder_hidden_states is not None:
370
+ encoder_decoder_position_bias = layer_outputs[2]
371
+
372
+ hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states)
373
+ hidden_states = self.dropout(hidden_states)
374
+
375
+ return EncoderOutput(
376
+ hidden_states=hidden_states,
377
+ attention_mask=attention_mask,
378
+ )
379
+
380
+
381
+ ###
382
+ # Code from huggingface/twodgirl
383
+ # License: apache-2.0
384
+
385
+ class T5EncoderModel(nn.Module):
386
+ def __init__(self, config: T5Config):
387
+ super().__init__()
388
+ config.is_encoder_decoder = False
389
+ assert not config.tie_word_embeddings
390
+ self.config = config
391
+ self.model_dim = config.d_model
392
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
393
+ encoder_config = copy.deepcopy(config)
394
+ encoder_config.is_decoder = False
395
+ self.encoder = T5Stack(encoder_config, self.shared)
396
+ flux_dev_d_model = 4096
397
+ self.last_layer = nn.Sequential(
398
+ nn.Linear(config.d_model, flux_dev_d_model),
399
+ nn.ReLU(),
400
+ nn.Linear(flux_dev_d_model, flux_dev_d_model)
401
+ )
402
+ self.apply(self._init_weights)
403
+
404
+ def forward(
405
+ self,
406
+ input_ids: Optional[torch.LongTensor] = None,
407
+ attention_mask: Optional[torch.FloatTensor] = None
408
+ ):
409
+ encoder_outputs = self.encoder(
410
+ input_ids=input_ids,
411
+ attention_mask=attention_mask,
412
+ )
413
+
414
+ return self.last_layer(encoder_outputs.hidden_states)
415
+
416
+ def get_input_embeddings(self):
417
+ return self.shared
418
+
419
+ def _init_weights(self, module):
420
+ factor = self.config.initializer_factor # Used for testing weights initialization
421
+ if isinstance(module, T5LayerNorm):
422
+ module.weight.data.fill_(factor * 1.0)
423
+ elif isinstance(module, T5EncoderModel):
424
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
425
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
426
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
427
+ elif isinstance(module, T5DenseGatedActDense):
428
+ d_ff, d_model = module.wi_0.weight.data.size()
429
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
430
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
431
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
432
+ elif isinstance(module, T5Attention):
433
+ d_model = self.config.d_model
434
+ key_value_proj_dim = self.config.d_kv
435
+ n_heads = self.config.num_heads
436
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
437
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
438
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
439
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
440
+ if hasattr(module, "relative_attention_bias"):
441
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
442
+
443
+ class PretrainedTextEncoder(PreTrainedModel):
444
+ # Call by:
445
+ # t5 = PretrainedTextEncoder(t5_config, T5EncoderModel(t5_config)).to(dtype=torch.float16)
446
+ # t5.load_model('text_encoder_2.safetensors')
447
+ # ...
448
+ # FluxPipeline.from_pretrained(..., text_encoder_2=t5)
449
+ def __init__(self, config, model):
450
+ super().__init__(config, model)
451
+ self.model = model
452
+
453
+ def load_model(self, filepath):
454
+ load_model(self.model, filepath)
455
+
456
+ def forward(self, x, output_hidden_states=False):
457
+ return self.model(x),
458
+
459
+ t5_config = T5Config(d_model=4096 // 2,
460
+ dd_ff=10240 // 2,
461
+ num_layers=2,
462
+ num_heads=32,
463
+ is_gated_act=True,
464
+ tie_word_embeddings=False,
465
+ max_seq_len=512)
text_encoder_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abe00664256642069f8f7db268933d3e7dbd545b04b675cf8355c2e6f3ac6cca
3
+ size 598817664