add set_input_embedding to support resize token embedding
Browse filesresize_token_embeddings need implements this method.
chatglm support it , but chatglm2 not support it
![企业微信截图_6553c363-ebfe-4d37-8ae2-c44a1244942a.png](https://cdn-uploads.huggingface.co/production/uploads/646efa216b3df773a2c75a17/XabvEgf2-8ESvqsXdMKZI.png)
- modeling_chatglm.py +3 -0
modeling_chatglm.py
CHANGED
@@ -766,6 +766,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
766 |
def get_input_embeddings(self):
|
767 |
return self.embedding.word_embeddings
|
768 |
|
|
|
|
|
|
|
769 |
def get_prompt(self, batch_size, device, dtype=torch.half):
|
770 |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
771 |
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
|
|
766 |
def get_input_embeddings(self):
|
767 |
return self.embedding.word_embeddings
|
768 |
|
769 |
+
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
770 |
+
self.embedding.word_embeddings = new_embeddings
|
771 |
+
|
772 |
def get_prompt(self, batch_size, device, dtype=torch.half):
|
773 |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
774 |
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|