localgemma2 / fn.py
aka7774's picture
Update fn.py
02ad52c verified
import os
import re
import torch
import datetime
import json
import csv
import gc
import local_gemma
from transformers import AutoTokenizer, TextStreamer
from transformers import TextIteratorStreamer
from transformers import BitsAndBytesConfig, GPTQConfig
from threading import Thread
tokenizer = None
model = None
cfg = {
'size': None,
}
default_args = {
'instruction': None,
'first_assistant': None,
'chat_template': None,
'max_new_tokens': 1024,
'temperature': 0.9,
'top_p': 0.95,
'top_k': 40,
'repetition_penalty': 1.2,
}
chat_past_key_values = {}
chat_messages = {}
def load_model(size = '9b'):
global tokenizer, model, cfg
if cfg['size'] == size:
return
del model
del tokenizer
model = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
model_name = f"SillyTilly/google-gemma-2-{size}-it"
model = local_gemma.LocalGemma2ForCausalLM.from_pretrained(model_name, preset="memory")
model._supports_cache_class = True
tokenizer = AutoTokenizer.from_pretrained(model_name)
cfg['size'] = size
def set_config(size, instruction, first_assistant, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
global default_args
load_model(size)
default_args.update({
'instruction': instruction,
'first_assistant': first_assistant,
'chat_template': chat_template,
'max_new_tokens': int(max_new_tokens),
'temperature': float(temperature),
'top_p': float(top_p),
'top_k': int(top_k),
'repetition_penalty': float(repetition_penalty),
})
return 'done.'
def set_config_args(args):
global default_args
load_model(args['size'])
default_args.update(args)
return 'done.'
def chatinterface_to_messages(history):
messages = []
for pair in history:
[user, assistant] = pair
if user:
messages.append({'role': 'user', 'content': user})
if assistant:
messages.append({'role': 'assistant', 'content': assistant})
return messages
# わりとややこしい
def tokenize(user_input, history = [], instruction = None, conversation_id = 'gradio', args = {}):
global tokenizer, chat_messages
# 先頭挿入用の形式づくり
inst_messages = []
if instruction:
if 'first_assistant' in args and args['first_assistant']:
# Claude互換形式
# userとassistantは交互に存在しないといけない
inst_messages = [
{'role': 'user', 'content': instruction},
{'role': 'assistant', 'content': args['first_assistant']},
]
else:
# OpenAI互換形式
inst_messages = [{'role': 'system', 'content': instruction}]
# messagesがあるときは全部上書きする
if conversation_id and 'messages' in args:
chat_messages[conversation_id] = inst_messages + args['messages']
# cacheがあるならmessages形式で送る
# instructionは既にcacheされているので不要(途中変更不可)
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
# user_inputを追加する
chat_messages[conversation_id] += [{'role': 'user', 'content': user_input}]
tokenized_chat = tokenizer.apply_chat_template(
chat_messages[conversation_id], tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
else:
# instructionがあれば適用する(inputは任意)
if instruction:
user_input = instruction.format(input=user_input)
# どっちも無いとさすがにエラー
if not user_input:
raise ValueError('require input or instruction.')
tokenized_chat = tokenizer(user_input, return_tensors="pt").input_ids
return tokenized_chat
def chat(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
global tokenizer, model, chat_past_key_values, chat_messages
for k, v in default_args.items():
args.setdefault(k, v)
cache = None
# conversation_idがあるときはcacheを読む
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
# clearが指定されてるなら最初に消す
if 'clear' in args and args['clear']:
chat_past_key_values[conversation_id] = None
chat_messages[conversation_id] = None
else:
cache = chat_past_key_values[conversation_id]
# chat_templateがあれば適用する
if args['chat_template']:
tokenizer.chat_template = args['chat_template']
device = local_gemma.utils.config.infer_device(None)
generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat')
# tokenizeする
tokenized_chat = tokenize(message, history, instruction, conversation_id, args).to(device)
streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True})
generation_kwargs.update(
{
"streamer": streamer,
"assistant_model": None,
"return_dict_in_generate": True,
"past_key_values": cache,
}
)
for k in [
'max_new_tokens',
'temperature',
'top_p',
'top_k',
'repetition_penalty'
]:
if args[k]:
generation_kwargs[k] = args[k]
# TODO(joao): this if shouldn't be needed, fix in transformers
if cache is not None:
generation_kwargs["cache_implementation"] = None
if args['max_new_tokens'] is not None:
input_ids_len = tokenized_chat.shape[-1]
max_cache_len = args['max_new_tokens'] + input_ids_len
if cache is not None and cache.max_cache_len < max_cache_len:
# reset the cache
generation_kwargs.pop("past_key_values")
generation_kwargs["cache_implementation"] = "hybrid"
else:
generation_kwargs["max_length"] = model.config.max_position_embeddings
gen_out = model.generate(input_ids=tokenized_chat, **generation_kwargs)
model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:]
model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True)
if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
# Store the cache for the next generation round; Pull the model output into the chat history.
chat_past_key_values[conversation_id] = gen_out.past_key_values
chat_messages[conversation_id] += [{"role": "user", "content": message},]
chat_messages[conversation_id] += [{"role": "assistant", "content": model_output_text},]
# Sanity check: EOS was removed, ends in "<end_of_turn>\n"
tokenized_chat = tokenizer.apply_chat_template(
chat_messages[conversation_id], tokenize=True, add_generation_prompt=False, return_tensors="pt"
).tolist()[0]
assert tokenized_chat[0] == 2
assert tokenized_chat[-1] == 108
assert tokenized_chat[-2] == 107
# TODO: stream対応
return model_output_text
# 非streamで返す
def infer(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
return chat(message, history, instruction, conversation_id, args)
def numel(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
global tokenizer, chat_messages
tokenized_chat = tokenize(message, history, instruction, conversation_id, args)
return torch.numel(tokenized_chat)