aka7774 commited on
Commit
3350c44
1 Parent(s): 6bbfa42

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +117 -0
  2. fn.py +202 -0
  3. install.bat +56 -0
  4. main.py +43 -0
  5. requirements.txt +5 -0
  6. venv.sh +7 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fn
2
+ import gradio as gr
3
+
4
+ with gr.Blocks() as demo:
5
+ gr.Markdown('# gemma2')
6
+ with gr.Tab('config'):
7
+ info = gr.Markdown()
8
+ with gr.Row():
9
+ with gr.Column(scale=1):
10
+ size = gr.Dropdown(
11
+ value=fn.cfg['size'],
12
+ choices=['9b','27b'],
13
+ label='size',
14
+ interactive=True,
15
+ )
16
+
17
+ with gr.Column(scale=1):
18
+ max_new_tokens = gr.Textbox(
19
+ value=fn.cfg['max_new_tokens'],
20
+ label='max_new_tokens',
21
+ interactive=True,
22
+ show_copy_button=True,
23
+ )
24
+ temperature = gr.Textbox(
25
+ value=fn.cfg['temperature'],
26
+ label='temperature',
27
+ interactive=True,
28
+ show_copy_button=True,
29
+ )
30
+ top_p = gr.Textbox(
31
+ value=fn.cfg['top_p'],
32
+ label='top_p',
33
+ interactive=True,
34
+ show_copy_button=True,
35
+ )
36
+ top_k = gr.Textbox(
37
+ value=fn.cfg['top_k'],
38
+ label='top_k',
39
+ interactive=True,
40
+ show_copy_button=True,
41
+ )
42
+ repetition_penalty = gr.Textbox(
43
+ value=fn.cfg['repetition_penalty'],
44
+ label='repetition_penalty',
45
+ interactive=True,
46
+ show_copy_button=True,
47
+ )
48
+
49
+ with gr.Row():
50
+ with gr.Column(scale=1):
51
+ inst_template = gr.Textbox(
52
+ value='',
53
+ lines=10,
54
+ label='inst_template',
55
+ interactive=True,
56
+ show_copy_button=True,
57
+ )
58
+ is_use_cache = gr.Checkbox(
59
+ value=False,
60
+ label='is_use_cache',
61
+ interactive=True,
62
+ )
63
+
64
+ set_button = gr.Button(value='Save')
65
+
66
+ with gr.Tab('instruct'):
67
+ with gr.Row():
68
+ with gr.Column(scale=1):
69
+ instruction = gr.Textbox(
70
+ lines=20,
71
+ label='instruction',
72
+ interactive=True,
73
+ show_copy_button=True,
74
+ )
75
+ input = gr.Textbox(
76
+ lines=1,
77
+ label='input',
78
+ interactive=True,
79
+ show_copy_button=True,
80
+ )
81
+ with gr.Column(scale=1):
82
+ said = gr.Textbox(
83
+ label='said',
84
+ lines=20,
85
+ show_copy_button=True,
86
+ )
87
+ numel = gr.Textbox(
88
+ lines=1,
89
+ label='numel',
90
+ show_copy_button=True,
91
+ )
92
+ inst_button = gr.Button(value='inst')
93
+ numel_button = gr.Button(value='numel')
94
+
95
+ with gr.Tab('chat'):
96
+ gr.ChatInterface(fn.chat)
97
+
98
+ set_button.click(
99
+ fn=fn.set_config,
100
+ inputs=[size, instruction, inst_template, is_use_cache, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
101
+ outputs=[info],
102
+ )
103
+
104
+ inst_button.click(
105
+ fn=fn.chat,
106
+ inputs=[input, input, instruction],
107
+ outputs=[said],
108
+ )
109
+
110
+ numel_button.click(
111
+ fn=fn.numel,
112
+ inputs=[input, input, instruction],
113
+ outputs=[numel],
114
+ )
115
+
116
+ if __name__ == '__main__':
117
+ demo.launch()
fn.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import datetime
5
+ import json
6
+ import csv
7
+ import gc
8
+
9
+ import local_gemma
10
+ from transformers import AutoTokenizer, TextStreamer
11
+ from transformers import TextIteratorStreamer
12
+ from transformers import BitsAndBytesConfig, GPTQConfig
13
+ from threading import Thread
14
+
15
+ tokenizer = None
16
+ model = None
17
+ default_cfg = {
18
+ 'size': None,
19
+ 'instruction': None,
20
+ 'inst_template': None,
21
+ 'is_use_cache': False,
22
+ 'max_new_tokens': 1024,
23
+ 'temperature': 0.9,
24
+ 'top_p': 0.95,
25
+ 'top_k': 40,
26
+ 'repetition_penalty': 1.2,
27
+ }
28
+ cfg = default_cfg.copy()
29
+ cache = None
30
+ chat_history = []
31
+
32
+ def load_model(size = '9b'):
33
+ global tokenizer, model, cfg
34
+
35
+ if cfg['size'] == size:
36
+ return
37
+
38
+ del model
39
+ del tokenizer
40
+ model = None
41
+ tokenizer = None
42
+ gc.collect()
43
+ torch.cuda.empty_cache()
44
+
45
+ model_name = f"SillyTilly/google-gemma-2-{size}-it"
46
+
47
+ model = local_gemma.LocalGemma2ForCausalLM.from_pretrained(model_name, preset="memory")
48
+ model._supports_cache_class = True
49
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
50
+
51
+ cfg['size'] = size
52
+
53
+ def clear_config():
54
+ global cfg
55
+ cfg = default_cfg.copy()
56
+
57
+ def set_config(size, instruction, inst_template, is_use_cache, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
58
+ global cfg
59
+ load_model(size)
60
+ cfg.update({
61
+ 'instruction': instruction,
62
+ 'inst_template': inst_template,
63
+ 'is_use_cache': bool(is_use_cache),
64
+ 'max_new_tokens': int(max_new_tokens),
65
+ 'temperature': float(temperature),
66
+ 'top_p': float(top_p),
67
+ 'top_k': int(top_k),
68
+ 'repetition_penalty': float(repetition_penalty),
69
+ })
70
+ return 'done.'
71
+
72
+ def set_config_args(args):
73
+ global cfg
74
+
75
+ load_model(args['size'])
76
+ cfg.update(args)
77
+
78
+ return 'done.'
79
+
80
+ def chatinterface_to_messages(message, history):
81
+ global cfg
82
+
83
+ messages = []
84
+
85
+ if cfg['instruction']:
86
+ messages.append({'role': 'user', 'content': cfg['instruction']})
87
+ # userとassistantは交互に存在しないといけない
88
+ if message:
89
+ messages.append({'role': 'assistant', 'content': '了解しました。'})
90
+
91
+ for pair in history:
92
+ [user, assistant] = pair
93
+ if user:
94
+ messages.append({'role': 'user', 'content': user})
95
+ if assistant:
96
+ messages.append({'role': 'assistant', 'content': assistant})
97
+
98
+ if message:
99
+ messages.append({'role': 'user', 'content': message})
100
+
101
+ return messages
102
+
103
+ def apply_template(messages):
104
+ global tokenizer, cfg, cache, chat_history
105
+
106
+ if type(messages) is str:
107
+ if cfg['inst_template']:
108
+ user_input = cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
109
+ user_input = cfg['instruction'].format(input=messages)
110
+ tokenized_chat = tokenizer(user_input, return_tensors="pt").input_ids
111
+ if type(messages) is list:
112
+ tokenized_chat = tokenizer.apply_chat_template(
113
+ messages + chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt"
114
+ )
115
+ return tokenized_chat
116
+
117
+ def chat(message, history = [], instruction = None, args = {}):
118
+ global tokenizer, model, cfg, cache, chat_history
119
+
120
+ if instruction:
121
+ cfg['instruction'] = instruction
122
+ tokenized_chat = apply_template(message)
123
+ else:
124
+ messages = chatinterface_to_messages(message, history)
125
+ tokenized_chat = apply_template(messages)
126
+
127
+ device = local_gemma.utils.config.infer_device(None)
128
+ is_use_cache = cfg['is_use_cache']
129
+ generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat')
130
+
131
+ streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True})
132
+ tokenized_chat = tokenized_chat.to(device)
133
+ generation_kwargs.update(
134
+ {
135
+ "streamer": streamer,
136
+ "assistant_model": None,
137
+ "return_dict_in_generate": True,
138
+ "past_key_values": cache,
139
+ }
140
+ )
141
+
142
+ for k in [
143
+ 'max_new_tokens',
144
+ 'temperature',
145
+ 'top_p',
146
+ 'top_k',
147
+ 'repetition_penalty'
148
+ ]:
149
+ if cfg[k]:
150
+ generation_kwargs[k] = cfg[k]
151
+
152
+ # TODO(joao): this if shouldn't be needed, fix in transformers
153
+ if cache is not None:
154
+ generation_kwargs["cache_implementation"] = None
155
+
156
+ if cfg['max_new_tokens'] is not None:
157
+ input_ids_len = tokenized_chat.shape[-1]
158
+ max_cache_len = cfg['max_new_tokens'] + input_ids_len
159
+ if cache is not None and cache.max_cache_len < max_cache_len:
160
+ # reset the cache
161
+ generation_kwargs.pop("past_key_values")
162
+ generation_kwargs["cache_implementation"] = "hybrid"
163
+ else:
164
+ generation_kwargs["max_length"] = model.config.max_position_embeddings
165
+
166
+ gen_out = model.generate(input_ids=tokenized_chat, **generation_kwargs)
167
+
168
+ # Store the cache for the next generation round; Pull the model output into the chat history.
169
+ cache = gen_out.past_key_values
170
+ model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:]
171
+ model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True)
172
+ chat_history += [{"role": "user", "content": message},]
173
+ chat_history += [{"role": "assistant", "content": model_output_text},]
174
+
175
+ # Sanity check: EOS was removed, ends in "<end_of_turn>\n"
176
+ tokenized_chat = tokenizer.apply_chat_template(
177
+ chat_history, tokenize=True, add_generation_prompt=False, return_tensors="pt"
178
+ ).tolist()[0]
179
+ assert tokenized_chat[0] == 2
180
+ assert tokenized_chat[-1] == 108
181
+ assert tokenized_chat[-2] == 107
182
+
183
+ if not is_use_cache:
184
+ cache = None
185
+ chat_history = []
186
+
187
+ return model_output_text
188
+
189
+ def infer(message, history = [], instruction = None, args = {}):
190
+ return chat(message, history, instruction, args)
191
+
192
+ def numel(message, history = [], instruction = None, args = {}):
193
+ global tokenizer, model, cfg, cache, chat_history
194
+
195
+ if instruction:
196
+ cfg['instruction'] = instruction
197
+ tokenized_chat = apply_template(message)
198
+ else:
199
+ messages = chatinterface_to_messages(message, history)
200
+ tokenized_chat = apply_template(messages)
201
+
202
+ return torch.numel(tokenized_chat)
install.bat ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ rem -------------------------------------------
4
+ rem NOT guaranteed to work on Windows
5
+
6
+ set APPDIR=gemma2
7
+ set REPOS=https://huggingface.co/spaces/aka7774/%APPDIR%
8
+ set VENV=venv
9
+
10
+ rem -------------------------------------------
11
+
12
+ set INSTALL_DIR=%~dp0
13
+ cd /d %INSTALL_DIR%
14
+
15
+ :git_clone
16
+ set DL_URL=%REPOS%
17
+ set DL_DST=%APPDIR%
18
+ git clone %DL_URL% %APPDIR%
19
+ if exist %DL_DST% goto install_python
20
+
21
+ set DL_URL=https://github.com/git-for-windows/git/releases/download/v2.41.0.windows.3/PortableGit-2.41.0.3-64-bit.7z.exe
22
+ set DL_DST=PortableGit-2.41.0.3-64-bit.7z.exe
23
+ curl -L -o %DL_DST% %DL_URL%
24
+ if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
25
+ %DL_DST% -y
26
+ del %DL_DST%
27
+
28
+ set GIT=%INSTALL_DIR%PortableGit\bin\git
29
+ %GIT% clone %REPOS%
30
+
31
+ :install_python
32
+ set DL_URL=https://github.com/indygreg/python-build-standalone/releases/download/20240415/cpython-3.10.14+20240415-x86_64-pc-windows-msvc-shared-install_only.tar.gz
33
+ set DL_DST="%INSTALL_DIR%python.tar.gz"
34
+ curl -L -o %DL_DST% %DL_URL%
35
+ if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
36
+ tar -xzf %DL_DST%
37
+
38
+ set PYTHON=%INSTALL_DIR%python\python.exe
39
+ set PATH=%PATH%;%INSTALL_DIR%python310\Scripts
40
+
41
+ :install_venv
42
+ cd %APPDIR%
43
+ %PYTHON% -m venv %VENV%
44
+ set PYTHON=%VENV%\Scripts\python.exe
45
+
46
+ :install_pip
47
+ set DL_URL=https://bootstrap.pypa.io/get-pip.py
48
+ set DL_DST=%INSTALL_DIR%get-pip.py
49
+ curl -o %DL_DST% %DL_URL%
50
+ if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
51
+ %PYTHON% %DL_DST%
52
+
53
+ %PYTHON% -m pip install gradio
54
+ %PYTHON% -m pip install -r requirements.txt
55
+
56
+ pause
main.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import signal
5
+ import io
6
+
7
+ from fastapi import FastAPI, Request, status, Form, UploadFile
8
+ from fastapi.staticfiles import StaticFiles
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel, Field
11
+ from fastapi.exceptions import RequestValidationError
12
+ from fastapi.responses import JSONResponse, StreamingResponse
13
+ import fn
14
+ import gradio as gr
15
+ from app import demo
16
+
17
+ app = FastAPI()
18
+
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=['*'],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ gr.mount_gradio_app(app, demo, path="/gradio")
28
+
29
+ @app.post("/set_config")
30
+ async def api_set_config(args: dict):
31
+ content = fn.set_config_args(args)
32
+ return {'content': content}
33
+
34
+ @app.post("/infer")
35
+ async def api_infer(args: dict):
36
+ args['fastapi'] = True
37
+ content = fn.infer(args['input'], [], args['instruct'], args)
38
+ return {'content': content}
39
+
40
+ @app.post("/numel")
41
+ async def api_numel(args: dict):
42
+ content = fn.numel(args['input'], [], args['instruct'], args)
43
+ return {'numel': content}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ local-gemma
4
+ bitsandbytes
5
+ python-multipart
venv.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+
3
+ python3 -m venv venv
4
+ curl -kL https://bootstrap.pypa.io/get-pip.py | venv/bin/python
5
+
6
+ venv/bin/python -m pip install gradio
7
+ venv/bin/python -m pip install -r requirements.txt