Yes365 commited on
Commit
53bd0bf
1 Parent(s): 162b620

add set_input_embedding to support resize token embedding

Browse files

resize_token_embeddings need implements this method.

chatglm support it , but chatglm2 not support it
![企业微信截图_6553c363-ebfe-4d37-8ae2-c44a1244942a.png](https://cdn-uploads.huggingface.co/production/uploads/646efa216b3df773a2c75a17/XabvEgf2-8ESvqsXdMKZI.png)

Files changed (1) hide show
  1. modeling_chatglm.py +3 -0
modeling_chatglm.py CHANGED
@@ -766,6 +766,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
766
  def get_input_embeddings(self):
767
  return self.embedding.word_embeddings
768
 
 
 
 
769
  def get_prompt(self, batch_size, device, dtype=torch.half):
770
  prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
771
  past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
 
766
  def get_input_embeddings(self):
767
  return self.embedding.word_embeddings
768
 
769
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
770
+ self.embedding.word_embeddings = new_embeddings
771
+
772
  def get_prompt(self, batch_size, device, dtype=torch.half):
773
  prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
774
  past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)