Spaces:
Runtime error
Runtime error
File size: 7,649 Bytes
3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 ac42875 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 1b165d8 3350c44 d347e01 3350c44 d347e01 3350c44 d347e01 3350c44 02ad52c 3350c44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import os
import re
import torch
import datetime
import json
import csv
import gc
import local_gemma
from transformers import AutoTokenizer, TextStreamer
from transformers import TextIteratorStreamer
from transformers import BitsAndBytesConfig, GPTQConfig
from threading import Thread
tokenizer = None
model = None
cfg = {
'size': None,
}
default_args = {
'instruction': None,
'first_assistant': None,
'chat_template': None,
'max_new_tokens': 1024,
'temperature': 0.9,
'top_p': 0.95,
'top_k': 40,
'repetition_penalty': 1.2,
}
chat_past_key_values = {}
chat_messages = {}
def load_model(size = '9b'):
global tokenizer, model, cfg
if cfg['size'] == size:
return
del model
del tokenizer
model = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
model_name = f"SillyTilly/google-gemma-2-{size}-it"
model = local_gemma.LocalGemma2ForCausalLM.from_pretrained(model_name, preset="memory")
model._supports_cache_class = True
tokenizer = AutoTokenizer.from_pretrained(model_name)
cfg['size'] = size
def set_config(size, instruction, first_assistant, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
global default_args
load_model(size)
default_args.update({
'instruction': instruction,
'first_assistant': first_assistant,
'chat_template': chat_template,
'max_new_tokens': int(max_new_tokens),
'temperature': float(temperature),
'top_p': float(top_p),
'top_k': int(top_k),
'repetition_penalty': float(repetition_penalty),
})
return 'done.'
def set_config_args(args):
global default_args
load_model(args['size'])
default_args.update(args)
return 'done.'
def chatinterface_to_messages(history):
messages = []
for pair in history:
[user, assistant] = pair
if user:
messages.append({'role': 'user', 'content': user})
if assistant:
messages.append({'role': 'assistant', 'content': assistant})
return messages
# わりとややこしい
def tokenize(user_input, history = [], instruction = None, conversation_id = 'gradio', args = {}):
global tokenizer, chat_messages
# 先頭挿入用の形式づくり
inst_messages = []
if instruction:
if 'first_assistant' in args and args['first_assistant']:
# Claude互換形式
# userとassistantは交互に存在しないといけない
inst_messages = [
{'role': 'user', 'content': instruction},
{'role': 'assistant', 'content': args['first_assistant']},
]
else:
# OpenAI互換形式
inst_messages = [{'role': 'system', 'content': instruction}]
# messagesがあるときは全部上書きする
if conversation_id and 'messages' in args:
chat_messages[conversation_id] = inst_messages + args['messages']
# cacheがあるならmessages形式で送る
# instructionは既にcacheされているので不要(途中変更不可)
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
# user_inputを追加する
chat_messages[conversation_id] += [{'role': 'user', 'content': user_input}]
tokenized_chat = tokenizer.apply_chat_template(
chat_messages[conversation_id], tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
else:
# instructionがあれば適用する(inputは任意)
if instruction:
user_input = instruction.format(input=user_input)
# どっちも無いとさすがにエラー
if not user_input:
raise ValueError('require input or instruction.')
tokenized_chat = tokenizer(user_input, return_tensors="pt").input_ids
return tokenized_chat
def chat(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
global tokenizer, model, chat_past_key_values, chat_messages
for k, v in default_args.items():
args.setdefault(k, v)
cache = None
# conversation_idがあるときはcacheを読む
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
# clearが指定されてるなら最初に消す
if 'clear' in args and args['clear']:
chat_past_key_values[conversation_id] = None
chat_messages[conversation_id] = None
else:
cache = chat_past_key_values[conversation_id]
# chat_templateがあれば適用する
if args['chat_template']:
tokenizer.chat_template = args['chat_template']
device = local_gemma.utils.config.infer_device(None)
generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat')
# tokenizeする
tokenized_chat = tokenize(message, history, instruction, conversation_id, args).to(device)
streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True})
generation_kwargs.update(
{
"streamer": streamer,
"assistant_model": None,
"return_dict_in_generate": True,
"past_key_values": cache,
}
)
for k in [
'max_new_tokens',
'temperature',
'top_p',
'top_k',
'repetition_penalty'
]:
if args[k]:
generation_kwargs[k] = args[k]
# TODO(joao): this if shouldn't be needed, fix in transformers
if cache is not None:
generation_kwargs["cache_implementation"] = None
if args['max_new_tokens'] is not None:
input_ids_len = tokenized_chat.shape[-1]
max_cache_len = args['max_new_tokens'] + input_ids_len
if cache is not None and cache.max_cache_len < max_cache_len:
# reset the cache
generation_kwargs.pop("past_key_values")
generation_kwargs["cache_implementation"] = "hybrid"
else:
generation_kwargs["max_length"] = model.config.max_position_embeddings
gen_out = model.generate(input_ids=tokenized_chat, **generation_kwargs)
model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:]
model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True)
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
# Store the cache for the next generation round; Pull the model output into the chat history.
chat_past_key_values[conversation_id] = gen_out.past_key_values
chat_messages[conversation_id] += [{"role": "user", "content": message},]
chat_messages[conversation_id] += [{"role": "assistant", "content": model_output_text},]
# Sanity check: EOS was removed, ends in "<end_of_turn>\n"
tokenized_chat = tokenizer.apply_chat_template(
chat_messages[conversation_id], tokenize=True, add_generation_prompt=False, return_tensors="pt"
).tolist()[0]
assert tokenized_chat[0] == 2
assert tokenized_chat[-1] == 108
assert tokenized_chat[-2] == 107
# TODO: stream対応
return model_output_text
# 非streamで返す
def infer(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
return chat(message, history, instruction, conversation_id, args)
def numel(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
global tokenizer, chat_messages
tokenized_chat = tokenize(message, history, instruction, conversation_id, args)
return torch.numel(tokenized_chat)
|