JohnSmith9982 commited on
Commit
dae5193
1 Parent(s): ffff3dc

Delete chat_func.py

Browse files
Files changed (1) hide show
  1. chat_func.py +0 -456
chat_func.py DELETED
@@ -1,456 +0,0 @@
1
- # -*- coding:utf-8 -*-
2
- from __future__ import annotations
3
- from typing import TYPE_CHECKING, List
4
-
5
- import logging
6
- import json
7
- import os
8
- import requests
9
- import urllib3
10
-
11
- from tqdm import tqdm
12
- import colorama
13
- from duckduckgo_search import ddg
14
- import asyncio
15
- import aiohttp
16
-
17
- from presets import *
18
- from llama_func import *
19
- from utils import *
20
-
21
- # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
22
-
23
- if TYPE_CHECKING:
24
- from typing import TypedDict
25
-
26
- class DataframeData(TypedDict):
27
- headers: List[str]
28
- data: List[List[str | int | bool]]
29
-
30
-
31
- initial_prompt = "You are a helpful assistant."
32
- API_URL = "https://api.openai.com/v1/chat/completions"
33
- HISTORY_DIR = "history"
34
- TEMPLATES_DIR = "templates"
35
-
36
- def get_response(
37
- openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
38
- ):
39
- headers = {
40
- "Content-Type": "application/json",
41
- "Authorization": f"Bearer {openai_api_key}",
42
- }
43
-
44
- history = [construct_system(system_prompt), *history]
45
-
46
- payload = {
47
- "model": selected_model,
48
- "messages": history, # [{"role": "user", "content": f"{inputs}"}],
49
- "temperature": temperature, # 1.0,
50
- "top_p": top_p, # 1.0,
51
- "n": 1,
52
- "stream": stream,
53
- "presence_penalty": 0,
54
- "frequency_penalty": 0,
55
- }
56
- if stream:
57
- timeout = timeout_streaming
58
- else:
59
- timeout = timeout_all
60
-
61
- # 获取环境变量中的代理设置
62
- http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
63
- https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
64
-
65
- # 如果存在代理设置,使用它们
66
- proxies = {}
67
- if http_proxy:
68
- logging.info(f"Using HTTP proxy: {http_proxy}")
69
- proxies["http"] = http_proxy
70
- if https_proxy:
71
- logging.info(f"Using HTTPS proxy: {https_proxy}")
72
- proxies["https"] = https_proxy
73
-
74
- # 如果有代理,使用代理发送请求,否则使用默认设置发送请求
75
- if proxies:
76
- response = requests.post(
77
- API_URL,
78
- headers=headers,
79
- json=payload,
80
- stream=True,
81
- timeout=timeout,
82
- proxies=proxies,
83
- )
84
- else:
85
- response = requests.post(
86
- API_URL,
87
- headers=headers,
88
- json=payload,
89
- stream=True,
90
- timeout=timeout,
91
- )
92
- return response
93
-
94
-
95
- def stream_predict(
96
- openai_api_key,
97
- system_prompt,
98
- history,
99
- inputs,
100
- chatbot,
101
- all_token_counts,
102
- top_p,
103
- temperature,
104
- selected_model,
105
- fake_input=None,
106
- display_append=""
107
- ):
108
- def get_return_value():
109
- return chatbot, history, status_text, all_token_counts
110
-
111
- logging.info("实时回答模式")
112
- partial_words = ""
113
- counter = 0
114
- status_text = "开始实时传输回答……"
115
- history.append(construct_user(inputs))
116
- history.append(construct_assistant(""))
117
- if fake_input:
118
- chatbot.append((fake_input, ""))
119
- else:
120
- chatbot.append((inputs, ""))
121
- user_token_count = 0
122
- if len(all_token_counts) == 0:
123
- system_prompt_token_count = count_token(construct_system(system_prompt))
124
- user_token_count = (
125
- count_token(construct_user(inputs)) + system_prompt_token_count
126
- )
127
- else:
128
- user_token_count = count_token(construct_user(inputs))
129
- all_token_counts.append(user_token_count)
130
- logging.info(f"输入token计数: {user_token_count}")
131
- yield get_return_value()
132
- try:
133
- response = get_response(
134
- openai_api_key,
135
- system_prompt,
136
- history,
137
- temperature,
138
- top_p,
139
- True,
140
- selected_model,
141
- )
142
- except requests.exceptions.ConnectTimeout:
143
- status_text = (
144
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
145
- )
146
- yield get_return_value()
147
- return
148
- except requests.exceptions.ReadTimeout:
149
- status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
150
- yield get_return_value()
151
- return
152
-
153
- yield get_return_value()
154
- error_json_str = ""
155
-
156
- for chunk in response.iter_lines():
157
- if counter == 0:
158
- counter += 1
159
- continue
160
- counter += 1
161
- # check whether each line is non-empty
162
- if chunk:
163
- chunk = chunk.decode()
164
- chunklength = len(chunk)
165
- try:
166
- chunk = json.loads(chunk[6:])
167
- except json.JSONDecodeError:
168
- logging.info(chunk)
169
- error_json_str += chunk
170
- status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
171
- yield get_return_value()
172
- continue
173
- # decode each line as response data is in bytes
174
- if chunklength > 6 and "delta" in chunk["choices"][0]:
175
- finish_reason = chunk["choices"][0]["finish_reason"]
176
- status_text = construct_token_message(
177
- sum(all_token_counts), stream=True
178
- )
179
- if finish_reason == "stop":
180
- yield get_return_value()
181
- break
182
- try:
183
- partial_words = (
184
- partial_words + chunk["choices"][0]["delta"]["content"]
185
- )
186
- except KeyError:
187
- status_text = (
188
- standard_error_msg
189
- + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
190
- + str(sum(all_token_counts))
191
- )
192
- yield get_return_value()
193
- break
194
- history[-1] = construct_assistant(partial_words)
195
- chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
196
- all_token_counts[-1] += 1
197
- yield get_return_value()
198
-
199
-
200
- def predict_all(
201
- openai_api_key,
202
- system_prompt,
203
- history,
204
- inputs,
205
- chatbot,
206
- all_token_counts,
207
- top_p,
208
- temperature,
209
- selected_model,
210
- fake_input=None,
211
- display_append=""
212
- ):
213
- logging.info("一次性回答模式")
214
- history.append(construct_user(inputs))
215
- history.append(construct_assistant(""))
216
- if fake_input:
217
- chatbot.append((fake_input, ""))
218
- else:
219
- chatbot.append((inputs, ""))
220
- all_token_counts.append(count_token(construct_user(inputs)))
221
- try:
222
- response = get_response(
223
- openai_api_key,
224
- system_prompt,
225
- history,
226
- temperature,
227
- top_p,
228
- False,
229
- selected_model,
230
- )
231
- except requests.exceptions.ConnectTimeout:
232
- status_text = (
233
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
234
- )
235
- return chatbot, history, status_text, all_token_counts
236
- except requests.exceptions.ProxyError:
237
- status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
238
- return chatbot, history, status_text, all_token_counts
239
- except requests.exceptions.SSLError:
240
- status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
241
- return chatbot, history, status_text, all_token_counts
242
- response = json.loads(response.text)
243
- content = response["choices"][0]["message"]["content"]
244
- history[-1] = construct_assistant(content)
245
- chatbot[-1] = (chatbot[-1][0], content+display_append)
246
- total_token_count = response["usage"]["total_tokens"]
247
- all_token_counts[-1] = total_token_count - sum(all_token_counts)
248
- status_text = construct_token_message(total_token_count)
249
- return chatbot, history, status_text, all_token_counts
250
-
251
-
252
- def predict(
253
- openai_api_key,
254
- system_prompt,
255
- history,
256
- inputs,
257
- chatbot,
258
- all_token_counts,
259
- top_p,
260
- temperature,
261
- stream=False,
262
- selected_model=MODELS[0],
263
- use_websearch=False,
264
- files = None,
265
- should_check_token_count=True,
266
- ): # repetition_penalty, top_k
267
- logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
268
- if files:
269
- msg = "构建索引中……(这可能需要比较久的时间)"
270
- logging.info(msg)
271
- yield chatbot, history, msg, all_token_counts
272
- index = construct_index(openai_api_key, file_src=files)
273
- msg = "索引构建完成,获取回答中……"
274
- yield chatbot, history, msg, all_token_counts
275
- history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot)
276
- yield chatbot, history, status_text, all_token_counts
277
- return
278
-
279
- old_inputs = ""
280
- link_references = []
281
- if use_websearch:
282
- search_results = ddg(inputs, max_results=5)
283
- old_inputs = inputs
284
- web_results = []
285
- for idx, result in enumerate(search_results):
286
- logging.info(f"搜索结果{idx + 1}:{result}")
287
- domain_name = urllib3.util.parse_url(result["href"]).host
288
- web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
289
- link_references.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
290
- link_references = "\n\n" + "".join(link_references)
291
- inputs = (
292
- replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
293
- .replace("{query}", inputs)
294
- .replace("{web_results}", "\n\n".join(web_results))
295
- )
296
- else:
297
- link_references = ""
298
-
299
- if len(openai_api_key) != 51:
300
- status_text = standard_error_msg + no_apikey_msg
301
- logging.info(status_text)
302
- chatbot.append((inputs, ""))
303
- if len(history) == 0:
304
- history.append(construct_user(inputs))
305
- history.append("")
306
- all_token_counts.append(0)
307
- else:
308
- history[-2] = construct_user(inputs)
309
- yield chatbot, history, status_text, all_token_counts
310
- return
311
-
312
- yield chatbot, history, "开始生成回答……", all_token_counts
313
-
314
- if stream:
315
- logging.info("使用流式传输")
316
- iter = stream_predict(
317
- openai_api_key,
318
- system_prompt,
319
- history,
320
- inputs,
321
- chatbot,
322
- all_token_counts,
323
- top_p,
324
- temperature,
325
- selected_model,
326
- fake_input=old_inputs,
327
- display_append=link_references
328
- )
329
- for chatbot, history, status_text, all_token_counts in iter:
330
- yield chatbot, history, status_text, all_token_counts
331
- else:
332
- logging.info("不使用流式传输")
333
- chatbot, history, status_text, all_token_counts = predict_all(
334
- openai_api_key,
335
- system_prompt,
336
- history,
337
- inputs,
338
- chatbot,
339
- all_token_counts,
340
- top_p,
341
- temperature,
342
- selected_model,
343
- fake_input=old_inputs,
344
- display_append=link_references
345
- )
346
- yield chatbot, history, status_text, all_token_counts
347
-
348
- logging.info(f"传输完毕。当前token计数为{all_token_counts}")
349
- if len(history) > 1 and history[-1]["content"] != inputs:
350
- logging.info(
351
- "回答为:"
352
- + colorama.Fore.BLUE
353
- + f"{history[-1]['content']}"
354
- + colorama.Style.RESET_ALL
355
- )
356
-
357
- if stream:
358
- max_token = max_token_streaming
359
- else:
360
- max_token = max_token_all
361
-
362
- if sum(all_token_counts) > max_token and should_check_token_count:
363
- status_text = f"精简token中{all_token_counts}/{max_token}"
364
- logging.info(status_text)
365
- yield chatbot, history, status_text, all_token_counts
366
- iter = reduce_token_size(
367
- openai_api_key,
368
- system_prompt,
369
- history,
370
- chatbot,
371
- all_token_counts,
372
- top_p,
373
- temperature,
374
- max_token//2,
375
- selected_model=selected_model,
376
- )
377
- for chatbot, history, status_text, all_token_counts in iter:
378
- status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
379
- yield chatbot, history, status_text, all_token_counts
380
-
381
-
382
- def retry(
383
- openai_api_key,
384
- system_prompt,
385
- history,
386
- chatbot,
387
- token_count,
388
- top_p,
389
- temperature,
390
- stream=False,
391
- selected_model=MODELS[0],
392
- ):
393
- logging.info("重试中……")
394
- if len(history) == 0:
395
- yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
396
- return
397
- history.pop()
398
- inputs = history.pop()["content"]
399
- token_count.pop()
400
- iter = predict(
401
- openai_api_key,
402
- system_prompt,
403
- history,
404
- inputs,
405
- chatbot,
406
- token_count,
407
- top_p,
408
- temperature,
409
- stream=stream,
410
- selected_model=selected_model,
411
- )
412
- logging.info("重试中……")
413
- for x in iter:
414
- yield x
415
- logging.info("重试完毕")
416
-
417
-
418
- def reduce_token_size(
419
- openai_api_key,
420
- system_prompt,
421
- history,
422
- chatbot,
423
- token_count,
424
- top_p,
425
- temperature,
426
- max_token_count,
427
- selected_model=MODELS[0],
428
- ):
429
- logging.info("开始减少token数量……")
430
- iter = predict(
431
- openai_api_key,
432
- system_prompt,
433
- history,
434
- summarize_prompt,
435
- chatbot,
436
- token_count,
437
- top_p,
438
- temperature,
439
- selected_model=selected_model,
440
- should_check_token_count=False,
441
- )
442
- logging.info(f"chatbot: {chatbot}")
443
- flag = False
444
- for chatbot, history, status_text, previous_token_count in iter:
445
- num_chat = find_n(previous_token_count, max_token_count)
446
- if flag:
447
- chatbot = chatbot[:-1]
448
- flag = True
449
- history = history[-2*num_chat:] if num_chat > 0 else []
450
- token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
451
- msg = f"保留了最近{num_chat}轮对话"
452
- yield chatbot, history, msg + "," + construct_token_message(
453
- sum(token_count) if len(token_count) > 0 else 0,
454
- ), token_count
455
- logging.info(msg)
456
- logging.info("减少token数量完毕")