Spaces:
Runtime error
Runtime error
import os | |
import re | |
import torch | |
import datetime | |
import json | |
import csv | |
import gc | |
import local_gemma | |
from transformers import AutoTokenizer, TextStreamer | |
from transformers import TextIteratorStreamer | |
from transformers import BitsAndBytesConfig, GPTQConfig | |
from threading import Thread | |
tokenizer = None | |
model = None | |
cfg = { | |
'size': None, | |
} | |
default_args = { | |
'instruction': None, | |
'first_assistant': None, | |
'chat_template': None, | |
'max_new_tokens': 1024, | |
'temperature': 0.9, | |
'top_p': 0.95, | |
'top_k': 40, | |
'repetition_penalty': 1.2, | |
} | |
chat_past_key_values = {} | |
chat_messages = {} | |
def load_model(size = '9b'): | |
global tokenizer, model, cfg | |
if cfg['size'] == size: | |
return | |
del model | |
del tokenizer | |
model = None | |
tokenizer = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
model_name = f"SillyTilly/google-gemma-2-{size}-it" | |
model = local_gemma.LocalGemma2ForCausalLM.from_pretrained(model_name, preset="memory") | |
model._supports_cache_class = True | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
cfg['size'] = size | |
def set_config(size, instruction, first_assistant, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
global default_args | |
load_model(size) | |
default_args.update({ | |
'instruction': instruction, | |
'first_assistant': first_assistant, | |
'chat_template': chat_template, | |
'max_new_tokens': int(max_new_tokens), | |
'temperature': float(temperature), | |
'top_p': float(top_p), | |
'top_k': int(top_k), | |
'repetition_penalty': float(repetition_penalty), | |
}) | |
return 'done.' | |
def set_config_args(args): | |
global default_args | |
load_model(args['size']) | |
default_args.update(args) | |
return 'done.' | |
def chatinterface_to_messages(history): | |
messages = [] | |
for pair in history: | |
[user, assistant] = pair | |
if user: | |
messages.append({'role': 'user', 'content': user}) | |
if assistant: | |
messages.append({'role': 'assistant', 'content': assistant}) | |
return messages | |
# わりとややこしい | |
def tokenize(user_input, history = [], instruction = None, conversation_id = 'gradio', args = {}): | |
global tokenizer, chat_messages | |
# 先頭挿入用の形式づくり | |
inst_messages = [] | |
if instruction: | |
if 'first_assistant' in args and args['first_assistant']: | |
# Claude互換形式 | |
# userとassistantは交互に存在しないといけない | |
inst_messages = [ | |
{'role': 'user', 'content': instruction}, | |
{'role': 'assistant', 'content': args['first_assistant']}, | |
] | |
else: | |
# OpenAI互換形式 | |
inst_messages = [{'role': 'system', 'content': instruction}] | |
# messagesがあるときは全部上書きする | |
if conversation_id and 'messages' in args: | |
chat_messages[conversation_id] = inst_messages + args['messages'] | |
# cacheがあるならmessages形式で送る | |
# instructionは既にcacheされているので不要(途中変更不可) | |
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]: | |
# user_inputを追加する | |
chat_messages[conversation_id] += [{'role': 'user', 'content': user_input}] | |
tokenized_chat = tokenizer.apply_chat_template( | |
chat_messages[conversation_id], tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
) | |
else: | |
# instructionがあれば適用する(inputは任意) | |
if instruction: | |
user_input = instruction.format(input=user_input) | |
# どっちも無いとさすがにエラー | |
if not user_input: | |
raise ValueError('require input or instruction.') | |
tokenized_chat = tokenizer(user_input, return_tensors="pt").input_ids | |
return tokenized_chat | |
def chat(message, history = [], instruction = None, conversation_id = 'gradio', args = {}): | |
global tokenizer, model, chat_past_key_values, chat_messages | |
for k, v in default_args.items(): | |
args.setdefault(k, v) | |
cache = None | |
# conversation_idがあるときはcacheを読む | |
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]: | |
# clearが指定されてるなら最初に消す | |
if 'clear' in args and args['clear']: | |
chat_past_key_values[conversation_id] = None | |
chat_messages[conversation_id] = None | |
else: | |
cache = chat_past_key_values[conversation_id] | |
# chat_templateがあれば適用する | |
if args['chat_template']: | |
tokenizer.chat_template = args['chat_template'] | |
device = local_gemma.utils.config.infer_device(None) | |
generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat') | |
# tokenizeする | |
tokenized_chat = tokenize(message, history, instruction, conversation_id, args).to(device) | |
streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True}) | |
generation_kwargs.update( | |
{ | |
"streamer": streamer, | |
"assistant_model": None, | |
"return_dict_in_generate": True, | |
"past_key_values": cache, | |
} | |
) | |
for k in [ | |
'max_new_tokens', | |
'temperature', | |
'top_p', | |
'top_k', | |
'repetition_penalty' | |
]: | |
if args[k]: | |
generation_kwargs[k] = args[k] | |
# TODO(joao): this if shouldn't be needed, fix in transformers | |
if cache is not None: | |
generation_kwargs["cache_implementation"] = None | |
if args['max_new_tokens'] is not None: | |
input_ids_len = tokenized_chat.shape[-1] | |
max_cache_len = args['max_new_tokens'] + input_ids_len | |
if cache is not None and cache.max_cache_len < max_cache_len: | |
# reset the cache | |
generation_kwargs.pop("past_key_values") | |
generation_kwargs["cache_implementation"] = "hybrid" | |
else: | |
generation_kwargs["max_length"] = model.config.max_position_embeddings | |
gen_out = model.generate(input_ids=tokenized_chat, **generation_kwargs) | |
model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:] | |
model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True) | |
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]: | |
# Store the cache for the next generation round; Pull the model output into the chat history. | |
chat_past_key_values[conversation_id] = gen_out.past_key_values | |
chat_messages[conversation_id] += [{"role": "user", "content": message},] | |
chat_messages[conversation_id] += [{"role": "assistant", "content": model_output_text},] | |
# Sanity check: EOS was removed, ends in "<end_of_turn>\n" | |
tokenized_chat = tokenizer.apply_chat_template( | |
chat_messages[conversation_id], tokenize=True, add_generation_prompt=False, return_tensors="pt" | |
).tolist()[0] | |
assert tokenized_chat[0] == 2 | |
assert tokenized_chat[-1] == 108 | |
assert tokenized_chat[-2] == 107 | |
# TODO: stream対応 | |
return model_output_text | |
# 非streamで返す | |
def infer(message, history = [], instruction = None, conversation_id = 'gradio', args = {}): | |
return chat(message, history, instruction, conversation_id, args) | |
def numel(message, history = [], instruction = None, conversation_id = 'gradio', args = {}): | |
global tokenizer, chat_messages | |
tokenized_chat = tokenize(message, history, instruction, conversation_id, args) | |
return torch.numel(tokenized_chat) | |