Add support for "padding_side" == "right" in ChatGLMTokenizer.

#107
Files changed (1) hide show
  1. tokenization_chatglm.py +14 -6
tokenization_chatglm.py CHANGED
@@ -225,7 +225,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
225
  (optional) Set to False to avoid returning attention mask (default: set to model specifics)
226
  """
227
  # Load from model defaults
228
- assert self.padding_side == "left"
229
 
230
  required_input = encoded_inputs[self.model_input_names[0]]
231
  seq_length = len(required_input)
@@ -248,10 +248,18 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
248
  if needs_to_be_padded:
249
  difference = max_length - len(required_input)
250
 
251
- if "attention_mask" in encoded_inputs:
252
- encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
253
- if "position_ids" in encoded_inputs:
254
- encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
255
- encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
 
 
 
 
 
 
 
256
 
257
  return encoded_inputs
 
 
225
  (optional) Set to False to avoid returning attention mask (default: set to model specifics)
226
  """
227
  # Load from model defaults
228
+ assert self.padding_side in ["left", "right"]
229
 
230
  required_input = encoded_inputs[self.model_input_names[0]]
231
  seq_length = len(required_input)
 
248
  if needs_to_be_padded:
249
  difference = max_length - len(required_input)
250
 
251
+ if self.padding_side == "left":
252
+ if "attention_mask" in encoded_inputs:
253
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
254
+ if "position_ids" in encoded_inputs:
255
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
256
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
257
+ else:
258
+ if "attention_mask" in encoded_inputs:
259
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
260
+ if "position_ids" in encoded_inputs:
261
+ encoded_inputs["position_ids"] = encoded_inputs["position_ids"] + [0] * difference
262
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
263
 
264
  return encoded_inputs
265
+