Spaces:
Sleeping
Sleeping
import os | |
import re | |
import torch | |
import datetime | |
import json | |
import csv | |
import gc | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers import TextIteratorStreamer | |
from transformers import BitsAndBytesConfig, GPTQConfig | |
from threading import Thread | |
tokenizer = None | |
model = None | |
default_cfg = { | |
'model_name': None, | |
'qtype': 'bnb', | |
'dtype': '4bit', | |
'instruction': None, | |
'inst_template': None, | |
'chat_template': None, | |
'max_new_tokens': 1024, | |
'temperature': 0.9, | |
'top_p': 0.95, | |
'top_k': 40, | |
'repetition_penalty': 1.2, | |
} | |
cfg = default_cfg.copy() | |
def load_model(model_name, qtype = 'bnb', dtype = '4bit'): | |
global tokenizer, model, cfg | |
if cfg['model_name'] == model_name and cfg['qtype'] == qtype and cfg['dtype'] == dtype: | |
return | |
del model | |
del tokenizer | |
model = None | |
tokenizer = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
match qtype: | |
case 'bnb': | |
match dtype: | |
case '4bit' | 'int4': | |
kwargs = dict( | |
quantization_config=BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
), | |
) | |
case '8bit' | 'int8': | |
kwargs = dict( | |
quantization_config=BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
), | |
) | |
case 'fp16': | |
kwargs = dict( | |
torch_dtype=torch.float16, | |
) | |
case 'bf16': | |
kwargs = dict( | |
torch_dtype=torch.bfloat16, | |
) | |
case _: | |
kwargs = dict() | |
case 'gptq': | |
match dtype: | |
case '4bit' | 'int4': | |
kwargs = dict( | |
quantization_config=GPTQConfig( | |
bits=4, | |
tokenizer=tokenizer, | |
), | |
) | |
case '8bit' | 'int8': | |
kwargs = dict( | |
quantization_config=GPTQConfig( | |
bits=8, | |
tokenizer=tokenizer, | |
), | |
) | |
case 'gguf': | |
kwargs = dict( | |
gguf_file=qtype, | |
) | |
case 'awq': | |
match dtype: | |
case 'fa2': | |
kwargs = dict( | |
use_flash_attention_2=True, | |
) | |
case _: | |
kwargs = dict() | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
trust_remote_code=True, | |
**kwargs, | |
) | |
cfg['model_name'] = model_name | |
cfg['qtype'] = qtype | |
cfg['dtype'] = dtype | |
def clear_config(): | |
global cfg | |
cfg = default_cfg.copy() | |
def set_config(model_name, qtype, dtype, instruction, inst_template, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
global cfg | |
load_model(model_name, qtype, dtype) | |
cfg.update({ | |
'instruction': instruction, | |
'inst_template': inst_template, | |
'chat_template': chat_template, | |
'max_new_tokens': int(max_new_tokens), | |
'temperature': float(temperature), | |
'top_p': float(top_p), | |
'top_k': int(top_k), | |
'repetition_penalty': float(repetition_penalty), | |
}) | |
return 'done.' | |
def set_config_args(args): | |
global cfg | |
load_model(args['model_name'], args['qtype'], args['dtype']) | |
cfg.update(args) | |
return 'done.' | |
def chatinterface_to_messages(message, history): | |
global cfg | |
messages = [] | |
if cfg['instruction']: | |
messages.append({'role': 'system', 'content': cfg['instruction']}) | |
for pair in history: | |
[user, assistant] = pair | |
if user: | |
messages.append({'role': 'user', 'content': user}) | |
if assistant: | |
messages.append({'role': 'assistant', 'content': assistant}) | |
if message: | |
messages.append({'role': 'user', 'content': message}) | |
return messages | |
def chat(message, history = [], instruction = None, args = {}): | |
global tokenizer, model, cfg | |
if instruction: | |
cfg['instruction'] = instruction | |
prompt = apply_template(message) | |
else: | |
messages = chatinterface_to_messages(message, history) | |
prompt = apply_template(messages) | |
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device) | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True, | |
) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
do_sample=True, | |
num_beams=1, | |
) | |
for k in [ | |
'max_new_tokens', | |
'temperature', | |
'top_p', | |
'top_k', | |
'repetition_penalty' | |
]: | |
if cfg[k]: | |
generate_kwargs[k] = cfg[k] | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
model_output = "" | |
for new_text in streamer: | |
model_output += new_text | |
if 'fastapi' in args: | |
# fastapiは差分だけを返して欲しい | |
yield new_text | |
else: | |
# gradioは常に全文を返して欲しい | |
yield model_output | |
return model_output | |
def apply_template(messages): | |
global tokenizer, cfg | |
if cfg['chat_template']: | |
tokenizer.chat_template = cfg['chat_template'] | |
if type(messages) is str: | |
if cfg['inst_template']: | |
return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages) | |
return cfg['instruction'].format(input=messages) | |
if type(messages) is list: | |
return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False) | |