aka7774 commited on
Commit
5653716
1 Parent(s): 69edd86

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +126 -0
  2. fn.py +184 -0
  3. install.bat +56 -0
  4. main.py +49 -0
  5. requirements.txt +21 -0
  6. venv.sh +10 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fn
2
+ import gradio as gr
3
+
4
+ with gr.Blocks() as demo:
5
+ with gr.Tab('config'):
6
+ info = gr.Markdown()
7
+ with gr.Row():
8
+ with gr.Column(scale=1):
9
+ model = gr.Textbox(
10
+ value=fn.cfg['model_name'],
11
+ label='model',
12
+ interactive=True,
13
+ show_copy_button=True,
14
+ )
15
+ dtype = gr.Dropdown(
16
+ value=fn.cfg['dtype'],
17
+ choices=['4bit'],
18
+ label='dtype',
19
+ interactive=True,
20
+ allow_custom_value=True,
21
+ )
22
+
23
+ with gr.Column(scale=1):
24
+ max_new_tokens = gr.Textbox(
25
+ value=fn.cfg['max_new_tokens'],
26
+ label='max_new_tokens',
27
+ interactive=True,
28
+ show_copy_button=True,
29
+ )
30
+ temperature = gr.Textbox(
31
+ value=fn.cfg['temperature'],
32
+ label='temperature',
33
+ interactive=True,
34
+ show_copy_button=True,
35
+ )
36
+ top_p = gr.Textbox(
37
+ value=fn.cfg['top_p'],
38
+ label='top_p',
39
+ interactive=True,
40
+ show_copy_button=True,
41
+ )
42
+ top_k = gr.Textbox(
43
+ value=fn.cfg['top_k'],
44
+ label='top_k',
45
+ interactive=True,
46
+ show_copy_button=True,
47
+ )
48
+ repetition_penalty = gr.Textbox(
49
+ value=fn.cfg['repetition_penalty'],
50
+ label='repetition_penalty',
51
+ interactive=True,
52
+ show_copy_button=True,
53
+ )
54
+
55
+ with gr.Row():
56
+ with gr.Column(scale=1):
57
+ inst_template = gr.Textbox(
58
+ value='',
59
+ lines=10,
60
+ label='inst_template',
61
+ interactive=True,
62
+ show_copy_button=True,
63
+ )
64
+ with gr.Column(scale=1):
65
+ chat_template = gr.Textbox(
66
+ value='',
67
+ lines=10,
68
+ label='chat_template',
69
+ interactive=True,
70
+ show_copy_button=True,
71
+ )
72
+
73
+ set_button = gr.Button(value='Save')
74
+
75
+ with gr.Tab('instruct'):
76
+ with gr.Row():
77
+ with gr.Column(scale=1):
78
+ instruction = gr.Textbox(
79
+ lines=20,
80
+ label='instruction',
81
+ interactive=True,
82
+ show_copy_button=True,
83
+ )
84
+ input = gr.Textbox(
85
+ lines=1,
86
+ label='input',
87
+ interactive=True,
88
+ show_copy_button=True,
89
+ )
90
+ with gr.Column(scale=1):
91
+ said = gr.Textbox(
92
+ label='said',
93
+ lines=20,
94
+ show_copy_button=True,
95
+ )
96
+ numel = gr.Textbox(
97
+ lines=1,
98
+ label='numel',
99
+ show_copy_button=True,
100
+ )
101
+ inst_button = gr.Button(value='inst')
102
+ numel_button = gr.Button(value='numel')
103
+
104
+ with gr.Tab('chat'):
105
+ gr.ChatInterface(fn.chat)
106
+
107
+ set_button.click(
108
+ fn=fn.set_config,
109
+ inputs=[model, dtype, instruction, inst_template, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
110
+ outputs=[info],
111
+ )
112
+
113
+ inst_button.click(
114
+ fn=fn.chat,
115
+ inputs=[input, input, instruction],
116
+ outputs=[said],
117
+ )
118
+
119
+ numel_button.click(
120
+ fn=fn.numel,
121
+ inputs=[input, input, instruction],
122
+ outputs=[numel],
123
+ )
124
+
125
+ if __name__ == '__main__':
126
+ demo.launch()
fn.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ import gc
5
+ import time
6
+ from unsloth import FastLanguageModel
7
+ from transformers import TextIteratorStreamer
8
+ from threading import Thread
9
+
10
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
11
+
12
+ tokenizer = None
13
+ model = None
14
+ default_cfg = {
15
+ 'model_name': "unsloth/gemma-2-9b-it-bnb-4bit",
16
+ 'dtype': None,
17
+ 'instruction': None,
18
+ 'inst_template': None,
19
+ 'chat_template': None,
20
+ 'max_length': 2400,
21
+ 'max_seq_length': 2048,
22
+ 'max_new_tokens': 512,
23
+ 'temperature': 0.9,
24
+ 'top_p': 0.95,
25
+ 'top_k': 40,
26
+ 'repetition_penalty': 1.2,
27
+ }
28
+ cfg = default_cfg.copy()
29
+
30
+ def load_model(model_name, dtype):
31
+ global tokenizer, model, cfg
32
+
33
+ if cfg['model_name'] == model_name and cfg['dtype'] == dtype:
34
+ return
35
+
36
+ del model
37
+ del tokenizer
38
+ model = None
39
+ tokenizer = None
40
+ gc.collect()
41
+ torch.cuda.empty_cache()
42
+
43
+ model, tokenizer = FastLanguageModel.from_pretrained(
44
+ model_name,
45
+ max_seq_length = cfg['max_seq_length'],
46
+ dtype = torch.bfloat16,
47
+ load_in_8bit = (dtype == '8bit'),
48
+ load_in_4bit = (dtype == '4bit'),
49
+ )
50
+
51
+ FastLanguageModel.for_inference(model)
52
+
53
+ cfg['model_name'] = model_name
54
+ cfg['dtype'] = dtype
55
+
56
+ def clear_config():
57
+ global cfg
58
+ cfg = default_cfg.copy()
59
+
60
+ def set_config(model_name, dtype, instruction, inst_template, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
61
+ global cfg
62
+ load_model(model_name, dtype)
63
+ cfg.update({
64
+ 'instruction': instruction,
65
+ 'inst_template': inst_template,
66
+ 'chat_template': chat_template,
67
+ 'max_new_tokens': int(max_new_tokens),
68
+ 'temperature': float(temperature),
69
+ 'top_p': float(top_p),
70
+ 'top_k': int(top_k),
71
+ 'repetition_penalty': float(repetition_penalty),
72
+ })
73
+ return 'done.'
74
+
75
+ def set_config_args(args):
76
+ global cfg
77
+
78
+ load_model(args['model_name'], args['dtype'])
79
+ cfg.update(args)
80
+
81
+ return 'done.'
82
+
83
+ def chatinterface_to_messages(message, history):
84
+ global cfg
85
+
86
+ messages = []
87
+
88
+ if cfg['instruction']:
89
+ messages.append({'role': 'system', 'content': cfg['instruction']})
90
+
91
+ for pair in history:
92
+ [user, assistant] = pair
93
+ if user:
94
+ messages.append({'role': 'user', 'content': user})
95
+ if assistant:
96
+ messages.append({'role': 'assistant', 'content': assistant})
97
+
98
+ if message:
99
+ messages.append({'role': 'user', 'content': message})
100
+
101
+ return messages
102
+
103
+ def apply_template(messages):
104
+ global tokenizer, cfg
105
+
106
+ if cfg['chat_template']:
107
+ tokenizer.chat_template = cfg['chat_template']
108
+
109
+ if type(messages) is str:
110
+ if cfg['inst_template']:
111
+ return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
112
+ return cfg['instruction'].format(input=messages)
113
+ if type(messages) is list:
114
+ return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
115
+
116
+ def chat(message, history = [], instruction = None, args = {}):
117
+ global tokenizer, model, cfg
118
+
119
+ if instruction:
120
+ cfg['instruction'] = instruction
121
+ prompt = apply_template(message)
122
+ else:
123
+ messages = chatinterface_to_messages(message, history)
124
+ prompt = apply_template(messages)
125
+
126
+ inputs = tokenizer(prompt, return_tensors="pt",
127
+ padding=True, max_length=cfg['max_length'], truncation=True).to("cuda")
128
+
129
+ streamer = TextIteratorStreamer(
130
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
131
+ )
132
+
133
+ generate_kwargs = dict(
134
+ inputs,
135
+ do_sample=True,
136
+ streamer=streamer,
137
+ num_beams=1,
138
+ )
139
+
140
+ for k in [
141
+ 'max_new_tokens',
142
+ 'temperature',
143
+ 'top_p',
144
+ 'top_k',
145
+ 'repetition_penalty'
146
+ ]:
147
+ if cfg[k]:
148
+ generate_kwargs[k] = cfg[k]
149
+
150
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
151
+ t.start()
152
+
153
+ model_output = ""
154
+ for new_text in streamer:
155
+ model_output += new_text
156
+ if 'fastapi' in args:
157
+ # fastapiは差分だけを返して欲しい
158
+ yield new_text
159
+ else:
160
+ # gradioは常に全文を返して欲しい
161
+ yield model_output
162
+
163
+ def infer(message, history = [], instruction = None, args = {}):
164
+ content = ''
165
+ for s in chat(message, history, instruction, args):
166
+ content += s
167
+ return content
168
+
169
+ def numel(message, history = [], instruction = None, args = {}):
170
+ global tokenizer, model, cfg
171
+
172
+ if instruction:
173
+ cfg['instruction'] = instruction
174
+ prompt = apply_template(message)
175
+ else:
176
+ messages = chatinterface_to_messages(message, history)
177
+ prompt = apply_template(messages)
178
+
179
+ model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
180
+
181
+ return torch.numel(model_inputs['input_ids'])
182
+
183
+
184
+ load_model(cfg['model_name'], '4bit')
install.bat ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ rem -------------------------------------------
4
+ rem NOT guaranteed to work on Windows
5
+
6
+ set APPDIR=gemma2_9b_7gb
7
+ set REPOS=https://huggingface.co/spaces/aka7774/%APPDIR%
8
+ set VENV=venv
9
+
10
+ rem -------------------------------------------
11
+
12
+ set INSTALL_DIR=%~dp0
13
+ cd /d %INSTALL_DIR%
14
+
15
+ :git_clone
16
+ set DL_URL=%REPOS%
17
+ set DL_DST=%APPDIR%
18
+ git clone %DL_URL% %APPDIR%
19
+ if exist %DL_DST% goto install_python
20
+
21
+ set DL_URL=https://github.com/git-for-windows/git/releases/download/v2.41.0.windows.3/PortableGit-2.41.0.3-64-bit.7z.exe
22
+ set DL_DST=PortableGit-2.41.0.3-64-bit.7z.exe
23
+ curl -L -o %DL_DST% %DL_URL%
24
+ if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
25
+ %DL_DST% -y
26
+ del %DL_DST%
27
+
28
+ set GIT=%INSTALL_DIR%PortableGit\bin\git
29
+ %GIT% clone %REPOS%
30
+
31
+ :install_python
32
+ set DL_URL=https://github.com/indygreg/python-build-standalone/releases/download/20240415/cpython-3.10.14+20240415-x86_64-pc-windows-msvc-shared-install_only.tar.gz
33
+ set DL_DST="%INSTALL_DIR%python.tar.gz"
34
+ curl -L -o %DL_DST% %DL_URL%
35
+ if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
36
+ tar -xzf %DL_DST%
37
+
38
+ set PYTHON=%INSTALL_DIR%python\python.exe
39
+ set PATH=%PATH%;%INSTALL_DIR%python310\Scripts
40
+
41
+ :install_venv
42
+ cd %APPDIR%
43
+ %PYTHON% -m venv %VENV%
44
+ set PYTHON=%VENV%\Scripts\python.exe
45
+
46
+ :install_pip
47
+ set DL_URL=https://bootstrap.pypa.io/get-pip.py
48
+ set DL_DST=%INSTALL_DIR%get-pip.py
49
+ curl -o %DL_DST% %DL_URL%
50
+ if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
51
+ %PYTHON% %DL_DST%
52
+
53
+ %PYTHON% -m pip install gradio
54
+ %PYTHON% -m pip install -r requirements.txt
55
+
56
+ pause
main.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import signal
5
+ import io
6
+
7
+ from fastapi import FastAPI, Request, status, Form, UploadFile
8
+ from fastapi.staticfiles import StaticFiles
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel, Field
11
+ from fastapi.exceptions import RequestValidationError
12
+ from fastapi.responses import JSONResponse, StreamingResponse
13
+ import fn
14
+ import gradio as gr
15
+ from app import demo
16
+
17
+ app = FastAPI()
18
+
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=['*'],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ gr.mount_gradio_app(app, demo, path="/gradio")
28
+
29
+ @app.post("/set_config")
30
+ async def api_set_config(args: dict):
31
+ content = fn.set_config_args(args)
32
+ return {'content': content}
33
+
34
+ @app.post("/infer")
35
+ async def api_infer(args: dict):
36
+ args['fastapi'] = True
37
+ if 'stream' in args and args['stream']:
38
+ return StreamingResponse(
39
+ fn.chat(args['input'], [], args['instruct'], args),
40
+ media_type="text/event-stream",
41
+ )
42
+ else:
43
+ content = fn.infer(args['input'], [], args['instruct'], args)
44
+ return {'content': content}
45
+
46
+ @app.post("/numel")
47
+ async def api_numel(args: dict):
48
+ content = fn.numel(args['input'], [], args['instruct'], args)
49
+ return {'numel': content}
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # On Windows, it is difficult to prepare flash_attn2 and probably cannot run.
2
+ # On WSL2:
3
+ # sudo apt install python3.10-dev
4
+ # wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-keyring_1.0-1_all.deb
5
+ # sudo dpkg -i cuda-keyring_1.0-1_all.deb
6
+ # sudo apt update
7
+ # sudo apt-get install cuda-toolkit-12-1
8
+ # vi ~/.bashrc
9
+ # if [ -e /usr/local/cuda ]; then
10
+ # export PATH="/usr/local/cuda/bin:$PATH"
11
+ # export LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH"
12
+ # fi
13
+
14
+ fastapi
15
+ uvicorn
16
+ transformers==4.43.3
17
+ bitsandbytes==0.43.3
18
+ accelerate==0.33.0
19
+ peft==0.12.0
20
+ wheel
21
+ python-multipart
venv.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+
3
+ python3 -m venv venv
4
+ curl -kL https://bootstrap.pypa.io/get-pip.py | venv/bin/python
5
+
6
+ venv/bin/python -m pip install gradio
7
+ venv/bin/python -m pip install -r requirements.txt
8
+ venv/bin/python -m pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121
9
+ venv/bin/python -m pip install flash-attn --no-build-isolation
10
+ venv/bin/python -m pip install "unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"