duzx16 commited on
Commit
1676f07
1 Parent(s): 591fa87

Implement new interface

Browse files
Files changed (2) hide show
  1. modeling_chatglm.py +25 -17
  2. tokenization_chatglm.py +15 -10
modeling_chatglm.py CHANGED
@@ -996,18 +996,23 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
996
  for layer_past in past
997
  )
998
 
999
- def process_response(self, response):
1000
- response = response.strip()
1001
- response = response.replace("[[训练时间]]", "2023年")
1002
- return response
1003
-
1004
- def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None):
1005
- inputs = tokenizer.build_chat_input(query, history=history, system=system)
1006
- inputs = inputs.to(self.device)
1007
- return inputs
 
 
 
 
 
1008
 
1009
  @torch.inference_mode()
1010
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None,
1011
  max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1012
  **kwargs):
1013
  if history is None:
@@ -1017,17 +1022,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1017
  logits_processor.append(InvalidScoreLogitsProcessor())
1018
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1019
  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1020
- inputs = self.build_inputs(tokenizer, query, history=history, system=system)
1021
- eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>")]
 
 
1022
  outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1023
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1024
  response = tokenizer.decode(outputs)
1025
- response = self.process_response(response)
1026
- history = history + [(query, response)]
1027
  return response, history
1028
 
1029
  @torch.inference_mode()
1030
- def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None,
1031
  past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1032
  logits_processor=None, return_past_key_values=False, **kwargs):
1033
  if history is None:
@@ -1040,9 +1047,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1040
  gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1041
  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1042
  if past_key_values is None:
1043
- inputs = self.build_inputs(tokenizer, query, history=history, system=system)
1044
  else:
1045
- inputs = self.build_inputs(tokenizer, query)
 
1046
  if past_key_values is not None:
1047
  past_length = past_key_values[0][0].shape[0]
1048
  if self.transformer.pre_seq_len is not None:
 
996
  for layer_past in past
997
  )
998
 
999
+ def process_response(self, output, history):
1000
+ content = ""
1001
+ for response in output.split("<|assistant|>"):
1002
+ metadata, content = response.split("\n", maxsplit=1)
1003
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1004
+ if not metadata.strip():
1005
+ content = content.strip()
1006
+ content = content.replace("[[训练时间]]", "2023年")
1007
+ else:
1008
+ content = "\n".join(content.split("\n")[1:-1])
1009
+ def tool_call(**kwargs):
1010
+ return kwargs
1011
+ content = eval(content)
1012
+ return content, history
1013
 
1014
  @torch.inference_mode()
1015
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = None,
1016
  max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1017
  **kwargs):
1018
  if history is None:
 
1022
  logits_processor.append(InvalidScoreLogitsProcessor())
1023
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1024
  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1025
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1026
+ inputs = inputs.to(self.device)
1027
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1028
+ tokenizer.get_command("<|observation|>")]
1029
  outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1030
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1031
  response = tokenizer.decode(outputs)
1032
+ history.append({"role": role, "content": query})
1033
+ response, history = self.process_response(response, history)
1034
  return response, history
1035
 
1036
  @torch.inference_mode()
1037
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = None,
1038
  past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1039
  logits_processor=None, return_past_key_values=False, **kwargs):
1040
  if history is None:
 
1047
  gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1048
  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1049
  if past_key_values is None:
1050
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1051
  else:
1052
+ inputs = tokenizer.build_chat_input(query, role=role)
1053
+ input = inputs.to(self.device)
1054
  if past_key_values is not None:
1055
  past_length = past_key_values[0][0].shape[0]
1056
  if self.transformer.pre_seq_len is not None:
tokenization_chatglm.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import torch
3
  from typing import List, Optional, Union, Dict
@@ -173,19 +174,23 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
173
  prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
174
  return prefix_tokens
175
 
176
- def build_chat_input(self, query, history=None, system=None):
 
 
 
 
 
 
 
177
  if history is None:
178
  history = []
179
  input_ids = []
180
- if system is not None:
181
- input_ids.extend(
182
- [self.get_command("<|system|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(system))
183
- for i, (old_query, old_response) in enumerate(history):
184
- input_ids.extend(
185
- [self.get_command("<|user|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(old_query))
186
- input_ids.extend(
187
- [self.get_command("<|assistant|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(old_response))
188
- input_ids.extend([self.get_command("<|user|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(query))
189
  input_ids.extend([self.get_command("<|assistant|>")])
190
  return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
191
 
 
1
+ import json
2
  import os
3
  import torch
4
  from typing import List, Optional, Union, Dict
 
174
  prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
175
  return prefix_tokens
176
 
177
+ def build_single_message(self, role, metadata, message):
178
+ assert role in ["system", "user", "assistant", "observation"], role
179
+ role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
180
+ message_tokens = self.tokenizer.encode(message)
181
+ tokens = role_tokens + message_tokens
182
+ return tokens
183
+
184
+ def build_chat_input(self, query, history=None, role="user"):
185
  if history is None:
186
  history = []
187
  input_ids = []
188
+ for item in history:
189
+ content = item["content"]
190
+ if item["role"] == "system" and "tools" in item:
191
+ content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
192
+ input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
193
+ input_ids.extend(self.build_single_message(role, "", query))
 
 
 
194
  input_ids.extend([self.get_command("<|assistant|>")])
195
  return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
196