File size: 7,649 Bytes
3350c44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d347e01
3350c44
d347e01
 
3350c44
d347e01
 
3350c44
 
 
 
 
 
d347e01
 
 
3350c44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d347e01
 
3350c44
d347e01
3350c44
d347e01
 
3350c44
 
 
 
 
 
 
 
 
d347e01
3350c44
 
d347e01
3350c44
 
 
d347e01
3350c44
 
 
 
 
 
 
 
 
 
 
d347e01
 
 
3350c44
d347e01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3350c44
d347e01
3350c44
d347e01
 
 
 
 
 
 
 
 
3350c44
 
d347e01
 
3350c44
d347e01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3350c44
 
 
ac42875
 
 
3350c44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d347e01
 
3350c44
 
 
 
 
d347e01
3350c44
d347e01
3350c44
 
 
 
 
 
 
 
 
 
1b165d8
 
 
 
 
 
 
 
 
 
 
 
 
 
3350c44
d347e01
3350c44
 
d347e01
 
 
3350c44
d347e01
 
3350c44
02ad52c
3350c44
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os
import re
import torch
import datetime
import json
import csv
import gc

import local_gemma
from transformers import AutoTokenizer, TextStreamer
from transformers import TextIteratorStreamer
from transformers import BitsAndBytesConfig, GPTQConfig
from threading import Thread

tokenizer = None
model = None
cfg = {
    'size': None,
}
default_args = {
    'instruction': None,
    'first_assistant': None,
    'chat_template': None,
    'max_new_tokens': 1024,
    'temperature': 0.9,
    'top_p': 0.95,
    'top_k': 40,
    'repetition_penalty': 1.2,
}

chat_past_key_values = {}
chat_messages = {}

def load_model(size = '9b'):
    global tokenizer, model, cfg

    if cfg['size'] == size:
        return

    del model
    del tokenizer
    model = None
    tokenizer = None
    gc.collect()
    torch.cuda.empty_cache()

    model_name = f"SillyTilly/google-gemma-2-{size}-it"

    model = local_gemma.LocalGemma2ForCausalLM.from_pretrained(model_name, preset="memory")
    model._supports_cache_class = True
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    cfg['size'] = size

def set_config(size, instruction, first_assistant, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
    global default_args
    load_model(size)
    default_args.update({
        'instruction': instruction,
        'first_assistant': first_assistant,
        'chat_template': chat_template,
        'max_new_tokens': int(max_new_tokens),
        'temperature': float(temperature),
        'top_p': float(top_p),
        'top_k': int(top_k),
        'repetition_penalty': float(repetition_penalty),
    })
    return 'done.'

def set_config_args(args):
    global default_args

    load_model(args['size'])
    default_args.update(args)

    return 'done.'

def chatinterface_to_messages(history):
    messages = []

    for pair in history:
        [user, assistant] = pair
        if user:
            messages.append({'role': 'user', 'content': user})
        if assistant:
            messages.append({'role': 'assistant', 'content': assistant})

    return messages

# わりとややこしい
def tokenize(user_input, history = [], instruction = None, conversation_id = 'gradio', args = {}):
    global tokenizer, chat_messages

    # 先頭挿入用の形式づくり
    inst_messages = []
    if instruction:
        if 'first_assistant' in args and args['first_assistant']:
            # Claude互換形式
            # userとassistantは交互に存在しないといけない
            inst_messages = [
                {'role': 'user', 'content': instruction},
                {'role': 'assistant', 'content': args['first_assistant']},
                ]
        else:
            # OpenAI互換形式
            inst_messages = [{'role': 'system', 'content': instruction}]

    # messagesがあるときは全部上書きする
    if conversation_id and 'messages' in args:
        chat_messages[conversation_id] = inst_messages + args['messages']

    # cacheがあるならmessages形式で送る
    # instructionは既にcacheされているので不要(途中変更不可)
    if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
        # user_inputを追加する
        chat_messages[conversation_id] += [{'role': 'user', 'content': user_input}]
        tokenized_chat = tokenizer.apply_chat_template(
            chat_messages[conversation_id], tokenize=True, add_generation_prompt=True, return_tensors="pt"
        )
    else:
        # instructionがあれば適用する(inputは任意)
        if instruction:
            user_input = instruction.format(input=user_input)
        # どっちも無いとさすがにエラー
        if not user_input:
            raise ValueError('require input or instruction.')
        tokenized_chat = tokenizer(user_input, return_tensors="pt").input_ids

    return tokenized_chat

def chat(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
    global tokenizer, model, chat_past_key_values, chat_messages

    for k, v in default_args.items():
        args.setdefault(k, v)

    cache = None
    # conversation_idがあるときはcacheを読む
    if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
        # clearが指定されてるなら最初に消す
        if 'clear' in args and args['clear']:
            chat_past_key_values[conversation_id] = None
            chat_messages[conversation_id] = None
        else:
            cache = chat_past_key_values[conversation_id]

    # chat_templateがあれば適用する
    if args['chat_template']:
        tokenizer.chat_template = args['chat_template']

    device = local_gemma.utils.config.infer_device(None)
    generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat')

    # tokenizeする
    tokenized_chat = tokenize(message, history, instruction, conversation_id, args).to(device)

    streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True})
    generation_kwargs.update(
        {
            "streamer": streamer,
            "assistant_model": None,
            "return_dict_in_generate": True,
            "past_key_values": cache,
        }
    )

    for k in [
        'max_new_tokens',
        'temperature',
        'top_p',
        'top_k',
        'repetition_penalty'
        ]:
        if args[k]:
            generation_kwargs[k] = args[k]

    # TODO(joao): this if shouldn't be needed, fix in transformers
    if cache is not None:
        generation_kwargs["cache_implementation"] = None

    if args['max_new_tokens'] is not None:
        input_ids_len = tokenized_chat.shape[-1]
        max_cache_len = args['max_new_tokens'] + input_ids_len
        if cache is not None and cache.max_cache_len < max_cache_len:
            # reset the cache
            generation_kwargs.pop("past_key_values")
            generation_kwargs["cache_implementation"] = "hybrid"
    else:
        generation_kwargs["max_length"] = model.config.max_position_embeddings

    gen_out = model.generate(input_ids=tokenized_chat, **generation_kwargs)
    model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:]
    model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True)

    if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]:
        # Store the cache for the next generation round; Pull the model output into the chat history.
        chat_past_key_values[conversation_id] = gen_out.past_key_values
        chat_messages[conversation_id] += [{"role": "user", "content": message},]
        chat_messages[conversation_id] += [{"role": "assistant", "content": model_output_text},]

        # Sanity check: EOS was removed, ends in "<end_of_turn>\n"
        tokenized_chat = tokenizer.apply_chat_template(
            chat_messages[conversation_id], tokenize=True, add_generation_prompt=False, return_tensors="pt"
        ).tolist()[0]
        assert tokenized_chat[0] == 2
        assert tokenized_chat[-1] == 108
        assert tokenized_chat[-2] == 107

    # TODO: stream対応
    return model_output_text

# 非streamで返す
def infer(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
    return chat(message, history, instruction, conversation_id, args)

def numel(message, history = [], instruction = None, conversation_id = 'gradio', args = {}):
    global tokenizer, chat_messages

    tokenized_chat = tokenize(message, history, instruction, conversation_id, args)

    return torch.numel(tokenized_chat)