aka7774 commited on
Commit
d347e01
1 Parent(s): 3350c44

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +13 -12
  2. fn.py +87 -69
app.py CHANGED
@@ -16,31 +16,31 @@ with gr.Blocks() as demo:
16
 
17
  with gr.Column(scale=1):
18
  max_new_tokens = gr.Textbox(
19
- value=fn.cfg['max_new_tokens'],
20
  label='max_new_tokens',
21
  interactive=True,
22
  show_copy_button=True,
23
  )
24
  temperature = gr.Textbox(
25
- value=fn.cfg['temperature'],
26
  label='temperature',
27
  interactive=True,
28
  show_copy_button=True,
29
  )
30
  top_p = gr.Textbox(
31
- value=fn.cfg['top_p'],
32
  label='top_p',
33
  interactive=True,
34
  show_copy_button=True,
35
  )
36
  top_k = gr.Textbox(
37
- value=fn.cfg['top_k'],
38
  label='top_k',
39
  interactive=True,
40
  show_copy_button=True,
41
  )
42
  repetition_penalty = gr.Textbox(
43
- value=fn.cfg['repetition_penalty'],
44
  label='repetition_penalty',
45
  interactive=True,
46
  show_copy_button=True,
@@ -48,17 +48,18 @@ with gr.Blocks() as demo:
48
 
49
  with gr.Row():
50
  with gr.Column(scale=1):
51
- inst_template = gr.Textbox(
52
  value='',
53
- lines=10,
54
- label='inst_template',
55
  interactive=True,
56
  show_copy_button=True,
57
  )
58
- is_use_cache = gr.Checkbox(
59
- value=False,
60
- label='is_use_cache',
 
61
  interactive=True,
 
62
  )
63
 
64
  set_button = gr.Button(value='Save')
@@ -97,7 +98,7 @@ with gr.Blocks() as demo:
97
 
98
  set_button.click(
99
  fn=fn.set_config,
100
- inputs=[size, instruction, inst_template, is_use_cache, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
101
  outputs=[info],
102
  )
103
 
 
16
 
17
  with gr.Column(scale=1):
18
  max_new_tokens = gr.Textbox(
19
+ value=fn.default_args['max_new_tokens'],
20
  label='max_new_tokens',
21
  interactive=True,
22
  show_copy_button=True,
23
  )
24
  temperature = gr.Textbox(
25
+ value=fn.default_args['temperature'],
26
  label='temperature',
27
  interactive=True,
28
  show_copy_button=True,
29
  )
30
  top_p = gr.Textbox(
31
+ value=fn.default_args['top_p'],
32
  label='top_p',
33
  interactive=True,
34
  show_copy_button=True,
35
  )
36
  top_k = gr.Textbox(
37
+ value=fn.default_args['top_k'],
38
  label='top_k',
39
  interactive=True,
40
  show_copy_button=True,
41
  )
42
  repetition_penalty = gr.Textbox(
43
+ value=fn.default_args['repetition_penalty'],
44
  label='repetition_penalty',
45
  interactive=True,
46
  show_copy_button=True,
 
48
 
49
  with gr.Row():
50
  with gr.Column(scale=1):
51
+ first_assistant = gr.Textbox(
52
  value='',
53
+ label='first_assistant',
 
54
  interactive=True,
55
  show_copy_button=True,
56
  )
57
+ chat_template = gr.Textbox(
58
+ value='',
59
+ lines=10,
60
+ label='chat_template',
61
  interactive=True,
62
+ show_copy_button=True,
63
  )
64
 
65
  set_button = gr.Button(value='Save')
 
98
 
99
  set_button.click(
100
  fn=fn.set_config,
101
+ inputs=[size, instruction, first_assistant, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
102
  outputs=[info],
103
  )
104
 
fn.py CHANGED
@@ -14,20 +14,22 @@ from threading import Thread
14
 
15
  tokenizer = None
16
  model = None
17
- default_cfg = {
18
  'size': None,
 
 
19
  'instruction': None,
20
- 'inst_template': None,
21
- 'is_use_cache': False,
22
  'max_new_tokens': 1024,
23
  'temperature': 0.9,
24
  'top_p': 0.95,
25
  'top_k': 40,
26
  'repetition_penalty': 1.2,
27
  }
28
- cfg = default_cfg.copy()
29
- cache = None
30
- chat_history = []
31
 
32
  def load_model(size = '9b'):
33
  global tokenizer, model, cfg
@@ -50,17 +52,13 @@ def load_model(size = '9b'):
50
 
51
  cfg['size'] = size
52
 
53
- def clear_config():
54
- global cfg
55
- cfg = default_cfg.copy()
56
-
57
- def set_config(size, instruction, inst_template, is_use_cache, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
58
- global cfg
59
  load_model(size)
60
- cfg.update({
61
  'instruction': instruction,
62
- 'inst_template': inst_template,
63
- 'is_use_cache': bool(is_use_cache),
64
  'max_new_tokens': int(max_new_tokens),
65
  'temperature': float(temperature),
66
  'top_p': float(top_p),
@@ -70,24 +68,16 @@ def set_config(size, instruction, inst_template, is_use_cache, max_new_tokens, t
70
  return 'done.'
71
 
72
  def set_config_args(args):
73
- global cfg
74
 
75
  load_model(args['size'])
76
- cfg.update(args)
77
 
78
  return 'done.'
79
 
80
- def chatinterface_to_messages(message, history):
81
- global cfg
82
-
83
  messages = []
84
 
85
- if cfg['instruction']:
86
- messages.append({'role': 'user', 'content': cfg['instruction']})
87
- # userとassistantは交互に存在しないといけない
88
- if message:
89
- messages.append({'role': 'assistant', 'content': '了解しました。'})
90
-
91
  for pair in history:
92
  [user, assistant] = pair
93
  if user:
@@ -95,41 +85,76 @@ def chatinterface_to_messages(message, history):
95
  if assistant:
96
  messages.append({'role': 'assistant', 'content': assistant})
97
 
98
- if message:
99
- messages.append({'role': 'user', 'content': message})
100
-
101
  return messages
102
 
103
- def apply_template(messages):
104
- global tokenizer, cfg, cache, chat_history
 
105
 
106
- if type(messages) is str:
107
- if cfg['inst_template']:
108
- user_input = cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
109
- user_input = cfg['instruction'].format(input=messages)
110
- tokenized_chat = tokenizer(user_input, return_tensors="pt").input_ids
111
- if type(messages) is list:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  tokenized_chat = tokenizer.apply_chat_template(
113
- messages + chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt"
114
  )
 
 
 
 
 
 
 
 
 
115
  return tokenized_chat
116
 
117
- def chat(message, history = [], instruction = None, args = {}):
118
- global tokenizer, model, cfg, cache, chat_history
119
 
120
- if instruction:
121
- cfg['instruction'] = instruction
122
- tokenized_chat = apply_template(message)
123
- else:
124
- messages = chatinterface_to_messages(message, history)
125
- tokenized_chat = apply_template(messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  device = local_gemma.utils.config.infer_device(None)
128
- is_use_cache = cfg['is_use_cache']
129
  generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat')
130
 
131
  streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True})
132
- tokenized_chat = tokenized_chat.to(device)
133
  generation_kwargs.update(
134
  {
135
  "streamer": streamer,
@@ -146,16 +171,16 @@ def chat(message, history = [], instruction = None, args = {}):
146
  'top_k',
147
  'repetition_penalty'
148
  ]:
149
- if cfg[k]:
150
- generation_kwargs[k] = cfg[k]
151
 
152
  # TODO(joao): this if shouldn't be needed, fix in transformers
153
  if cache is not None:
154
  generation_kwargs["cache_implementation"] = None
155
 
156
- if cfg['max_new_tokens'] is not None:
157
  input_ids_len = tokenized_chat.shape[-1]
158
- max_cache_len = cfg['max_new_tokens'] + input_ids_len
159
  if cache is not None and cache.max_cache_len < max_cache_len:
160
  # reset the cache
161
  generation_kwargs.pop("past_key_values")
@@ -169,34 +194,27 @@ def chat(message, history = [], instruction = None, args = {}):
169
  cache = gen_out.past_key_values
170
  model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:]
171
  model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True)
172
- chat_history += [{"role": "user", "content": message},]
173
- chat_history += [{"role": "assistant", "content": model_output_text},]
174
 
175
  # Sanity check: EOS was removed, ends in "<end_of_turn>\n"
176
  tokenized_chat = tokenizer.apply_chat_template(
177
- chat_history, tokenize=True, add_generation_prompt=False, return_tensors="pt"
178
  ).tolist()[0]
179
  assert tokenized_chat[0] == 2
180
  assert tokenized_chat[-1] == 108
181
  assert tokenized_chat[-2] == 107
182
 
183
- if not is_use_cache:
184
- cache = None
185
- chat_history = []
186
-
187
  return model_output_text
188
 
189
- def infer(message, history = [], instruction = None, args = {}):
190
- return chat(message, history, instruction, args)
 
191
 
192
- def numel(message, history = [], instruction = None, args = {}):
193
- global tokenizer, model, cfg, cache, chat_history
194
 
195
- if instruction:
196
- cfg['instruction'] = instruction
197
- tokenized_chat = apply_template(message)
198
- else:
199
- messages = chatinterface_to_messages(message, history)
200
- tokenized_chat = apply_template(messages)
201
 
202
  return torch.numel(tokenized_chat)
 
14
 
15
  tokenizer = None
16
  model = None
17
+ cfg = {
18
  'size': None,
19
+ }
20
+ default_args = {
21
  'instruction': None,
22
+ 'first_assistant': None,
23
+ 'chat_template': None,
24
  'max_new_tokens': 1024,
25
  'temperature': 0.9,
26
  'top_p': 0.95,
27
  'top_k': 40,
28
  'repetition_penalty': 1.2,
29
  }
30
+
31
+ chat_past_key_values = {}
32
+ chat_messages = {}
33
 
34
  def load_model(size = '9b'):
35
  global tokenizer, model, cfg
 
52
 
53
  cfg['size'] = size
54
 
55
+ def set_config(size, instruction, first_assistant, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
56
+ global default_args
 
 
 
 
57
  load_model(size)
58
+ default_args.update({
59
  'instruction': instruction,
60
+ 'first_assistant': first_assistant,
61
+ 'chat_template': chat_template,
62
  'max_new_tokens': int(max_new_tokens),
63
  'temperature': float(temperature),
64
  'top_p': float(top_p),
 
68
  return 'done.'
69
 
70
  def set_config_args(args):
71
+ global default_args
72
 
73
  load_model(args['size'])
74
+ default_args.update(args)
75
 
76
  return 'done.'
77
 
78
+ def chatinterface_to_messages(history):
 
 
79
  messages = []
80
 
 
 
 
 
 
 
81
  for pair in history:
82
  [user, assistant] = pair
83
  if user:
 
85
  if assistant:
86
  messages.append({'role': 'assistant', 'content': assistant})
87
 
 
 
 
88
  return messages
89
 
90
+ # わりとややこしい
91
+ def tokenize(user_input, history = [], instruction = None, conversation_id = 'gradio', args = {}):
92
+ global tokenizer, chat_messages
93
 
94
+ # 先頭挿入用の形式づくり
95
+ inst_messages = []
96
+ if instruction:
97
+ if 'first_assistant' in args and args['first_assistant']:
98
+ # Claude互換形式
99
+ # userとassistantは交互に存在しないといけない
100
+ inst_messages = [
101
+ {'role': 'user', 'content': instruction},
102
+ {'role': 'assistant', 'content': args['first_assistant']},
103
+ ]
104
+ else:
105
+ # OpenAI互換形式
106
+ inst_messages = [{'role': 'system', 'content': instruction}]
107
+
108
+ # messagesがあるときは全部上書きする
109
+ if conversation_id and 'messages' in args:
110
+ chat_messages[conversation_id] = inst_messages + args['messages']
111
+
112
+ # cacheがあるならmessages形式で送る
113
+ # instructionは既にcacheされているので不要(途中変更不可)
114
+ if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
115
+ # user_inputを追加する
116
+ chat_messages[conversation_id] += [{'role': 'user', 'content': user_input}]
117
  tokenized_chat = tokenizer.apply_chat_template(
118
+ chat_messages[conversation_id], tokenize=True, add_generation_prompt=True, return_tensors="pt"
119
  )
120
+ else:
121
+ # instructionがあれば適用する(inputは任意)
122
+ if instruction:
123
+ user_input = instruction.format(input=user_input)
124
+ # どっちも無いとさすがにエラー
125
+ if not user_input:
126
+ raise ValueError('require input or instruction.')
127
+ tokenized_chat = tokenizer(user_input, return_tensors="pt").input_ids
128
+
129
  return tokenized_chat
130
 
131
+ def chat(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
132
+ global tokenizer, model, chat_past_key_values, chat_messages
133
 
134
+ for k, v in default_args.items():
135
+ args.setdefault(k, v)
136
+
137
+ cache = None
138
+ # conversation_idがあるときはcacheを読む
139
+ if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
140
+ # clearが指定されてるなら最初に消す
141
+ if 'clear' in args and args['clear']:
142
+ chat_past_key_values[conversation_id] = None
143
+ chat_messages[conversation_id] = None
144
+ else:
145
+ cache = chat_past_key_values[conversation_id]
146
+
147
+ # chat_templateがあれば適用する
148
+ if args['chat_template']:
149
+ tokenizer.chat_template = args['chat_template']
150
+
151
+ # tokenizeする
152
+ tokenized_chat = tokenize(message, history, instruction, conversation_id, args).to(device)
153
 
154
  device = local_gemma.utils.config.infer_device(None)
 
155
  generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat')
156
 
157
  streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True})
 
158
  generation_kwargs.update(
159
  {
160
  "streamer": streamer,
 
171
  'top_k',
172
  'repetition_penalty'
173
  ]:
174
+ if args[k]:
175
+ generation_kwargs[k] = args[k]
176
 
177
  # TODO(joao): this if shouldn't be needed, fix in transformers
178
  if cache is not None:
179
  generation_kwargs["cache_implementation"] = None
180
 
181
+ if args['max_new_tokens'] is not None:
182
  input_ids_len = tokenized_chat.shape[-1]
183
+ max_cache_len = args['max_new_tokens'] + input_ids_len
184
  if cache is not None and cache.max_cache_len < max_cache_len:
185
  # reset the cache
186
  generation_kwargs.pop("past_key_values")
 
194
  cache = gen_out.past_key_values
195
  model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:]
196
  model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True)
197
+ chat_messages += [{"role": "user", "content": message},]
198
+ chat_messages += [{"role": "assistant", "content": model_output_text},]
199
 
200
  # Sanity check: EOS was removed, ends in "<end_of_turn>\n"
201
  tokenized_chat = tokenizer.apply_chat_template(
202
+ chat_messages, tokenize=True, add_generation_prompt=False, return_tensors="pt"
203
  ).tolist()[0]
204
  assert tokenized_chat[0] == 2
205
  assert tokenized_chat[-1] == 108
206
  assert tokenized_chat[-2] == 107
207
 
208
+ # TODO: stream対応
 
 
 
209
  return model_output_text
210
 
211
+ # 非streamで返す
212
+ def infer(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
213
+ return chat(message, history, instruction, conversation_id, args)
214
 
215
+ def numel(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
216
+ global tokenizer, chat_messages
217
 
218
+ tokenized_chat = tokenize(message, history, instruction, conversation_id, args).to(device)
 
 
 
 
 
219
 
220
  return torch.numel(tokenized_chat)