helboukkouri commited on
Commit
925f56a
1 Parent(s): 30397d8

Create modeling_character_bert.py

Browse files
Files changed (1) hide show
  1. modeling_character_bert.py +1954 -0
modeling_character_bert.py ADDED
@@ -0,0 +1,1954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright Hicham EL BOUKKOURI, Olivier FERRET, Thomas LAVERGNE, Hiroshi NOJI,
3
+ # Pierre ZWEIGENBAUM, Junichi TSUJII, The HuggingFace Inc. and AllenNLP teams.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """
18
+ PyTorch CharacterBERT model: this is a variant of BERT that uses the CharacterCNN module from ELMo instead of a
19
+ WordPiece embedding matrix. See: “CharacterBERT: Reconciling ELMo and BERT for Word-Level Open-Vocabulary
20
+ Representations From Characters“ https://www.aclweb.org/anthology/2020.coling-main.609/
21
+ """
22
+
23
+ import math
24
+ import warnings
25
+ from dataclasses import dataclass
26
+ from typing import Callable, Optional, Tuple
27
+
28
+ import torch
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import CrossEntropyLoss, MSELoss
32
+
33
+ from transformers.activations import ACT2FN
34
+ from transformers.file_utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ replace_return_docstrings,
40
+ )
41
+ from transformers.modeling_outputs import (
42
+ BaseModelOutputWithPastAndCrossAttentions,
43
+ BaseModelOutputWithPoolingAndCrossAttentions,
44
+ CausalLMOutputWithCrossAttentions,
45
+ MaskedLMOutput,
46
+ MultipleChoiceModelOutput,
47
+ NextSentencePredictorOutput,
48
+ QuestionAnsweringModelOutput,
49
+ SequenceClassifierOutput,
50
+ TokenClassifierOutput,
51
+ )
52
+ from transformers.modeling_utils import (
53
+ PreTrainedModel,
54
+ apply_chunking_to_forward,
55
+ find_pruneable_heads_and_indices,
56
+ prune_linear_layer,
57
+ )
58
+ from transformers.utils import logging
59
+ from .configuration_character_bert import CharacterBertConfig
60
+ from .tokenization_character_bert import CharacterMapper
61
+
62
+
63
+ logger = logging.get_logger(__name__)
64
+
65
+ _CHECKPOINT_FOR_DOC = "helboukkouri/character-bert"
66
+ _CONFIG_FOR_DOC = "CharacterBertConfig"
67
+ _TOKENIZER_FOR_DOC = "CharacterBertTokenizer"
68
+
69
+ CHARACTER_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
70
+ "helboukkouri/character-bert",
71
+ "helboukkouri/character-bert-medical",
72
+ # See all CharacterBERT models at https://huggingface.co/models?filter=character_bert
73
+ ]
74
+
75
+
76
+ # NOTE: the following class is taken from:
77
+ # https://github.com/allenai/allennlp/blob/main/allennlp/modules/highway.py
78
+ class Highway(torch.nn.Module):
79
+ """
80
+ A `Highway layer <https://arxiv.org/abs/1505.00387)>`__ does a gated combination of a linear transformation and a
81
+ non-linear transformation of its input. :math:`y = g * x + (1 - g) * f(A(x))`, where :math:`A` is a linear
82
+ transformation, :math:`f` is an element-wise non-linearity, and :math:`g` is an element-wise gate, computed as
83
+ :math:`sigmoid(B(x))`.
84
+
85
+ This module will apply a fixed number of highway layers to its input, returning the final result.
86
+
87
+ # Parameters
88
+
89
+ input_dim : `int`, required The dimensionality of :math:`x`. We assume the input has shape `(batch_size, ...,
90
+ input_dim)`. num_layers : `int`, optional (default=`1`) The number of highway layers to apply to the input.
91
+ activation : `Callable[[torch.Tensor], torch.Tensor]`, optional (default=`torch.nn.functional.relu`) The
92
+ non-linearity to use in the highway layers.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ input_dim: int,
98
+ num_layers: int = 1,
99
+ activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
100
+ ) -> None:
101
+ super().__init__()
102
+ self._input_dim = input_dim
103
+ self._layers = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)])
104
+ self._activation = activation
105
+ for layer in self._layers:
106
+ # We should bias the highway layer to just carry its input forward. We do that by
107
+ # setting the bias on `B(x)` to be positive, because that means `g` will be biased to
108
+ # be high, so we will carry the input forward. The bias on `B(x)` is the second half
109
+ # of the bias vector in each Linear layer.
110
+ layer.bias[input_dim:].data.fill_(1)
111
+
112
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
113
+ current_input = inputs
114
+ for layer in self._layers:
115
+ projected_input = layer(current_input)
116
+ linear_part = current_input
117
+ # NOTE: if you modify this, think about whether you should modify the initialization
118
+ # above, too.
119
+ nonlinear_part, gate = projected_input.chunk(2, dim=-1)
120
+ nonlinear_part = self._activation(nonlinear_part)
121
+ gate = torch.sigmoid(gate)
122
+ current_input = gate * linear_part + (1 - gate) * nonlinear_part
123
+ return current_input
124
+
125
+
126
+ # NOTE: The CharacterCnn was adapted from `_ElmoCharacterEncoder`:
127
+ # https://github.com/allenai/allennlp/blob/main/allennlp/modules/elmo.py#L254
128
+ class CharacterCnn(torch.nn.Module):
129
+ """
130
+ Computes context insensitive token representation using multiple CNNs. This embedder has input character ids of
131
+ size (batch_size, sequence_length, 50) and returns (batch_size, sequence_length, hidden_size), where hidden_size is
132
+ typically 768.
133
+ """
134
+
135
+ def __init__(self, config):
136
+ super().__init__()
137
+ self.character_embeddings_dim = config.character_embeddings_dim
138
+ self.cnn_activation = config.cnn_activation
139
+ self.cnn_filters = config.cnn_filters
140
+ self.num_highway_layers = config.num_highway_layers
141
+ self.max_word_length = config.max_word_length
142
+ self.hidden_size = config.hidden_size
143
+ # NOTE: this is the 256 possible utf-8 bytes + special slots for the
144
+ # [CLS]/[SEP]/[PAD]/[MASK] characters as well as beginning/end of
145
+ # word symbols and character padding for short words -> total of 263
146
+ self.character_vocab_size = 263
147
+ self._init_weights()
148
+
149
+ def get_output_dim(self):
150
+ return self.hidden_size
151
+
152
+ def _init_weights(self):
153
+ self._init_char_embedding()
154
+ self._init_cnn_weights()
155
+ self._init_highway()
156
+ self._init_projection()
157
+
158
+ def _init_char_embedding(self):
159
+ weights = torch.empty((self.character_vocab_size, self.character_embeddings_dim))
160
+ nn.init.normal_(weights)
161
+ weights[0].fill_(0.0) # token padding
162
+ weights[CharacterMapper.padding_character + 1].fill_(0.0) # character padding
163
+ self._char_embedding_weights = torch.nn.Parameter(torch.FloatTensor(weights), requires_grad=True)
164
+
165
+ def _init_cnn_weights(self):
166
+ convolutions = []
167
+ for i, (width, num) in enumerate(self.cnn_filters):
168
+ conv = torch.nn.Conv1d(
169
+ in_channels=self.character_embeddings_dim, out_channels=num, kernel_size=width, bias=True
170
+ )
171
+ conv.weight.requires_grad = True
172
+ conv.bias.requires_grad = True
173
+ convolutions.append(conv)
174
+ self.add_module(f"char_conv_{i}", conv)
175
+ self._convolutions = convolutions
176
+
177
+ def _init_highway(self):
178
+ # the highway layers have same dimensionality as the number of cnn filters
179
+ n_filters = sum(f[1] for f in self.cnn_filters)
180
+ self._highways = Highway(n_filters, self.num_highway_layers, activation=nn.functional.relu)
181
+ for k in range(self.num_highway_layers):
182
+ # The AllenNLP highway is one matrix multplication with concatenation of
183
+ # transform and carry weights.
184
+ self._highways._layers[k].weight.requires_grad = True
185
+ self._highways._layers[k].bias.requires_grad = True
186
+
187
+ def _init_projection(self):
188
+ n_filters = sum(f[1] for f in self.cnn_filters)
189
+ self._projection = torch.nn.Linear(n_filters, self.hidden_size, bias=True)
190
+ self._projection.weight.requires_grad = True
191
+ self._projection.bias.requires_grad = True
192
+
193
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
194
+ """
195
+ Compute context insensitive token embeddings from characters. # Parameters inputs : `torch.Tensor` Shape
196
+ `(batch_size, sequence_length, 50)` of character ids representing the current batch. # Returns output:
197
+ `torch.Tensor` Shape `(batch_size, sequence_length, embedding_dim)` tensor with context insensitive token
198
+ representations.
199
+ """
200
+
201
+ # character embeddings
202
+ # (batch_size * sequence_length, max_word_length, embed_dim)
203
+ character_embedding = torch.nn.functional.embedding(
204
+ inputs.view(-1, self.max_word_length), self._char_embedding_weights
205
+ )
206
+
207
+ # CNN representations
208
+ if self.cnn_activation == "tanh":
209
+ activation = torch.tanh
210
+ elif self.cnn_activation == "relu":
211
+ activation = torch.nn.functional.relu
212
+ else:
213
+ raise Exception("ConfigurationError: Unknown activation")
214
+
215
+ # (batch_size * sequence_length, embed_dim, max_word_length)
216
+ character_embedding = torch.transpose(character_embedding, 1, 2)
217
+ convs = []
218
+ for i in range(len(self._convolutions)):
219
+ conv = getattr(self, "char_conv_{}".format(i))
220
+ convolved = conv(character_embedding)
221
+ # (batch_size * sequence_length, n_filters for this width)
222
+ convolved, _ = torch.max(convolved, dim=-1)
223
+ convolved = activation(convolved)
224
+ convs.append(convolved)
225
+
226
+ # (batch_size * sequence_length, n_filters)
227
+ token_embedding = torch.cat(convs, dim=-1)
228
+
229
+ # apply the highway layers (batch_size * sequence_length, n_filters)
230
+ token_embedding = self._highways(token_embedding)
231
+
232
+ # final projection (batch_size * sequence_length, embedding_dim)
233
+ token_embedding = self._projection(token_embedding)
234
+
235
+ # reshape to (batch_size, sequence_length, embedding_dim)
236
+ batch_size, sequence_length, _ = inputs.size()
237
+ output = token_embedding.view(batch_size, sequence_length, -1)
238
+
239
+ return output
240
+
241
+
242
+ class CharacterBertEmbeddings(nn.Module):
243
+ """Construct the embeddings from word, position and token_type embeddings."""
244
+
245
+ def __init__(self, config):
246
+ super().__init__()
247
+ self.word_embeddings = CharacterCnn(config)
248
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
249
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
250
+
251
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
252
+ # any TensorFlow checkpoint file
253
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
254
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
255
+
256
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
257
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
258
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
259
+
260
+ def forward(
261
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
262
+ ):
263
+ if input_ids is not None:
264
+ input_shape = input_ids[:, :, 0].size()
265
+ else:
266
+ input_shape = inputs_embeds.size()[:-1]
267
+
268
+ seq_length = input_shape[1]
269
+
270
+ if position_ids is None:
271
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
272
+
273
+ if token_type_ids is None:
274
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
275
+
276
+ if inputs_embeds is None:
277
+ inputs_embeds = self.word_embeddings(input_ids)
278
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
279
+
280
+ embeddings = inputs_embeds + token_type_embeddings
281
+ if self.position_embedding_type == "absolute":
282
+ position_embeddings = self.position_embeddings(position_ids)
283
+ embeddings += position_embeddings
284
+ embeddings = self.LayerNorm(embeddings)
285
+ embeddings = self.dropout(embeddings)
286
+ return embeddings
287
+
288
+
289
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->CharacterBert
290
+ class CharacterBertSelfAttention(nn.Module):
291
+ def __init__(self, config, position_embedding_type=None):
292
+ super().__init__()
293
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
294
+ raise ValueError(
295
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
296
+ f"heads ({config.num_attention_heads})"
297
+ )
298
+
299
+ self.num_attention_heads = config.num_attention_heads
300
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
301
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
302
+
303
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
304
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
305
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
306
+
307
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
308
+ self.position_embedding_type = position_embedding_type or getattr(
309
+ config, "position_embedding_type", "absolute"
310
+ )
311
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
312
+ self.max_position_embeddings = config.max_position_embeddings
313
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
314
+
315
+ self.is_decoder = config.is_decoder
316
+
317
+ def transpose_for_scores(self, x):
318
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
319
+ x = x.view(*new_x_shape)
320
+ return x.permute(0, 2, 1, 3)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states,
325
+ attention_mask=None,
326
+ head_mask=None,
327
+ encoder_hidden_states=None,
328
+ encoder_attention_mask=None,
329
+ past_key_value=None,
330
+ output_attentions=False,
331
+ ):
332
+ mixed_query_layer = self.query(hidden_states)
333
+
334
+ # If this is instantiated as a cross-attention module, the keys
335
+ # and values come from an encoder; the attention mask needs to be
336
+ # such that the encoder's padding tokens are not attended to.
337
+ is_cross_attention = encoder_hidden_states is not None
338
+
339
+ if is_cross_attention and past_key_value is not None:
340
+ # reuse k,v, cross_attentions
341
+ key_layer = past_key_value[0]
342
+ value_layer = past_key_value[1]
343
+ attention_mask = encoder_attention_mask
344
+ elif is_cross_attention:
345
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
346
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
347
+ attention_mask = encoder_attention_mask
348
+ elif past_key_value is not None:
349
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
350
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
351
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
352
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
353
+ else:
354
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
355
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
356
+
357
+ query_layer = self.transpose_for_scores(mixed_query_layer)
358
+
359
+ if self.is_decoder:
360
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
361
+ # Further calls to cross_attention layer can then reuse all cross-attention
362
+ # key/value_states (first "if" case)
363
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
364
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
365
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
366
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
367
+ past_key_value = (key_layer, value_layer)
368
+
369
+ # Take the dot product between "query" and "key" to get the raw attention scores.
370
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
371
+
372
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
373
+ seq_length = hidden_states.size()[1]
374
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
375
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
376
+ distance = position_ids_l - position_ids_r
377
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
378
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
379
+
380
+ if self.position_embedding_type == "relative_key":
381
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
382
+ attention_scores = attention_scores + relative_position_scores
383
+ elif self.position_embedding_type == "relative_key_query":
384
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
385
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
386
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
387
+
388
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
389
+ if attention_mask is not None:
390
+ # Apply the attention mask is (precomputed for all layers in CharacterBertModel forward() function)
391
+ attention_scores = attention_scores + attention_mask
392
+
393
+ # Normalize the attention scores to probabilities.
394
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
395
+
396
+ # This is actually dropping out entire tokens to attend to, which might
397
+ # seem a bit unusual, but is taken from the original Transformer paper.
398
+ attention_probs = self.dropout(attention_probs)
399
+
400
+ # Mask heads if we want to
401
+ if head_mask is not None:
402
+ attention_probs = attention_probs * head_mask
403
+
404
+ context_layer = torch.matmul(attention_probs, value_layer)
405
+
406
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
407
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
408
+ context_layer = context_layer.view(*new_context_layer_shape)
409
+
410
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
411
+
412
+ if self.is_decoder:
413
+ outputs = outputs + (past_key_value,)
414
+ return outputs
415
+
416
+
417
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->CharacterBert
418
+ class CharacterBertSelfOutput(nn.Module):
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
422
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
423
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
424
+
425
+ def forward(self, hidden_states, input_tensor):
426
+ hidden_states = self.dense(hidden_states)
427
+ hidden_states = self.dropout(hidden_states)
428
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
429
+ return hidden_states
430
+
431
+
432
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->CharacterBert
433
+ class CharacterBertAttention(nn.Module):
434
+ def __init__(self, config, position_embedding_type=None):
435
+ super().__init__()
436
+ self.self = CharacterBertSelfAttention(config, position_embedding_type=position_embedding_type)
437
+ self.output = CharacterBertSelfOutput(config)
438
+ self.pruned_heads = set()
439
+
440
+ def prune_heads(self, heads):
441
+ if len(heads) == 0:
442
+ return
443
+ heads, index = find_pruneable_heads_and_indices(
444
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
445
+ )
446
+
447
+ # Prune linear layers
448
+ self.self.query = prune_linear_layer(self.self.query, index)
449
+ self.self.key = prune_linear_layer(self.self.key, index)
450
+ self.self.value = prune_linear_layer(self.self.value, index)
451
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
452
+
453
+ # Update hyper params and store pruned heads
454
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
455
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
456
+ self.pruned_heads = self.pruned_heads.union(heads)
457
+
458
+ def forward(
459
+ self,
460
+ hidden_states,
461
+ attention_mask=None,
462
+ head_mask=None,
463
+ encoder_hidden_states=None,
464
+ encoder_attention_mask=None,
465
+ past_key_value=None,
466
+ output_attentions=False,
467
+ ):
468
+ self_outputs = self.self(
469
+ hidden_states,
470
+ attention_mask,
471
+ head_mask,
472
+ encoder_hidden_states,
473
+ encoder_attention_mask,
474
+ past_key_value,
475
+ output_attentions,
476
+ )
477
+ attention_output = self.output(self_outputs[0], hidden_states)
478
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
479
+ return outputs
480
+
481
+
482
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CharacterBert
483
+ class CharacterBertIntermediate(nn.Module):
484
+ def __init__(self, config):
485
+ super().__init__()
486
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
487
+ if isinstance(config.hidden_act, str):
488
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
489
+ else:
490
+ self.intermediate_act_fn = config.hidden_act
491
+
492
+ def forward(self, hidden_states):
493
+ hidden_states = self.dense(hidden_states)
494
+ hidden_states = self.intermediate_act_fn(hidden_states)
495
+ return hidden_states
496
+
497
+
498
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CharacterBert
499
+ class CharacterBertOutput(nn.Module):
500
+ def __init__(self, config):
501
+ super().__init__()
502
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
503
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
504
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
505
+
506
+ def forward(self, hidden_states, input_tensor):
507
+ hidden_states = self.dense(hidden_states)
508
+ hidden_states = self.dropout(hidden_states)
509
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
510
+ return hidden_states
511
+
512
+
513
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->CharacterBert
514
+ class CharacterBertLayer(nn.Module):
515
+ def __init__(self, config):
516
+ super().__init__()
517
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
518
+ self.seq_len_dim = 1
519
+ self.attention = CharacterBertAttention(config)
520
+ self.is_decoder = config.is_decoder
521
+ self.add_cross_attention = config.add_cross_attention
522
+ if self.add_cross_attention:
523
+ if not self.is_decoder:
524
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
525
+ self.crossattention = CharacterBertAttention(config, position_embedding_type="absolute")
526
+ self.intermediate = CharacterBertIntermediate(config)
527
+ self.output = CharacterBertOutput(config)
528
+
529
+ def forward(
530
+ self,
531
+ hidden_states,
532
+ attention_mask=None,
533
+ head_mask=None,
534
+ encoder_hidden_states=None,
535
+ encoder_attention_mask=None,
536
+ past_key_value=None,
537
+ output_attentions=False,
538
+ ):
539
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
540
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
541
+ self_attention_outputs = self.attention(
542
+ hidden_states,
543
+ attention_mask,
544
+ head_mask,
545
+ output_attentions=output_attentions,
546
+ past_key_value=self_attn_past_key_value,
547
+ )
548
+ attention_output = self_attention_outputs[0]
549
+
550
+ # if decoder, the last output is tuple of self-attn cache
551
+ if self.is_decoder:
552
+ outputs = self_attention_outputs[1:-1]
553
+ present_key_value = self_attention_outputs[-1]
554
+ else:
555
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
556
+
557
+ cross_attn_present_key_value = None
558
+ if self.is_decoder and encoder_hidden_states is not None:
559
+ if not hasattr(self, "crossattention"):
560
+ raise ValueError(
561
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
562
+ )
563
+
564
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
565
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
566
+ cross_attention_outputs = self.crossattention(
567
+ attention_output,
568
+ attention_mask,
569
+ head_mask,
570
+ encoder_hidden_states,
571
+ encoder_attention_mask,
572
+ cross_attn_past_key_value,
573
+ output_attentions,
574
+ )
575
+ attention_output = cross_attention_outputs[0]
576
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
577
+
578
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
579
+ cross_attn_present_key_value = cross_attention_outputs[-1]
580
+ present_key_value = present_key_value + cross_attn_present_key_value
581
+
582
+ layer_output = apply_chunking_to_forward(
583
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
584
+ )
585
+ outputs = (layer_output,) + outputs
586
+
587
+ # if decoder, return the attn key/values as the last output
588
+ if self.is_decoder:
589
+ outputs = outputs + (present_key_value,)
590
+
591
+ return outputs
592
+
593
+ def feed_forward_chunk(self, attention_output):
594
+ intermediate_output = self.intermediate(attention_output)
595
+ layer_output = self.output(intermediate_output, attention_output)
596
+ return layer_output
597
+
598
+
599
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->CharacterBert
600
+ class CharacterBertEncoder(nn.Module):
601
+ def __init__(self, config):
602
+ super().__init__()
603
+ self.config = config
604
+ self.layer = nn.ModuleList([CharacterBertLayer(config) for _ in range(config.num_hidden_layers)])
605
+ self.gradient_checkpointing = False
606
+
607
+ def forward(
608
+ self,
609
+ hidden_states,
610
+ attention_mask=None,
611
+ head_mask=None,
612
+ encoder_hidden_states=None,
613
+ encoder_attention_mask=None,
614
+ past_key_values=None,
615
+ use_cache=None,
616
+ output_attentions=False,
617
+ output_hidden_states=False,
618
+ return_dict=True,
619
+ ):
620
+ all_hidden_states = () if output_hidden_states else None
621
+ all_self_attentions = () if output_attentions else None
622
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
623
+
624
+ next_decoder_cache = () if use_cache else None
625
+ for i, layer_module in enumerate(self.layer):
626
+ if output_hidden_states:
627
+ all_hidden_states = all_hidden_states + (hidden_states,)
628
+
629
+ layer_head_mask = head_mask[i] if head_mask is not None else None
630
+ past_key_value = past_key_values[i] if past_key_values is not None else None
631
+
632
+ if self.gradient_checkpointing and self.training:
633
+
634
+ if use_cache:
635
+ logger.warning(
636
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
637
+ )
638
+ use_cache = False
639
+
640
+ def create_custom_forward(module):
641
+ def custom_forward(*inputs):
642
+ return module(*inputs, past_key_value, output_attentions)
643
+
644
+ return custom_forward
645
+
646
+ layer_outputs = torch.utils.checkpoint.checkpoint(
647
+ create_custom_forward(layer_module),
648
+ hidden_states,
649
+ attention_mask,
650
+ layer_head_mask,
651
+ encoder_hidden_states,
652
+ encoder_attention_mask,
653
+ )
654
+ else:
655
+ layer_outputs = layer_module(
656
+ hidden_states,
657
+ attention_mask,
658
+ layer_head_mask,
659
+ encoder_hidden_states,
660
+ encoder_attention_mask,
661
+ past_key_value,
662
+ output_attentions,
663
+ )
664
+
665
+ hidden_states = layer_outputs[0]
666
+ if use_cache:
667
+ next_decoder_cache += (layer_outputs[-1],)
668
+ if output_attentions:
669
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
670
+ if self.config.add_cross_attention:
671
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
672
+
673
+ if output_hidden_states:
674
+ all_hidden_states = all_hidden_states + (hidden_states,)
675
+
676
+ if not return_dict:
677
+ return tuple(
678
+ v
679
+ for v in [
680
+ hidden_states,
681
+ next_decoder_cache,
682
+ all_hidden_states,
683
+ all_self_attentions,
684
+ all_cross_attentions,
685
+ ]
686
+ if v is not None
687
+ )
688
+ return BaseModelOutputWithPastAndCrossAttentions(
689
+ last_hidden_state=hidden_states,
690
+ past_key_values=next_decoder_cache,
691
+ hidden_states=all_hidden_states,
692
+ attentions=all_self_attentions,
693
+ cross_attentions=all_cross_attentions,
694
+ )
695
+
696
+
697
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->CharacterBert
698
+ class CharacterBertPooler(nn.Module):
699
+ def __init__(self, config):
700
+ super().__init__()
701
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
702
+ self.activation = nn.Tanh()
703
+
704
+ def forward(self, hidden_states):
705
+ # We "pool" the model by simply taking the hidden state corresponding
706
+ # to the first token.
707
+ first_token_tensor = hidden_states[:, 0]
708
+ pooled_output = self.dense(first_token_tensor)
709
+ pooled_output = self.activation(pooled_output)
710
+ return pooled_output
711
+
712
+
713
+ # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->CharacterBert
714
+ class CharacterBertPredictionHeadTransform(nn.Module):
715
+ def __init__(self, config):
716
+ super().__init__()
717
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
718
+ if isinstance(config.hidden_act, str):
719
+ self.transform_act_fn = ACT2FN[config.hidden_act]
720
+ else:
721
+ self.transform_act_fn = config.hidden_act
722
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
723
+
724
+ def forward(self, hidden_states):
725
+ hidden_states = self.dense(hidden_states)
726
+ hidden_states = self.transform_act_fn(hidden_states)
727
+ hidden_states = self.LayerNorm(hidden_states)
728
+ return hidden_states
729
+
730
+
731
+ class CharacterBertLMPredictionHead(nn.Module):
732
+ def __init__(self, config):
733
+ super().__init__()
734
+ self.transform = CharacterBertPredictionHeadTransform(config)
735
+
736
+ # The output weights are the same as the input embeddings, but there is
737
+ # an output-only bias for each token.
738
+ self.decoder = nn.Linear(config.hidden_size, config.mlm_vocab_size, bias=False)
739
+
740
+ self.bias = nn.Parameter(torch.zeros(config.mlm_vocab_size))
741
+
742
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
743
+ self.decoder.bias = self.bias
744
+
745
+ def forward(self, hidden_states):
746
+ hidden_states = self.transform(hidden_states)
747
+ hidden_states = self.decoder(hidden_states)
748
+ return hidden_states
749
+
750
+
751
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->CharacterBert
752
+ class CharacterBertOnlyMLMHead(nn.Module):
753
+ def __init__(self, config):
754
+ super().__init__()
755
+ self.predictions = CharacterBertLMPredictionHead(config)
756
+
757
+ def forward(self, sequence_output):
758
+ prediction_scores = self.predictions(sequence_output)
759
+ return prediction_scores
760
+
761
+
762
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->CharacterBert
763
+ class CharacterBertOnlyNSPHead(nn.Module):
764
+ def __init__(self, config):
765
+ super().__init__()
766
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
767
+
768
+ def forward(self, pooled_output):
769
+ seq_relationship_score = self.seq_relationship(pooled_output)
770
+ return seq_relationship_score
771
+
772
+
773
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->CharacterBert
774
+ class CharacterBertPreTrainingHeads(nn.Module):
775
+ def __init__(self, config):
776
+ super().__init__()
777
+ self.predictions = CharacterBertLMPredictionHead(config)
778
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
779
+
780
+ def forward(self, sequence_output, pooled_output):
781
+ prediction_scores = self.predictions(sequence_output)
782
+ seq_relationship_score = self.seq_relationship(pooled_output)
783
+ return prediction_scores, seq_relationship_score
784
+
785
+
786
+ class CharacterBertPreTrainedModel(PreTrainedModel):
787
+ """
788
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
789
+ models.
790
+ """
791
+
792
+ config_class = CharacterBertConfig
793
+ load_tf_weights = None
794
+ base_model_prefix = "character_bert"
795
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
796
+
797
+ def _init_weights(self, module):
798
+ """Initialize the weights"""
799
+ if isinstance(module, CharacterCnn):
800
+ # We need to handle the case of these parameters since it is not an actual module
801
+ module._char_embedding_weights.data.normal_()
802
+ # token padding
803
+ module._char_embedding_weights.data[0].fill_(0.0)
804
+ # character padding
805
+ module._char_embedding_weights.data[CharacterMapper.padding_character + 1].fill_(0.0)
806
+ if isinstance(module, nn.Linear):
807
+ # Slightly different from the TF version which uses truncated_normal for initialization
808
+ # cf https://github.com/pytorch/pytorch/pull/5617
809
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
810
+ if module.bias is not None:
811
+ module.bias.data.zero_()
812
+ elif isinstance(module, nn.Embedding):
813
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
814
+ if module.padding_idx is not None:
815
+ module.weight.data[module.padding_idx].zero_()
816
+ elif isinstance(module, nn.LayerNorm):
817
+ module.bias.data.zero_()
818
+ module.weight.data.fill_(1.0)
819
+
820
+
821
+ @dataclass
822
+ # Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->CharacterBert
823
+ class CharacterBertForPreTrainingOutput(ModelOutput):
824
+ """
825
+ Output type of [`CharacterBertForPreTraining`].
826
+
827
+ Args:
828
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
829
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
830
+ (classification) loss.
831
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
832
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
833
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
834
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
835
+ before SoftMax).
836
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
837
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
838
+ shape `(batch_size, sequence_length, hidden_size)`.
839
+
840
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
841
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
842
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
843
+ sequence_length)`.
844
+
845
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
846
+ heads.
847
+ """
848
+
849
+ loss: Optional[torch.FloatTensor] = None
850
+ prediction_logits: torch.FloatTensor = None
851
+ seq_relationship_logits: torch.FloatTensor = None
852
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
853
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
854
+
855
+
856
+ CHARACTER_BERT_START_DOCSTRING = r"""
857
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
858
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
859
+ behavior.
860
+
861
+ Parameters:
862
+ config (:
863
+ class:*~transformers.CharacterBertConfig*): Model configuration class with all the parameters of the model.
864
+ Initializing with a config file does not load the weights associated with the model, only the
865
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model
866
+ weights.
867
+ """
868
+
869
+ CHARACTER_BERT_INPUTS_DOCSTRING = r"""
870
+ Args:
871
+ input_ids (`torch.LongTensor` of shape `{0}`):
872
+ Indices of input sequence tokens.
873
+
874
+ Indices can be obtained using [`CharacterBertTokenizer`]. See
875
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
876
+ details.
877
+
878
+ [What are input IDs?](../glossary#input-ids)
879
+ attention_mask (`torch.FloatTensor` of shape `{1}`, *optional*):
880
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
881
+
882
+ - 1 for tokens that are **not masked**,
883
+ - 0 for tokens that are **masked**.
884
+
885
+ [What are attention masks?](../glossary#attention-mask)
886
+ token_type_ids (`torch.LongTensor` of shape `{1}`, *optional*):
887
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
888
+
889
+ - 0 corresponds to a *sentence A* token,
890
+ - 1 corresponds to a *sentence B* token.
891
+
892
+ [What are token type IDs?](../glossary#token-type-ids)
893
+ position_ids (`torch.LongTensor` of shape `{1}`, *optional*):
894
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`.
895
+
896
+ [What are position IDs?](../glossary#position-ids)
897
+ head_mask (:
898
+ obj:*torch.FloatTensor* of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask
899
+ to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
900
+
901
+ - 1 indicates the head is **not masked**,
902
+ - 0 indicates the head is **masked**.
903
+
904
+ inputs_embeds (:
905
+ obj:*torch.FloatTensor* of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
906
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
907
+ This is useful if you want more control over how to convert *input_ids* indices into associated vectors
908
+ than the model's internal embedding lookup matrix.
909
+ output_attentions (`bool`, *optional*):
910
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
911
+ tensors for more detail.
912
+ output_hidden_states (`bool`, *optional*):
913
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
914
+ more detail.
915
+ return_dict (`bool`, *optional*):
916
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
917
+ """
918
+
919
+
920
+ @add_start_docstrings(
921
+ "The bare CharacterBERT Model transformer outputting raw hidden-states without any specific head on top.",
922
+ CHARACTER_BERT_START_DOCSTRING,
923
+ )
924
+ class CharacterBertModel(CharacterBertPreTrainedModel):
925
+ """
926
+
927
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
928
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
929
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
930
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
931
+
932
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration
933
+ set to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
934
+ argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an
935
+ input to the forward pass.
936
+ """
937
+
938
+ def __init__(self, config, add_pooling_layer=True):
939
+ super().__init__(config)
940
+ self.config = config
941
+
942
+ self.embeddings = CharacterBertEmbeddings(config)
943
+ self.encoder = CharacterBertEncoder(config)
944
+
945
+ self.pooler = CharacterBertPooler(config) if add_pooling_layer else None
946
+
947
+ self.init_weights()
948
+
949
+ def get_input_embeddings(self):
950
+ return self.embeddings.word_embeddings
951
+
952
+ def set_input_embeddings(self, value):
953
+ self.embeddings.word_embeddings = value
954
+
955
+ def resize_token_embeddings(self, *args, **kwargs):
956
+ raise NotImplementedError("Cannot resize CharacterBERT's token embeddings.")
957
+
958
+ def _prune_heads(self, heads_to_prune):
959
+ """
960
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
961
+ class PreTrainedModel
962
+ """
963
+ for layer, heads in heads_to_prune.items():
964
+ self.encoder.layer[layer].attention.prune_heads(heads)
965
+
966
+ @add_start_docstrings_to_model_forward(
967
+ CHARACTER_BERT_INPUTS_DOCSTRING.format(
968
+ "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)"
969
+ )
970
+ )
971
+ @add_code_sample_docstrings(
972
+ processor_class=_TOKENIZER_FOR_DOC,
973
+ checkpoint=_CHECKPOINT_FOR_DOC,
974
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
975
+ config_class=_CONFIG_FOR_DOC,
976
+ )
977
+ def forward(
978
+ self,
979
+ input_ids=None,
980
+ attention_mask=None,
981
+ token_type_ids=None,
982
+ position_ids=None,
983
+ head_mask=None,
984
+ inputs_embeds=None,
985
+ encoder_hidden_states=None,
986
+ encoder_attention_mask=None,
987
+ past_key_values=None,
988
+ use_cache=None,
989
+ output_attentions=None,
990
+ output_hidden_states=None,
991
+ return_dict=None,
992
+ ):
993
+ r"""
994
+ encoder_hidden_states (:
995
+ obj:*torch.FloatTensor* of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence
996
+ of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
997
+ is configured as a decoder.
998
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
999
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1000
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1001
+
1002
+ - 1 for tokens that are **not masked**,
1003
+ - 0 for tokens that are **masked**.
1004
+ past_key_values (:
1005
+ obj:*tuple(tuple(torch.FloatTensor))* of length `config.n_layers` with each tuple having 4 tensors of
1006
+ shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key
1007
+ and value hidden states of the attention blocks. Can be used to speed up decoding. If
1008
+ `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
1009
+ (those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
1010
+ instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1011
+ use_cache (`bool`, *optional*):
1012
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up
1013
+ decoding (see `past_key_values`).
1014
+ """
1015
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1016
+ output_hidden_states = (
1017
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1018
+ )
1019
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1020
+
1021
+ if self.config.is_decoder:
1022
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1023
+ else:
1024
+ use_cache = False
1025
+
1026
+ if input_ids is not None and inputs_embeds is not None:
1027
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1028
+ elif input_ids is not None:
1029
+ input_shape = input_ids.size()[:-1]
1030
+ batch_size, seq_length = input_shape
1031
+ elif inputs_embeds is not None:
1032
+ input_shape = inputs_embeds.size()[:-1]
1033
+ batch_size, seq_length = input_shape
1034
+ else:
1035
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1036
+
1037
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1038
+
1039
+ # past_key_values_length
1040
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1041
+
1042
+ if attention_mask is None:
1043
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1044
+ if token_type_ids is None:
1045
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1046
+
1047
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1048
+ # ourselves in which case we just need to make it broadcastable to all heads.
1049
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
1050
+
1051
+ # If a 2D or 3D attention mask is provided for the cross-attention
1052
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1053
+ if self.config.is_decoder and encoder_hidden_states is not None:
1054
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1055
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1056
+ if encoder_attention_mask is None:
1057
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1058
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1059
+ else:
1060
+ encoder_extended_attention_mask = None
1061
+
1062
+ # Prepare head mask if needed
1063
+ # 1.0 in head_mask indicate we keep the head
1064
+ # attention_probs has shape bsz x n_heads x N x N
1065
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1066
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1067
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1068
+
1069
+ embedding_output = self.embeddings(
1070
+ input_ids=input_ids,
1071
+ position_ids=position_ids,
1072
+ token_type_ids=token_type_ids,
1073
+ inputs_embeds=inputs_embeds,
1074
+ past_key_values_length=past_key_values_length,
1075
+ )
1076
+ encoder_outputs = self.encoder(
1077
+ embedding_output,
1078
+ attention_mask=extended_attention_mask,
1079
+ head_mask=head_mask,
1080
+ encoder_hidden_states=encoder_hidden_states,
1081
+ encoder_attention_mask=encoder_extended_attention_mask,
1082
+ past_key_values=past_key_values,
1083
+ use_cache=use_cache,
1084
+ output_attentions=output_attentions,
1085
+ output_hidden_states=output_hidden_states,
1086
+ return_dict=return_dict,
1087
+ )
1088
+ sequence_output = encoder_outputs[0]
1089
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1090
+
1091
+ if not return_dict:
1092
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1093
+
1094
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1095
+ last_hidden_state=sequence_output,
1096
+ pooler_output=pooled_output,
1097
+ past_key_values=encoder_outputs.past_key_values,
1098
+ hidden_states=encoder_outputs.hidden_states,
1099
+ attentions=encoder_outputs.attentions,
1100
+ cross_attentions=encoder_outputs.cross_attentions,
1101
+ )
1102
+
1103
+
1104
+ @add_start_docstrings(
1105
+ """
1106
+ CharacterBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
1107
+ `next sentence prediction (classification)` head.
1108
+ """,
1109
+ CHARACTER_BERT_START_DOCSTRING,
1110
+ )
1111
+ class CharacterBertForPreTraining(CharacterBertPreTrainedModel):
1112
+ def __init__(self, config):
1113
+ super().__init__(config)
1114
+
1115
+ self.character_bert = CharacterBertModel(config)
1116
+ self.cls = CharacterBertPreTrainingHeads(config)
1117
+
1118
+ self.init_weights()
1119
+
1120
+ def get_output_embeddings(self):
1121
+ return self.cls.predictions.decoder
1122
+
1123
+ def set_output_embeddings(self, new_embeddings):
1124
+ self.cls.predictions.decoder = new_embeddings
1125
+
1126
+ @add_start_docstrings_to_model_forward(
1127
+ CHARACTER_BERT_INPUTS_DOCSTRING.format(
1128
+ "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)"
1129
+ )
1130
+ )
1131
+ @replace_return_docstrings(output_type=CharacterBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1132
+ def forward(
1133
+ self,
1134
+ input_ids=None,
1135
+ attention_mask=None,
1136
+ token_type_ids=None,
1137
+ position_ids=None,
1138
+ head_mask=None,
1139
+ inputs_embeds=None,
1140
+ labels=None,
1141
+ next_sentence_label=None,
1142
+ output_attentions=None,
1143
+ output_hidden_states=None,
1144
+ return_dict=None,
1145
+ ):
1146
+ r"""
1147
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1148
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.mlm_vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
1149
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.mlm_vocab_size]`
1150
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1151
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1152
+ (see `input_ids` docstring) Indices should be in `[0, 1]`:
1153
+
1154
+ - 0 indicates sequence B is a continuation of sequence A,
1155
+ - 1 indicates sequence B is a random sequence.
1156
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1157
+ Used to hide legacy arguments that have been deprecated.
1158
+
1159
+ Returns:
1160
+
1161
+ Example:
1162
+
1163
+ ```python
1164
+ >>> from transformers import CharacterBertTokenizer, CharacterBertForPreTraining >>> import torch
1165
+
1166
+ >>> tokenizer = CharacterBertTokenizer.from_pretrained('helboukkouri/character-bert') >>> model =
1167
+ CharacterBertForPreTraining.from_pretrained('helboukkouri/character-bert')
1168
+
1169
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs)
1170
+
1171
+ >>> prediction_logits = outputs.prediction_logits >>> seq_relationship_logits =
1172
+ outputs.seq_relationship_logits
1173
+ ```
1174
+ """
1175
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1176
+
1177
+ outputs = self.character_bert(
1178
+ input_ids,
1179
+ attention_mask=attention_mask,
1180
+ token_type_ids=token_type_ids,
1181
+ position_ids=position_ids,
1182
+ head_mask=head_mask,
1183
+ inputs_embeds=inputs_embeds,
1184
+ output_attentions=output_attentions,
1185
+ output_hidden_states=output_hidden_states,
1186
+ return_dict=return_dict,
1187
+ )
1188
+
1189
+ sequence_output, pooled_output = outputs[:2]
1190
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1191
+
1192
+ total_loss = None
1193
+ if labels is not None and next_sentence_label is not None:
1194
+ loss_fct = CrossEntropyLoss()
1195
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.mlm_vocab_size), labels.view(-1))
1196
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1197
+ total_loss = masked_lm_loss + next_sentence_loss
1198
+
1199
+ if not return_dict:
1200
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1201
+ return ((total_loss,) + output) if total_loss is not None else output
1202
+
1203
+ return CharacterBertForPreTrainingOutput(
1204
+ loss=total_loss,
1205
+ prediction_logits=prediction_scores,
1206
+ seq_relationship_logits=seq_relationship_score,
1207
+ hidden_states=outputs.hidden_states,
1208
+ attentions=outputs.attentions,
1209
+ )
1210
+
1211
+
1212
+ @add_start_docstrings(
1213
+ """CharacterBert Model with a `language modeling` head on top for CLM fine-tuning.""",
1214
+ CHARACTER_BERT_START_DOCSTRING,
1215
+ )
1216
+ class CharacterBertLMHeadModel(CharacterBertPreTrainedModel):
1217
+
1218
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1219
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1220
+
1221
+ def __init__(self, config):
1222
+ super().__init__(config)
1223
+
1224
+ if not config.is_decoder:
1225
+ logger.warning("If you want to use `CharacterBertLMHeadModel` as a standalone, add `is_decoder=True.`")
1226
+
1227
+ self.character_bert = CharacterBertModel(config, add_pooling_layer=False)
1228
+ self.cls = CharacterBertOnlyMLMHead(config)
1229
+
1230
+ self.init_weights()
1231
+
1232
+ def get_output_embeddings(self):
1233
+ return self.cls.predictions.decoder
1234
+
1235
+ def set_output_embeddings(self, new_embeddings):
1236
+ self.cls.predictions.decoder = new_embeddings
1237
+
1238
+ @add_start_docstrings_to_model_forward(
1239
+ CHARACTER_BERT_INPUTS_DOCSTRING.format(
1240
+ "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)"
1241
+ )
1242
+ )
1243
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1244
+ def forward(
1245
+ self,
1246
+ input_ids=None,
1247
+ attention_mask=None,
1248
+ token_type_ids=None,
1249
+ position_ids=None,
1250
+ head_mask=None,
1251
+ inputs_embeds=None,
1252
+ encoder_hidden_states=None,
1253
+ encoder_attention_mask=None,
1254
+ labels=None,
1255
+ past_key_values=None,
1256
+ use_cache=None,
1257
+ output_attentions=None,
1258
+ output_hidden_states=None,
1259
+ return_dict=None,
1260
+ ):
1261
+ r"""
1262
+ encoder_hidden_states (:
1263
+ obj:*torch.FloatTensor* of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence
1264
+ of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
1265
+ is configured as a decoder.
1266
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1267
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1268
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1269
+
1270
+ - 1 for tokens that are **not masked**,
1271
+ - 0 for tokens that are **masked**.
1272
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1273
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1274
+ `[-100, 0, ..., config.mlm_vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100`
1275
+ are ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.mlm_vocab_size]`
1276
+ past_key_values (:
1277
+ obj:*tuple(tuple(torch.FloatTensor))* of length `config.n_layers` with each tuple having 4 tensors of
1278
+ shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key
1279
+ and value hidden states of the attention blocks. Can be used to speed up decoding.
1280
+
1281
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
1282
+ (those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
1283
+ instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1284
+ use_cache (`bool`, *optional*):
1285
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up
1286
+ decoding (see `past_key_values`).
1287
+
1288
+ Returns:
1289
+
1290
+ Example:
1291
+
1292
+ ```python
1293
+ >>> from transformers import CharacterBertTokenizer, CharacterBertLMHeadModel, CharacterBertConfig >>>
1294
+ import torch
1295
+
1296
+ >>> tokenizer = CharacterBertTokenizer.from_pretrained('helboukkouri/character-bert') >>> config =
1297
+ CharacterBertConfig.from_pretrained("helboukkouri/character-bert") >>> config.is_decoder = True >>> model =
1298
+ CharacterBertLMHeadModel.from_pretrained('helboukkouri/character-bert', config=config)
1299
+
1300
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs)
1301
+
1302
+ >>> prediction_logits = outputs.logits
1303
+ ```
1304
+ """
1305
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1306
+ if labels is not None:
1307
+ use_cache = False
1308
+
1309
+ outputs = self.character_bert(
1310
+ input_ids,
1311
+ attention_mask=attention_mask,
1312
+ token_type_ids=token_type_ids,
1313
+ position_ids=position_ids,
1314
+ head_mask=head_mask,
1315
+ inputs_embeds=inputs_embeds,
1316
+ encoder_hidden_states=encoder_hidden_states,
1317
+ encoder_attention_mask=encoder_attention_mask,
1318
+ past_key_values=past_key_values,
1319
+ use_cache=use_cache,
1320
+ output_attentions=output_attentions,
1321
+ output_hidden_states=output_hidden_states,
1322
+ return_dict=return_dict,
1323
+ )
1324
+
1325
+ sequence_output = outputs[0]
1326
+ prediction_scores = self.cls(sequence_output)
1327
+
1328
+ lm_loss = None
1329
+ if labels is not None:
1330
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1331
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1332
+ labels = labels[:, 1:].contiguous()
1333
+ loss_fct = CrossEntropyLoss()
1334
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.mlm_vocab_size), labels.view(-1))
1335
+
1336
+ if not return_dict:
1337
+ output = (prediction_scores,) + outputs[2:]
1338
+ return ((lm_loss,) + output) if lm_loss is not None else output
1339
+
1340
+ return CausalLMOutputWithCrossAttentions(
1341
+ loss=lm_loss,
1342
+ logits=prediction_scores,
1343
+ past_key_values=outputs.past_key_values,
1344
+ hidden_states=outputs.hidden_states,
1345
+ attentions=outputs.attentions,
1346
+ cross_attentions=outputs.cross_attentions,
1347
+ )
1348
+
1349
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1350
+ input_shape = input_ids.shape
1351
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1352
+ if attention_mask is None:
1353
+ attention_mask = input_ids.new_ones(input_shape)
1354
+
1355
+ # cut decoder_input_ids if past is used
1356
+ if past is not None:
1357
+ input_ids = input_ids[:, -1:]
1358
+
1359
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
1360
+
1361
+ def _reorder_cache(self, past, beam_idx):
1362
+ reordered_past = ()
1363
+ for layer_past in past:
1364
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1365
+ return reordered_past
1366
+
1367
+
1368
+ @add_start_docstrings(
1369
+ """CharacterBert Model with a `language modeling` head on top.""", CHARACTER_BERT_START_DOCSTRING
1370
+ )
1371
+ class CharacterBertForMaskedLM(CharacterBertPreTrainedModel):
1372
+
1373
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1374
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1375
+
1376
+ def __init__(self, config):
1377
+ super().__init__(config)
1378
+
1379
+ if config.is_decoder:
1380
+ logger.warning(
1381
+ "If you want to use `CharacterBertForMaskedLM` make sure `config.is_decoder=False` for "
1382
+ "bi-directional self-attention."
1383
+ )
1384
+ self.character_bert = CharacterBertModel(config, add_pooling_layer=False)
1385
+ self.cls = CharacterBertOnlyMLMHead(config)
1386
+
1387
+ self.init_weights()
1388
+
1389
+ def get_output_embeddings(self):
1390
+ return self.cls.predictions.decoder
1391
+
1392
+ def set_output_embeddings(self, new_embeddings):
1393
+ self.cls.predictions.decoder = new_embeddings
1394
+
1395
+ @add_start_docstrings_to_model_forward(
1396
+ CHARACTER_BERT_INPUTS_DOCSTRING.format(
1397
+ "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)"
1398
+ )
1399
+ )
1400
+ @add_code_sample_docstrings(
1401
+ processor_class=_TOKENIZER_FOR_DOC,
1402
+ checkpoint=_CHECKPOINT_FOR_DOC,
1403
+ output_type=MaskedLMOutput,
1404
+ config_class=_CONFIG_FOR_DOC,
1405
+ )
1406
+ def forward(
1407
+ self,
1408
+ input_ids=None,
1409
+ attention_mask=None,
1410
+ token_type_ids=None,
1411
+ position_ids=None,
1412
+ head_mask=None,
1413
+ inputs_embeds=None,
1414
+ encoder_hidden_states=None,
1415
+ encoder_attention_mask=None,
1416
+ labels=None,
1417
+ output_attentions=None,
1418
+ output_hidden_states=None,
1419
+ return_dict=None,
1420
+ ):
1421
+ r"""
1422
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1423
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.mlm_vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
1424
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.mlm_vocab_size]`
1425
+ """
1426
+
1427
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1428
+
1429
+ outputs = self.character_bert(
1430
+ input_ids,
1431
+ attention_mask=attention_mask,
1432
+ token_type_ids=token_type_ids,
1433
+ position_ids=position_ids,
1434
+ head_mask=head_mask,
1435
+ inputs_embeds=inputs_embeds,
1436
+ encoder_hidden_states=encoder_hidden_states,
1437
+ encoder_attention_mask=encoder_attention_mask,
1438
+ output_attentions=output_attentions,
1439
+ output_hidden_states=output_hidden_states,
1440
+ return_dict=return_dict,
1441
+ )
1442
+
1443
+ sequence_output = outputs[0]
1444
+ prediction_scores = self.cls(sequence_output)
1445
+
1446
+ masked_lm_loss = None
1447
+ if labels is not None:
1448
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1449
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.mlm_vocab_size), labels.view(-1))
1450
+
1451
+ if not return_dict:
1452
+ output = (prediction_scores,) + outputs[2:]
1453
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1454
+
1455
+ return MaskedLMOutput(
1456
+ loss=masked_lm_loss,
1457
+ logits=prediction_scores,
1458
+ hidden_states=outputs.hidden_states,
1459
+ attentions=outputs.attentions,
1460
+ )
1461
+
1462
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1463
+ input_shape = input_ids.shape
1464
+ effective_batch_size = input_shape[0]
1465
+
1466
+ # add a dummy token
1467
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1468
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1469
+ dummy_token = torch.full(
1470
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1471
+ )
1472
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1473
+
1474
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1475
+
1476
+
1477
+ @add_start_docstrings(
1478
+ """CharacterBert Model with a `next sentence prediction (classification)` head on top.""",
1479
+ CHARACTER_BERT_START_DOCSTRING,
1480
+ )
1481
+ class CharacterBertForNextSentencePrediction(CharacterBertPreTrainedModel):
1482
+ def __init__(self, config):
1483
+ super().__init__(config)
1484
+
1485
+ self.character_bert = CharacterBertModel(config)
1486
+ self.cls = CharacterBertOnlyNSPHead(config)
1487
+
1488
+ self.init_weights()
1489
+
1490
+ @add_start_docstrings_to_model_forward(
1491
+ CHARACTER_BERT_INPUTS_DOCSTRING.format(
1492
+ "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)"
1493
+ )
1494
+ )
1495
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1496
+ def forward(
1497
+ self,
1498
+ input_ids=None,
1499
+ attention_mask=None,
1500
+ token_type_ids=None,
1501
+ position_ids=None,
1502
+ head_mask=None,
1503
+ inputs_embeds=None,
1504
+ labels=None,
1505
+ output_attentions=None,
1506
+ output_hidden_states=None,
1507
+ return_dict=None,
1508
+ **kwargs
1509
+ ):
1510
+ r"""
1511
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1512
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1513
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
1514
+
1515
+ - 0 indicates sequence B is a continuation of sequence A,
1516
+ - 1 indicates sequence B is a random sequence.
1517
+
1518
+ Returns:
1519
+
1520
+ Example:
1521
+
1522
+ ```python
1523
+ >>> from transformers import CharacterBertTokenizer, CharacterBertForNextSentencePrediction >>> import
1524
+ torch
1525
+
1526
+ >>> tokenizer = CharacterBertTokenizer.from_pretrained('helboukkouri/character-bert') >>> model =
1527
+ CharacterBertForNextSentencePrediction.from_pretrained('helboukkouri/character-bert')
1528
+
1529
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1530
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." >>> encoding =
1531
+ tokenizer(prompt, next_sentence, return_tensors='pt')
1532
+
1533
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1])) >>> logits = outputs.logits >>> assert
1534
+ logits[0, 0] < logits[0, 1] # next sentence was random
1535
+ ```
1536
+ """
1537
+
1538
+ if "next_sentence_label" in kwargs:
1539
+ warnings.warn(
1540
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
1541
+ FutureWarning,
1542
+ )
1543
+ labels = kwargs.pop("next_sentence_label")
1544
+
1545
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1546
+
1547
+ outputs = self.character_bert(
1548
+ input_ids,
1549
+ attention_mask=attention_mask,
1550
+ token_type_ids=token_type_ids,
1551
+ position_ids=position_ids,
1552
+ head_mask=head_mask,
1553
+ inputs_embeds=inputs_embeds,
1554
+ output_attentions=output_attentions,
1555
+ output_hidden_states=output_hidden_states,
1556
+ return_dict=return_dict,
1557
+ )
1558
+
1559
+ pooled_output = outputs[1]
1560
+
1561
+ seq_relationship_scores = self.cls(pooled_output)
1562
+
1563
+ next_sentence_loss = None
1564
+ if labels is not None:
1565
+ loss_fct = CrossEntropyLoss()
1566
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1567
+
1568
+ if not return_dict:
1569
+ output = (seq_relationship_scores,) + outputs[2:]
1570
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1571
+
1572
+ return NextSentencePredictorOutput(
1573
+ loss=next_sentence_loss,
1574
+ logits=seq_relationship_scores,
1575
+ hidden_states=outputs.hidden_states,
1576
+ attentions=outputs.attentions,
1577
+ )
1578
+
1579
+
1580
+ @add_start_docstrings(
1581
+ """
1582
+ CharacterBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1583
+ pooled output) e.g. for GLUE tasks.
1584
+ """,
1585
+ CHARACTER_BERT_START_DOCSTRING,
1586
+ )
1587
+ class CharacterBertForSequenceClassification(CharacterBertPreTrainedModel):
1588
+ def __init__(self, config):
1589
+ super().__init__(config)
1590
+ self.num_labels = config.num_labels
1591
+ self.character_bert = CharacterBertModel(config)
1592
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1593
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1594
+
1595
+ self.init_weights()
1596
+
1597
+ @add_start_docstrings_to_model_forward(
1598
+ CHARACTER_BERT_INPUTS_DOCSTRING.format(
1599
+ "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)"
1600
+ )
1601
+ )
1602
+ @add_code_sample_docstrings(
1603
+ processor_class=_TOKENIZER_FOR_DOC,
1604
+ checkpoint=_CHECKPOINT_FOR_DOC,
1605
+ output_type=SequenceClassifierOutput,
1606
+ config_class=_CONFIG_FOR_DOC,
1607
+ )
1608
+ def forward(
1609
+ self,
1610
+ input_ids=None,
1611
+ attention_mask=None,
1612
+ token_type_ids=None,
1613
+ position_ids=None,
1614
+ head_mask=None,
1615
+ inputs_embeds=None,
1616
+ labels=None,
1617
+ output_attentions=None,
1618
+ output_hidden_states=None,
1619
+ return_dict=None,
1620
+ ):
1621
+ r"""
1622
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1623
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1624
+ If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1625
+ """
1626
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1627
+
1628
+ outputs = self.character_bert(
1629
+ input_ids,
1630
+ attention_mask=attention_mask,
1631
+ token_type_ids=token_type_ids,
1632
+ position_ids=position_ids,
1633
+ head_mask=head_mask,
1634
+ inputs_embeds=inputs_embeds,
1635
+ output_attentions=output_attentions,
1636
+ output_hidden_states=output_hidden_states,
1637
+ return_dict=return_dict,
1638
+ )
1639
+
1640
+ pooled_output = outputs[1]
1641
+
1642
+ pooled_output = self.dropout(pooled_output)
1643
+ logits = self.classifier(pooled_output)
1644
+
1645
+ loss = None
1646
+ if labels is not None:
1647
+ if self.num_labels == 1:
1648
+ # We are doing regression
1649
+ loss_fct = MSELoss()
1650
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1651
+ else:
1652
+ loss_fct = CrossEntropyLoss()
1653
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1654
+
1655
+ if not return_dict:
1656
+ output = (logits,) + outputs[2:]
1657
+ return ((loss,) + output) if loss is not None else output
1658
+
1659
+ return SequenceClassifierOutput(
1660
+ loss=loss,
1661
+ logits=logits,
1662
+ hidden_states=outputs.hidden_states,
1663
+ attentions=outputs.attentions,
1664
+ )
1665
+
1666
+
1667
+ @add_start_docstrings(
1668
+ """
1669
+ CharacterBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output
1670
+ and a softmax) e.g. for RocStories/SWAG tasks.
1671
+ """,
1672
+ CHARACTER_BERT_START_DOCSTRING,
1673
+ )
1674
+ class CharacterBertForMultipleChoice(CharacterBertPreTrainedModel):
1675
+ def __init__(self, config):
1676
+ super().__init__(config)
1677
+
1678
+ self.character_bert = CharacterBertModel(config)
1679
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1680
+ self.classifier = nn.Linear(config.hidden_size, 1)
1681
+
1682
+ self.init_weights()
1683
+
1684
+ @add_start_docstrings_to_model_forward(
1685
+ CHARACTER_BERT_INPUTS_DOCSTRING.format(
1686
+ "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)"
1687
+ )
1688
+ )
1689
+ @add_code_sample_docstrings(
1690
+ processor_class=_TOKENIZER_FOR_DOC,
1691
+ checkpoint=_CHECKPOINT_FOR_DOC,
1692
+ output_type=MultipleChoiceModelOutput,
1693
+ config_class=_CONFIG_FOR_DOC,
1694
+ )
1695
+ def forward(
1696
+ self,
1697
+ input_ids=None,
1698
+ attention_mask=None,
1699
+ token_type_ids=None,
1700
+ position_ids=None,
1701
+ head_mask=None,
1702
+ inputs_embeds=None,
1703
+ labels=None,
1704
+ output_attentions=None,
1705
+ output_hidden_states=None,
1706
+ return_dict=None,
1707
+ ):
1708
+ r"""
1709
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1710
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1711
+ `input_ids` above)
1712
+ """
1713
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1714
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1715
+
1716
+ input_ids = input_ids.view(-1, input_ids.size(-2), input_ids.size(-1)) if input_ids is not None else None
1717
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1718
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1719
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1720
+ inputs_embeds = (
1721
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1722
+ if inputs_embeds is not None
1723
+ else None
1724
+ )
1725
+
1726
+ outputs = self.character_bert(
1727
+ input_ids,
1728
+ attention_mask=attention_mask,
1729
+ token_type_ids=token_type_ids,
1730
+ position_ids=position_ids,
1731
+ head_mask=head_mask,
1732
+ inputs_embeds=inputs_embeds,
1733
+ output_attentions=output_attentions,
1734
+ output_hidden_states=output_hidden_states,
1735
+ return_dict=return_dict,
1736
+ )
1737
+
1738
+ pooled_output = outputs[1]
1739
+
1740
+ pooled_output = self.dropout(pooled_output)
1741
+ logits = self.classifier(pooled_output)
1742
+ reshaped_logits = logits.view(-1, num_choices)
1743
+
1744
+ loss = None
1745
+ if labels is not None:
1746
+ loss_fct = CrossEntropyLoss()
1747
+ loss = loss_fct(reshaped_logits, labels)
1748
+
1749
+ if not return_dict:
1750
+ output = (reshaped_logits,) + outputs[2:]
1751
+ return ((loss,) + output) if loss is not None else output
1752
+
1753
+ return MultipleChoiceModelOutput(
1754
+ loss=loss,
1755
+ logits=reshaped_logits,
1756
+ hidden_states=outputs.hidden_states,
1757
+ attentions=outputs.attentions,
1758
+ )
1759
+
1760
+
1761
+ @add_start_docstrings(
1762
+ """
1763
+ CharacterBERT Model with a token classification head on top (a linear layer on top of the hidden-states output)
1764
+ e.g. for Named-Entity-Recognition (NER) tasks.
1765
+ """,
1766
+ CHARACTER_BERT_START_DOCSTRING,
1767
+ )
1768
+ class CharacterBertForTokenClassification(CharacterBertPreTrainedModel):
1769
+ def __init__(self, config):
1770
+ super().__init__(config)
1771
+ self.num_labels = config.num_labels
1772
+
1773
+ self.character_bert = CharacterBertModel(config)
1774
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1775
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1776
+
1777
+ self.init_weights()
1778
+
1779
+ @add_start_docstrings_to_model_forward(
1780
+ CHARACTER_BERT_INPUTS_DOCSTRING.format(
1781
+ "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)"
1782
+ )
1783
+ )
1784
+ @add_code_sample_docstrings(
1785
+ processor_class=_TOKENIZER_FOR_DOC,
1786
+ checkpoint=_CHECKPOINT_FOR_DOC,
1787
+ output_type=TokenClassifierOutput,
1788
+ config_class=_CONFIG_FOR_DOC,
1789
+ )
1790
+ def forward(
1791
+ self,
1792
+ input_ids=None,
1793
+ attention_mask=None,
1794
+ token_type_ids=None,
1795
+ position_ids=None,
1796
+ head_mask=None,
1797
+ inputs_embeds=None,
1798
+ labels=None,
1799
+ output_attentions=None,
1800
+ output_hidden_states=None,
1801
+ return_dict=None,
1802
+ ):
1803
+ r"""
1804
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1805
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1806
+ """
1807
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1808
+
1809
+ outputs = self.character_bert(
1810
+ input_ids,
1811
+ attention_mask=attention_mask,
1812
+ token_type_ids=token_type_ids,
1813
+ position_ids=position_ids,
1814
+ head_mask=head_mask,
1815
+ inputs_embeds=inputs_embeds,
1816
+ output_attentions=output_attentions,
1817
+ output_hidden_states=output_hidden_states,
1818
+ return_dict=return_dict,
1819
+ )
1820
+
1821
+ sequence_output = outputs[0]
1822
+
1823
+ sequence_output = self.dropout(sequence_output)
1824
+ logits = self.classifier(sequence_output)
1825
+
1826
+ loss = None
1827
+ if labels is not None:
1828
+ loss_fct = CrossEntropyLoss()
1829
+ # Only keep active parts of the loss
1830
+ if attention_mask is not None:
1831
+ active_loss = attention_mask.view(-1) == 1
1832
+ active_logits = logits.view(-1, self.num_labels)
1833
+ active_labels = torch.where(
1834
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1835
+ )
1836
+ loss = loss_fct(active_logits, active_labels)
1837
+ else:
1838
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1839
+
1840
+ if not return_dict:
1841
+ output = (logits,) + outputs[2:]
1842
+ return ((loss,) + output) if loss is not None else output
1843
+
1844
+ return TokenClassifierOutput(
1845
+ loss=loss,
1846
+ logits=logits,
1847
+ hidden_states=outputs.hidden_states,
1848
+ attentions=outputs.attentions,
1849
+ )
1850
+
1851
+
1852
+ @add_start_docstrings(
1853
+ """
1854
+ CharacterBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
1855
+ linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1856
+ """,
1857
+ CHARACTER_BERT_START_DOCSTRING,
1858
+ )
1859
+ class CharacterBertForQuestionAnswering(CharacterBertPreTrainedModel):
1860
+ def __init__(self, config):
1861
+ super().__init__(config)
1862
+
1863
+ config.num_labels = 2
1864
+ self.num_labels = config.num_labels
1865
+
1866
+ self.character_bert = CharacterBertModel(config)
1867
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1868
+
1869
+ self.init_weights()
1870
+
1871
+ @add_start_docstrings_to_model_forward(
1872
+ CHARACTER_BERT_INPUTS_DOCSTRING.format(
1873
+ "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)"
1874
+ )
1875
+ )
1876
+ @add_code_sample_docstrings(
1877
+ processor_class=_TOKENIZER_FOR_DOC,
1878
+ checkpoint=_CHECKPOINT_FOR_DOC,
1879
+ output_type=QuestionAnsweringModelOutput,
1880
+ config_class=_CONFIG_FOR_DOC,
1881
+ )
1882
+ def forward(
1883
+ self,
1884
+ input_ids=None,
1885
+ attention_mask=None,
1886
+ token_type_ids=None,
1887
+ position_ids=None,
1888
+ head_mask=None,
1889
+ inputs_embeds=None,
1890
+ start_positions=None,
1891
+ end_positions=None,
1892
+ output_attentions=None,
1893
+ output_hidden_states=None,
1894
+ return_dict=None,
1895
+ ):
1896
+ r"""
1897
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1898
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1899
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the
1900
+ sequence are not taken into account for computing the loss.
1901
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1902
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1903
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the
1904
+ sequence are not taken into account for computing the loss.
1905
+ """
1906
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1907
+
1908
+ outputs = self.character_bert(
1909
+ input_ids,
1910
+ attention_mask=attention_mask,
1911
+ token_type_ids=token_type_ids,
1912
+ position_ids=position_ids,
1913
+ head_mask=head_mask,
1914
+ inputs_embeds=inputs_embeds,
1915
+ output_attentions=output_attentions,
1916
+ output_hidden_states=output_hidden_states,
1917
+ return_dict=return_dict,
1918
+ )
1919
+
1920
+ sequence_output = outputs[0]
1921
+
1922
+ logits = self.qa_outputs(sequence_output)
1923
+ start_logits, end_logits = logits.split(1, dim=-1)
1924
+ start_logits = start_logits.squeeze(-1)
1925
+ end_logits = end_logits.squeeze(-1)
1926
+
1927
+ total_loss = None
1928
+ if start_positions is not None and end_positions is not None:
1929
+ # If we are on multi-GPU, split add a dimension
1930
+ if len(start_positions.size()) > 1:
1931
+ start_positions = start_positions.squeeze(-1)
1932
+ if len(end_positions.size()) > 1:
1933
+ end_positions = end_positions.squeeze(-1)
1934
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1935
+ ignored_index = start_logits.size(1)
1936
+ start_positions.clamp_(0, ignored_index)
1937
+ end_positions.clamp_(0, ignored_index)
1938
+
1939
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1940
+ start_loss = loss_fct(start_logits, start_positions)
1941
+ end_loss = loss_fct(end_logits, end_positions)
1942
+ total_loss = (start_loss + end_loss) / 2
1943
+
1944
+ if not return_dict:
1945
+ output = (start_logits, end_logits) + outputs[2:]
1946
+ return ((total_loss,) + output) if total_loss is not None else output
1947
+
1948
+ return QuestionAnsweringModelOutput(
1949
+ loss=total_loss,
1950
+ start_logits=start_logits,
1951
+ end_logits=end_logits,
1952
+ hidden_states=outputs.hidden_states,
1953
+ attentions=outputs.attentions,
1954
+ )