Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files
app.py
CHANGED
@@ -16,31 +16,31 @@ with gr.Blocks() as demo:
|
|
16 |
|
17 |
with gr.Column(scale=1):
|
18 |
max_new_tokens = gr.Textbox(
|
19 |
-
value=fn.
|
20 |
label='max_new_tokens',
|
21 |
interactive=True,
|
22 |
show_copy_button=True,
|
23 |
)
|
24 |
temperature = gr.Textbox(
|
25 |
-
value=fn.
|
26 |
label='temperature',
|
27 |
interactive=True,
|
28 |
show_copy_button=True,
|
29 |
)
|
30 |
top_p = gr.Textbox(
|
31 |
-
value=fn.
|
32 |
label='top_p',
|
33 |
interactive=True,
|
34 |
show_copy_button=True,
|
35 |
)
|
36 |
top_k = gr.Textbox(
|
37 |
-
value=fn.
|
38 |
label='top_k',
|
39 |
interactive=True,
|
40 |
show_copy_button=True,
|
41 |
)
|
42 |
repetition_penalty = gr.Textbox(
|
43 |
-
value=fn.
|
44 |
label='repetition_penalty',
|
45 |
interactive=True,
|
46 |
show_copy_button=True,
|
@@ -48,17 +48,18 @@ with gr.Blocks() as demo:
|
|
48 |
|
49 |
with gr.Row():
|
50 |
with gr.Column(scale=1):
|
51 |
-
|
52 |
value='',
|
53 |
-
|
54 |
-
label='inst_template',
|
55 |
interactive=True,
|
56 |
show_copy_button=True,
|
57 |
)
|
58 |
-
|
59 |
-
value=
|
60 |
-
|
|
|
61 |
interactive=True,
|
|
|
62 |
)
|
63 |
|
64 |
set_button = gr.Button(value='Save')
|
@@ -97,7 +98,7 @@ with gr.Blocks() as demo:
|
|
97 |
|
98 |
set_button.click(
|
99 |
fn=fn.set_config,
|
100 |
-
inputs=[size, instruction,
|
101 |
outputs=[info],
|
102 |
)
|
103 |
|
|
|
16 |
|
17 |
with gr.Column(scale=1):
|
18 |
max_new_tokens = gr.Textbox(
|
19 |
+
value=fn.default_args['max_new_tokens'],
|
20 |
label='max_new_tokens',
|
21 |
interactive=True,
|
22 |
show_copy_button=True,
|
23 |
)
|
24 |
temperature = gr.Textbox(
|
25 |
+
value=fn.default_args['temperature'],
|
26 |
label='temperature',
|
27 |
interactive=True,
|
28 |
show_copy_button=True,
|
29 |
)
|
30 |
top_p = gr.Textbox(
|
31 |
+
value=fn.default_args['top_p'],
|
32 |
label='top_p',
|
33 |
interactive=True,
|
34 |
show_copy_button=True,
|
35 |
)
|
36 |
top_k = gr.Textbox(
|
37 |
+
value=fn.default_args['top_k'],
|
38 |
label='top_k',
|
39 |
interactive=True,
|
40 |
show_copy_button=True,
|
41 |
)
|
42 |
repetition_penalty = gr.Textbox(
|
43 |
+
value=fn.default_args['repetition_penalty'],
|
44 |
label='repetition_penalty',
|
45 |
interactive=True,
|
46 |
show_copy_button=True,
|
|
|
48 |
|
49 |
with gr.Row():
|
50 |
with gr.Column(scale=1):
|
51 |
+
first_assistant = gr.Textbox(
|
52 |
value='',
|
53 |
+
label='first_assistant',
|
|
|
54 |
interactive=True,
|
55 |
show_copy_button=True,
|
56 |
)
|
57 |
+
chat_template = gr.Textbox(
|
58 |
+
value='',
|
59 |
+
lines=10,
|
60 |
+
label='chat_template',
|
61 |
interactive=True,
|
62 |
+
show_copy_button=True,
|
63 |
)
|
64 |
|
65 |
set_button = gr.Button(value='Save')
|
|
|
98 |
|
99 |
set_button.click(
|
100 |
fn=fn.set_config,
|
101 |
+
inputs=[size, instruction, first_assistant, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
102 |
outputs=[info],
|
103 |
)
|
104 |
|
fn.py
CHANGED
@@ -14,20 +14,22 @@ from threading import Thread
|
|
14 |
|
15 |
tokenizer = None
|
16 |
model = None
|
17 |
-
|
18 |
'size': None,
|
|
|
|
|
19 |
'instruction': None,
|
20 |
-
'
|
21 |
-
'
|
22 |
'max_new_tokens': 1024,
|
23 |
'temperature': 0.9,
|
24 |
'top_p': 0.95,
|
25 |
'top_k': 40,
|
26 |
'repetition_penalty': 1.2,
|
27 |
}
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
|
32 |
def load_model(size = '9b'):
|
33 |
global tokenizer, model, cfg
|
@@ -50,17 +52,13 @@ def load_model(size = '9b'):
|
|
50 |
|
51 |
cfg['size'] = size
|
52 |
|
53 |
-
def
|
54 |
-
global
|
55 |
-
cfg = default_cfg.copy()
|
56 |
-
|
57 |
-
def set_config(size, instruction, inst_template, is_use_cache, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
|
58 |
-
global cfg
|
59 |
load_model(size)
|
60 |
-
|
61 |
'instruction': instruction,
|
62 |
-
'
|
63 |
-
'
|
64 |
'max_new_tokens': int(max_new_tokens),
|
65 |
'temperature': float(temperature),
|
66 |
'top_p': float(top_p),
|
@@ -70,24 +68,16 @@ def set_config(size, instruction, inst_template, is_use_cache, max_new_tokens, t
|
|
70 |
return 'done.'
|
71 |
|
72 |
def set_config_args(args):
|
73 |
-
global
|
74 |
|
75 |
load_model(args['size'])
|
76 |
-
|
77 |
|
78 |
return 'done.'
|
79 |
|
80 |
-
def chatinterface_to_messages(
|
81 |
-
global cfg
|
82 |
-
|
83 |
messages = []
|
84 |
|
85 |
-
if cfg['instruction']:
|
86 |
-
messages.append({'role': 'user', 'content': cfg['instruction']})
|
87 |
-
# userとassistantは交互に存在しないといけない
|
88 |
-
if message:
|
89 |
-
messages.append({'role': 'assistant', 'content': '了解しました。'})
|
90 |
-
|
91 |
for pair in history:
|
92 |
[user, assistant] = pair
|
93 |
if user:
|
@@ -95,41 +85,76 @@ def chatinterface_to_messages(message, history):
|
|
95 |
if assistant:
|
96 |
messages.append({'role': 'assistant', 'content': assistant})
|
97 |
|
98 |
-
if message:
|
99 |
-
messages.append({'role': 'user', 'content': message})
|
100 |
-
|
101 |
return messages
|
102 |
|
103 |
-
|
104 |
-
|
|
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
tokenized_chat = tokenizer.apply_chat_template(
|
113 |
-
|
114 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
return tokenized_chat
|
116 |
|
117 |
-
def chat(message, history = [], instruction = None, args = {}):
|
118 |
-
global tokenizer, model,
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
device = local_gemma.utils.config.infer_device(None)
|
128 |
-
is_use_cache = cfg['is_use_cache']
|
129 |
generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat')
|
130 |
|
131 |
streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True})
|
132 |
-
tokenized_chat = tokenized_chat.to(device)
|
133 |
generation_kwargs.update(
|
134 |
{
|
135 |
"streamer": streamer,
|
@@ -146,16 +171,16 @@ def chat(message, history = [], instruction = None, args = {}):
|
|
146 |
'top_k',
|
147 |
'repetition_penalty'
|
148 |
]:
|
149 |
-
if
|
150 |
-
generation_kwargs[k] =
|
151 |
|
152 |
# TODO(joao): this if shouldn't be needed, fix in transformers
|
153 |
if cache is not None:
|
154 |
generation_kwargs["cache_implementation"] = None
|
155 |
|
156 |
-
if
|
157 |
input_ids_len = tokenized_chat.shape[-1]
|
158 |
-
max_cache_len =
|
159 |
if cache is not None and cache.max_cache_len < max_cache_len:
|
160 |
# reset the cache
|
161 |
generation_kwargs.pop("past_key_values")
|
@@ -169,34 +194,27 @@ def chat(message, history = [], instruction = None, args = {}):
|
|
169 |
cache = gen_out.past_key_values
|
170 |
model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:]
|
171 |
model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True)
|
172 |
-
|
173 |
-
|
174 |
|
175 |
# Sanity check: EOS was removed, ends in "<end_of_turn>\n"
|
176 |
tokenized_chat = tokenizer.apply_chat_template(
|
177 |
-
|
178 |
).tolist()[0]
|
179 |
assert tokenized_chat[0] == 2
|
180 |
assert tokenized_chat[-1] == 108
|
181 |
assert tokenized_chat[-2] == 107
|
182 |
|
183 |
-
|
184 |
-
cache = None
|
185 |
-
chat_history = []
|
186 |
-
|
187 |
return model_output_text
|
188 |
|
189 |
-
|
190 |
-
|
|
|
191 |
|
192 |
-
def numel(message, history = [], instruction = None, args = {}):
|
193 |
-
global tokenizer,
|
194 |
|
195 |
-
|
196 |
-
cfg['instruction'] = instruction
|
197 |
-
tokenized_chat = apply_template(message)
|
198 |
-
else:
|
199 |
-
messages = chatinterface_to_messages(message, history)
|
200 |
-
tokenized_chat = apply_template(messages)
|
201 |
|
202 |
return torch.numel(tokenized_chat)
|
|
|
14 |
|
15 |
tokenizer = None
|
16 |
model = None
|
17 |
+
cfg = {
|
18 |
'size': None,
|
19 |
+
}
|
20 |
+
default_args = {
|
21 |
'instruction': None,
|
22 |
+
'first_assistant': None,
|
23 |
+
'chat_template': None,
|
24 |
'max_new_tokens': 1024,
|
25 |
'temperature': 0.9,
|
26 |
'top_p': 0.95,
|
27 |
'top_k': 40,
|
28 |
'repetition_penalty': 1.2,
|
29 |
}
|
30 |
+
|
31 |
+
chat_past_key_values = {}
|
32 |
+
chat_messages = {}
|
33 |
|
34 |
def load_model(size = '9b'):
|
35 |
global tokenizer, model, cfg
|
|
|
52 |
|
53 |
cfg['size'] = size
|
54 |
|
55 |
+
def set_config(size, instruction, first_assistant, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
|
56 |
+
global default_args
|
|
|
|
|
|
|
|
|
57 |
load_model(size)
|
58 |
+
default_args.update({
|
59 |
'instruction': instruction,
|
60 |
+
'first_assistant': first_assistant,
|
61 |
+
'chat_template': chat_template,
|
62 |
'max_new_tokens': int(max_new_tokens),
|
63 |
'temperature': float(temperature),
|
64 |
'top_p': float(top_p),
|
|
|
68 |
return 'done.'
|
69 |
|
70 |
def set_config_args(args):
|
71 |
+
global default_args
|
72 |
|
73 |
load_model(args['size'])
|
74 |
+
default_args.update(args)
|
75 |
|
76 |
return 'done.'
|
77 |
|
78 |
+
def chatinterface_to_messages(history):
|
|
|
|
|
79 |
messages = []
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
for pair in history:
|
82 |
[user, assistant] = pair
|
83 |
if user:
|
|
|
85 |
if assistant:
|
86 |
messages.append({'role': 'assistant', 'content': assistant})
|
87 |
|
|
|
|
|
|
|
88 |
return messages
|
89 |
|
90 |
+
# わりとややこしい
|
91 |
+
def tokenize(user_input, history = [], instruction = None, conversation_id = 'gradio', args = {}):
|
92 |
+
global tokenizer, chat_messages
|
93 |
|
94 |
+
# 先頭挿入用の形式づくり
|
95 |
+
inst_messages = []
|
96 |
+
if instruction:
|
97 |
+
if 'first_assistant' in args and args['first_assistant']:
|
98 |
+
# Claude互換形式
|
99 |
+
# userとassistantは交互に存在しないといけない
|
100 |
+
inst_messages = [
|
101 |
+
{'role': 'user', 'content': instruction},
|
102 |
+
{'role': 'assistant', 'content': args['first_assistant']},
|
103 |
+
]
|
104 |
+
else:
|
105 |
+
# OpenAI互換形式
|
106 |
+
inst_messages = [{'role': 'system', 'content': instruction}]
|
107 |
+
|
108 |
+
# messagesがあるときは全部上書きする
|
109 |
+
if conversation_id and 'messages' in args:
|
110 |
+
chat_messages[conversation_id] = inst_messages + args['messages']
|
111 |
+
|
112 |
+
# cacheがあるならmessages形式で送る
|
113 |
+
# instructionは既にcacheされているので不要(途中変更不可)
|
114 |
+
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
|
115 |
+
# user_inputを追加する
|
116 |
+
chat_messages[conversation_id] += [{'role': 'user', 'content': user_input}]
|
117 |
tokenized_chat = tokenizer.apply_chat_template(
|
118 |
+
chat_messages[conversation_id], tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
119 |
)
|
120 |
+
else:
|
121 |
+
# instructionがあれば適用する(inputは任意)
|
122 |
+
if instruction:
|
123 |
+
user_input = instruction.format(input=user_input)
|
124 |
+
# どっちも無いとさすがにエラー
|
125 |
+
if not user_input:
|
126 |
+
raise ValueError('require input or instruction.')
|
127 |
+
tokenized_chat = tokenizer(user_input, return_tensors="pt").input_ids
|
128 |
+
|
129 |
return tokenized_chat
|
130 |
|
131 |
+
def chat(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
|
132 |
+
global tokenizer, model, chat_past_key_values, chat_messages
|
133 |
|
134 |
+
for k, v in default_args.items():
|
135 |
+
args.setdefault(k, v)
|
136 |
+
|
137 |
+
cache = None
|
138 |
+
# conversation_idがあるときはcacheを読む
|
139 |
+
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
|
140 |
+
# clearが指定されてるなら最初に消す
|
141 |
+
if 'clear' in args and args['clear']:
|
142 |
+
chat_past_key_values[conversation_id] = None
|
143 |
+
chat_messages[conversation_id] = None
|
144 |
+
else:
|
145 |
+
cache = chat_past_key_values[conversation_id]
|
146 |
+
|
147 |
+
# chat_templateがあれば適用する
|
148 |
+
if args['chat_template']:
|
149 |
+
tokenizer.chat_template = args['chat_template']
|
150 |
+
|
151 |
+
# tokenizeする
|
152 |
+
tokenized_chat = tokenize(message, history, instruction, conversation_id, args).to(device)
|
153 |
|
154 |
device = local_gemma.utils.config.infer_device(None)
|
|
|
155 |
generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat')
|
156 |
|
157 |
streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True})
|
|
|
158 |
generation_kwargs.update(
|
159 |
{
|
160 |
"streamer": streamer,
|
|
|
171 |
'top_k',
|
172 |
'repetition_penalty'
|
173 |
]:
|
174 |
+
if args[k]:
|
175 |
+
generation_kwargs[k] = args[k]
|
176 |
|
177 |
# TODO(joao): this if shouldn't be needed, fix in transformers
|
178 |
if cache is not None:
|
179 |
generation_kwargs["cache_implementation"] = None
|
180 |
|
181 |
+
if args['max_new_tokens'] is not None:
|
182 |
input_ids_len = tokenized_chat.shape[-1]
|
183 |
+
max_cache_len = args['max_new_tokens'] + input_ids_len
|
184 |
if cache is not None and cache.max_cache_len < max_cache_len:
|
185 |
# reset the cache
|
186 |
generation_kwargs.pop("past_key_values")
|
|
|
194 |
cache = gen_out.past_key_values
|
195 |
model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:]
|
196 |
model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True)
|
197 |
+
chat_messages += [{"role": "user", "content": message},]
|
198 |
+
chat_messages += [{"role": "assistant", "content": model_output_text},]
|
199 |
|
200 |
# Sanity check: EOS was removed, ends in "<end_of_turn>\n"
|
201 |
tokenized_chat = tokenizer.apply_chat_template(
|
202 |
+
chat_messages, tokenize=True, add_generation_prompt=False, return_tensors="pt"
|
203 |
).tolist()[0]
|
204 |
assert tokenized_chat[0] == 2
|
205 |
assert tokenized_chat[-1] == 108
|
206 |
assert tokenized_chat[-2] == 107
|
207 |
|
208 |
+
# TODO: stream対応
|
|
|
|
|
|
|
209 |
return model_output_text
|
210 |
|
211 |
+
# 非streamで返す
|
212 |
+
def infer(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
|
213 |
+
return chat(message, history, instruction, conversation_id, args)
|
214 |
|
215 |
+
def numel(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
|
216 |
+
global tokenizer, chat_messages
|
217 |
|
218 |
+
tokenized_chat = tokenize(message, history, instruction, conversation_id, args).to(device)
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
return torch.numel(tokenized_chat)
|