BlinkDL commited on
Commit
e63ee0a
1 Parent(s): e176af0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -78
app.py CHANGED
@@ -1,50 +1,63 @@
1
  import gradio as gr
2
- import os, gc, torch
3
  from datetime import datetime
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
6
  nvmlInit()
7
  gpu_h = nvmlDeviceGetHandleByIndex(0)
8
  ctx_limit = 1024
9
- title = "RWKV-4-Pile-14B-20230313-ctx8192-test1050"
10
- desc = f'''Links:
11
- <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a>
12
- <a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a>
13
- <a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a>
14
- <a href="https://huggingface.co/spaces/BlinkDL/Raven-RWKV-7B" target="_blank" style="margin:0 0.5em">Raven 7B (alpaca-style)</a>
15
- '''
16
 
17
  os.environ["RWKV_JIT_ON"] = '1'
18
  os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
19
 
20
  from rwkv.model import RWKV
21
- model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-14b", filename=f"{title}.pth")
22
  model = RWKV(model=model_path, strategy='cuda fp16i8 *24 -> cuda fp16')
23
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
24
  pipeline = PIPELINE(model, "20B_tokenizer.json")
25
 
26
- def infer(
27
- ctx,
28
- token_count=10,
29
- temperature=1.0,
30
- top_p=0.8,
31
- presencePenalty = 0.1,
32
- countPenalty = 0.1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ):
34
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
35
  alpha_frequency = countPenalty,
36
  alpha_presence = presencePenalty,
37
- token_ban = [0], # ban the generation of some tokens
38
- token_stop = []) # stop generation whenever you see any token here
39
-
40
- ctx = ctx.strip(' ')
41
- if ctx.endswith('\n'):
42
- ctx = f'\n{ctx.strip()}\n'
43
- else:
44
- ctx = f'\n{ctx.strip()}'
45
 
46
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
47
- print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
 
48
 
49
  all_tokens = []
50
  out_last = 0
