Add support for "padding_side" == "right" in ChatGLMTokenizer.
#107
by
fengkaige
- opened
- tokenization_chatglm.py +14 -6
tokenization_chatglm.py
CHANGED
@@ -225,7 +225,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
225 |
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
226 |
"""
|
227 |
# Load from model defaults
|
228 |
-
assert self.padding_side
|
229 |
|
230 |
required_input = encoded_inputs[self.model_input_names[0]]
|
231 |
seq_length = len(required_input)
|
@@ -248,10 +248,18 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
248 |
if needs_to_be_padded:
|
249 |
difference = max_length - len(required_input)
|
250 |
|
251 |
-
if "
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
return encoded_inputs
|
|
|
|
225 |
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
226 |
"""
|
227 |
# Load from model defaults
|
228 |
+
assert self.padding_side in ["left", "right"]
|
229 |
|
230 |
required_input = encoded_inputs[self.model_input_names[0]]
|
231 |
seq_length = len(required_input)
|
|
|
248 |
if needs_to_be_padded:
|
249 |
difference = max_length - len(required_input)
|
250 |
|
251 |
+
if self.padding_side == "left":
|
252 |
+
if "attention_mask" in encoded_inputs:
|
253 |
+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
254 |
+
if "position_ids" in encoded_inputs:
|
255 |
+
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
256 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
257 |
+
else:
|
258 |
+
if "attention_mask" in encoded_inputs:
|
259 |
+
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
|
260 |
+
if "position_ids" in encoded_inputs:
|
261 |
+
encoded_inputs["position_ids"] = encoded_inputs["position_ids"] + [0] * difference
|
262 |
+
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
263 |
|
264 |
return encoded_inputs
|
265 |
+
|