@@ -53,8 +66,6 @@ def infer(
53
  state = None
54
  for i in range(int(token_count)):
55
  out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
56
- for n in args.token_ban:
57
- out[n] = -float('inf')
58
  for n in occurrence:
59
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
60
 
@@ -72,62 +83,222 @@ def infer(
72
  out_str += tmp
73
  yield out_str.strip()
74
  out_last = i + 1
 
 
 
 
75
  gc.collect()
76
  torch.cuda.empty_cache()
77
  yield out_str.strip()
78
 
79
  examples = [
80
- ["Expert Questions & Helpful Answers\nAsk Research Experts\nQuestion:\nHow can we eliminate poverty?\n\nFull Answer:\n", 150, 1.0, 0.7, 0.2, 0.2],
81
- ["Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n", 150, 1.0, 0.7, 0.2, 0.2],
82
- ['''Below is an instruction that describes a task. Write a response that appropriately completes the request.
83
-
84
- ### Instruction:
85
- Generate a list of adjectives that describe a person as brave.
86
-
87
- ### Response:
88
- ''', 150, 1.0, 0.2, 0.5, 0.5],
89
- ['''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
90
-
91
- ### Instruction:
92
- Arrange the given numbers in ascending order.
93
-
94
- ### Input:
95
- 2, 4, 0, 8, 3
96
-
97
- ### Response:
98
- ''', 150, 1.0, 0.2, 0.5, 0.5],
99
- ["Ask Expert\n\nQuestion:\nWhat are some good plans for world peace?\n\nExpert Full Answer:\n", 150, 1.0, 0.7, 0.2, 0.2],
100
- ["Q & A\n\nQuestion:\nWhy is the sky blue?\n\nDetailed Expert Answer:\n", 150, 1.0, 0.7, 0.2, 0.2],
101
- ["Dear sir,\nI would like to express my boundless apologies for the recent nuclear war.", 150, 1.0, 0.7, 0.2, 0.2],
102
- ["Here is a shell script to find all .hpp files in /home/workspace and delete the 3th row string of these files:", 150, 1.0, 0.7, 0.1, 0.1],
103
- ["Building a website can be done in 10 simple steps:\n1.", 150, 1.0, 0.7, 0.2, 0.2],
104
- ["A Chinese phrase is provided: 百闻不如一见。\nThe masterful Chinese translator flawlessly translates the phrase into English:", 150, 1.0, 0.5, 0.2, 0.2],
105
- ["I believe the meaning of life is", 150, 1.0, 0.7, 0.2, 0.2],
106
- ["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
107
  ]
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- iface = gr.Interface(
111
- fn=infer,
112
- description=f'''{desc} *** <b>Please try examples first (bottom of page)</b> *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''',
113
- allow_flagging="never",
114
- inputs=[
115
- gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n"), # prompt
116
- gr.Slider(10, 200, step=10, value=150), # token_count
117
- gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
118
- gr.Slider(0.0, 1.0, step=0.05, value=0.7), # top_p
119
- gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presencePenalty
120
- gr.Slider(0.0, 1.0, step=0.1, value=0.2), # countPenalty
121
- ],
122
- outputs=gr.Textbox(label="Generated Output", lines=28),
123
- examples=examples,
124
- cache_examples=False,
125
- ).queue()
126
-
127
- demo = gr.TabbedInterface(
128
- [iface], ["Generative"],
129
- title=title,
130
- )
131
-
132
- demo.queue(max_size=10)
133
  demo.launch(share=False)
 
1
  import gradio as gr
2
+ import os, gc, copy, torch
3
  from datetime import datetime
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
6
  nvmlInit()
7
  gpu_h = nvmlDeviceGetHandleByIndex(0)
8
  ctx_limit = 1024
9
+ title = "RWKV-4-Raven-14B-v10-Eng99%-Other1%-20230427-ctx8192"
 
 
 
 
 
 
10
 
11
  os.environ["RWKV_JIT_ON"] = '1'
12
  os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
13
 
14
  from rwkv.model import RWKV
15
+ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-raven", filename=f"{title}.pth")
16
  model = RWKV(model=model_path, strategy='cuda fp16i8 *24 -> cuda fp16')
17
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
18
  pipeline = PIPELINE(model, "20B_tokenizer.json")
19
 
20
+ def generate_prompt(instruction, input=None):
21
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
22
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
23
+ if input:
24
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
25
+
26
+ # Instruction:
27
+ {instruction}
28
+
29
+ # Input:
30
+ {input}
31
+
32
+ # Response:
33
+ """
34
+ else:
35
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
36
+
37
+ # Instruction:
38
+ {instruction}
39
+
40
+ # Response:
41
+ """
42
+
43
+ def evaluate(
44
+ instruction,
45
+ input=None,
46
+ token_count=200,
47
+ temperature=1.0,
48
+ top_p=0.7,
49
+ presencePenalty = 0.1,
50
+ countPenalty = 0.1,
51
  ):
52
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
53
  alpha_frequency = countPenalty,
54
  alpha_presence = presencePenalty,
55
+ token_ban = [], # ban the generation of some tokens
56
+ token_stop = [0]) # stop generation whenever you see any token here
 
 
 
 
 
 
57
 
58
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
59
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
60
+ ctx = generate_prompt(instruction, input)
61
 
62
  all_tokens = []
63
  out_last = 0
 
66
  state = None
67
  for i in range(int(token_count)):
68
  out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
 
 
69
  for n in occurrence:
70
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
71
 
 
83
  out_str += tmp
84
  yield out_str.strip()
85
  out_last = i + 1
86
+
87
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
88
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
89
+
90
  gc.collect()
91
  torch.cuda.empty_cache()
92
  yield out_str.strip()
93
 
94
  examples = [
95
+ ["Tell me about ravens.", "", 150, 1.2, 0.5, 0.4, 0.4],
96
+ ["Write a python function to mine 1 BTC, with details and comments.", "", 150, 1.2, 0.5, 0.4, 0.4],
97
+ ["Write a song about ravens.", "", 150, 1.2, 0.5, 0.4, 0.4],
98
+ ["Explain the following metaphor: Life is like cats.", "", 150, 1.2, 0.5, 0.4, 0.4],
99
+ ["Write a story using the following information", "A man named Alex chops a tree down", 150, 1.2, 0.5, 0.4, 0.4],
100
+ ["Generate a list of adjectives that describe a person as brave.", "", 150, 1.2, 0.5, 0.4, 0.4],
101
+ ["You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.", "", 150, 1.2, 0.5, 0.4, 0.4],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  ]
103
 
104
+ ##########################################################################
105
+
106
+ chat_intro = '''The following is a coherent verbose detailed conversation between <|user|> and an AI girl named <|bot|>.
107
+
108
+ <|user|>: Hi <|bot|>, Would you like to chat with me for a while?
109
+
110
+ <|bot|>: Hi <|user|>. Sure. What would you like to talk about? I'm listening.
111
+ '''
112
+
113
+ def user(message, chatbot):
114
+ chatbot = chatbot or []
115
+ # print(f"User: {message}")
116
+ return "", chatbot + [[message, None]]
117
+
118
+ def alternative(chatbot, history):
119
+ if not chatbot or not history:
120
+ return chatbot, history
121
+
122
+ chatbot[-1][1] = None
123
+ history[0] = copy.deepcopy(history[1])
124
+
125
+ return chatbot, history
126
+
127
+ def chat(
128
+ prompt,
129
+ user,
130
+ bot,
131
+ chatbot,
132
+ history,
133
+ temperature=1.0,
134
+ top_p=0.8,
135
+ presence_penalty=0.1,
136
+ count_penalty=0.1,
137
+ ):
138
+ args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
139
+ alpha_frequency=float(count_penalty),
140
+ alpha_presence=float(presence_penalty),
141
+ token_ban=[], # ban the generation of some tokens
142
+ token_stop=[]) # stop generation whenever you see any token here
143
+
144
+ if not chatbot:
145
+ return chatbot, history
146
+
147
+ message = chatbot[-1][0]
148
+ message = message.strip().replace('\r\n','\n').replace('\n\n','\n')
149
+ ctx = f"{user}: {message}\n\n{bot}:"
150
+
151
+ if not history:
152
+ prompt = prompt.replace("<|user|>", user.strip())
153
+ prompt = prompt.replace("<|bot|>", bot.strip())
154
+ prompt = prompt.strip()
155
+ prompt = f"\n{prompt}\n\n"
156
+
157
+ out, state = model.forward(pipeline.encode(prompt), None)
158
+ history = [state, None, []] # [state, state_pre, tokens]
159
+ # print("History reloaded.")
160
+
161
+ [state, _, all_tokens] = history
162
+ state_pre_0 = copy.deepcopy(state)
163
+
164
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
165
+ state_pre_1 = copy.deepcopy(state) # For recovery
166
+
167
+ # print("Bot:", end='')
168
+
169
+ begin = len(all_tokens)
170
+ out_last = begin
171
+ out_str: str = ''
172
+ occurrence = {}
173
+ for i in range(300):
174
+ if i <= 0:
175
+ nl_bias = -float('inf')
176
+ elif i <= 30:
177
+ nl_bias = (i - 30) * 0.1
178
+ elif i <= 130:
179
+ nl_bias = 0
180
+ else:
181
+ nl_bias = (i - 130) * 0.25
182
+ out[187] += nl_bias
183
+ for n in occurrence:
184
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
185
+
186
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
187
+ next_tokens = [token]
188
+ if token == 0:
189
+ next_tokens = pipeline.encode('\n\n')
190
+ all_tokens += next_tokens
191
+
192
+ if token not in occurrence:
193
+ occurrence[token] = 1
194
+ else:
195
+ occurrence[token] += 1
196
+
197
+ out, state = model.forward(next_tokens, state)
198
+
199
+ tmp = pipeline.decode(all_tokens[out_last:])
200
+ if '\ufffd' not in tmp:
201
+ # print(tmp, end='', flush=True)
202
+ out_last = begin + i + 1
203
+ out_str += tmp
204
+
205
+ chatbot[-1][1] = out_str.strip()
206
+ history = [state, all_tokens]
207
+ yield chatbot, history
208
+
209
+ out_str = pipeline.decode(all_tokens[begin:])
210
+ out_str = out_str.replace("\r\n", '\n').replace('\\n', '\n')
211
+
212
+ if '\n\n' in out_str:
213
+ break
214
+
215
+ # State recovery
216
+ if f'{user}:' in out_str or f'{bot}:' in out_str:
217
+ idx_user = out_str.find(f'{user}:')
218
+ idx_user = len(out_str) if idx_user == -1 else idx_user
219
+ idx_bot = out_str.find(f'{bot}:')
220
+ idx_bot = len(out_str) if idx_bot == -1 else idx_bot
221
+ idx = min(idx_user, idx_bot)
222
+
223
+ if idx < len(out_str):
224
+ out_str = f" {out_str[:idx].strip()}\n\n"
225
+ tokens = pipeline.encode(out_str)
226
+
227
+ all_tokens = all_tokens[:begin] + tokens
228
+ out, state = model.forward(tokens, state_pre_1)
229
+ break
230
+
231
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
232
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
233
+
234
+ gc.collect()
235
+ torch.cuda.empty_cache()
236
+
237
+ chatbot[-1][1] = out_str.strip()
238
+ history = [state, state_pre_0, all_tokens]
239
+ yield chatbot, history
240
+
241
+ ##########################################################################
242
+
243
+ with gr.Blocks(title=title) as demo:
244
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>🐦Raven - {title}</h1>\n</div>")
245
+ with gr.Tab("Instruct mode"):
246
+ gr.Markdown(f"Raven is [RWKV 14B](https://github.com/BlinkDL/ChatRWKV) 100% RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM) finetuned to follow instructions. *** Please try examples first (bottom of page) *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}. Finetuned on alpaca, gpt4all, codealpaca and more. For best results, *** keep you prompt short and clear ***. <b>UPDATE: now with Chat (see above, as a tab)</b>.")
247
+ with gr.Row():
248
+ with gr.Column():
249
+ instruction = gr.Textbox(lines=2, label="Instruction", value="Tell me about ravens.")
250
+ input = gr.Textbox(lines=2, label="Input", placeholder="none")
251
+ token_count = gr.Slider(10, 200, label="Max Tokens", step=10, value=150)
252
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
253
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.5)
254
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.4)
255
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.4)
256
+ with gr.Column():
257
+ with gr.Row():
258
+ submit = gr.Button("Submit", variant="primary")
259
+ clear = gr.Button("Clear", variant="secondary")
260
+ output = gr.Textbox(label="Output", lines=5)
261
+ data = gr.Dataset(components=[instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Instructions", headers=["Instruction", "Input", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
262
+ submit.click(evaluate, [instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
263
+ clear.click(lambda: None, [], [output])
264
+ data.click(lambda x: x, [data], [instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty])
265
+
266
+ with gr.Tab("Chat (Experimental - Might be buggy - use ChatRWKV for reference)"):
267
+ gr.Markdown(f'''<b>*** The length of response is restricted in this demo. Use ChatRWKV for longer generations. ***</b> Say "go on" or "continue" can sometimes continue the response. If you'd like to edit the scenario, make sure to follow the exact same format: empty lines between (and only between) different speakers. Changes only take effect after you press [Clear]. <b>The default "Bob" & "Alice" names work the best.</b>''', label="Description")
268
+ with gr.Row():
269
+ with gr.Column():
270
+ chatbot = gr.Chatbot()
271
+ state = gr.State()
272
+ message = gr.Textbox(label="Message", value="Write me a python code to land on moon.")
273
+ with gr.Row():
274
+ send = gr.Button("Send", variant="primary")
275
+ alt = gr.Button("Alternative", variant="secondary")
276
+ clear = gr.Button("Clear", variant="secondary")
277
+ with gr.Column():
278
+ with gr.Row():
279
+ user_name = gr.Textbox(lines=1, max_lines=1, label="User Name", value="Bob")
280
+ bot_name = gr.Textbox(lines=1, max_lines=1, label="Bot Name", value="Alice")
281
+ prompt = gr.Textbox(lines=10, max_lines=50, label="Scenario", value=chat_intro)
282
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
283
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.5)
284
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.4)
285
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.4)
286
+ chat_inputs = [
287
+ prompt,
288
+ user_name,
289
+ bot_name,
290
+ chatbot,
291
+ state,
292
+ temperature,
293
+ top_p,
294
+ presence_penalty,
295
+ count_penalty
296
+ ]
297
+ chat_outputs = [chatbot, state]
298
+ message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
299
+ send.click(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
300
+ alt.click(alternative, [chatbot, state], [chatbot, state], queue=False).then(chat, chat_inputs, chat_outputs)
301
+ clear.click(lambda: ([], None, ""), [], [chatbot, state, message], queue=False)
302
 
303
+ demo.queue(concurrency_count=1, max_size=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  demo.launch(share=False)