ahuang11 commited on
Commit
a0db240
1 Parent(s): 44c8f56

Create ai.py

Browse files
Files changed (1) hide show
  1. ai.py +291 -0
ai.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=W0707
2
+ # pylint: disable=W0719
3
+
4
+ import os
5
+ import json
6
+ import tiktoken
7
+ import openai
8
+ from openai import OpenAI
9
+ import requests
10
+
11
+ from constants.cli import OPENAI_MODELS
12
+ from constants.ai import SYSTEM_PROMPT, PROMPT, API_URL
13
+
14
+
15
+ def retrieve(query, k=10, filters=None):
16
+ """Retrieves and returns dict.
17
+
18
+ Args:
19
+ query (str): User query to pass in
20
+ openai_api_key (str): openai api key. If not passed in, uses environment variable
21
+ k (int, optional): number of results passed back. Defaults to 10.
22
+ filters (dict, optional): Filters to apply to the query. You can filter based off
23
+ any piece of metadata by passing in a dict of the format {metadata_name: filter_value}
24
+ ie {"library_id": "1234"}.
25
+
26
+ See the README for more details:
27
+ https://github.com/fleet-ai/context/tree/main#using-fleet-contexts-rich-metadata
28
+
29
+ Returns:
30
+ list: List of queried results
31
+ """
32
+
33
+ url = f"{API_URL}/query"
34
+ params = {
35
+ "query": query,
36
+ "dataset": "python_libraries",
37
+ "n_results": k,
38
+ "filters": filters,
39
+ }
40
+ return requests.post(url, json=params, timeout=120).json()
41
+
42
+
43
+ def retrieve_context(query, openai_api_key, k=10, filters=None):
44
+ """Gets the context from our libraries vector db for a given query.
45
+
46
+ Args:
47
+ query (str): User input query
48
+ k (int, optional): number of retrieved results. Defaults to 10.
49
+ """
50
+
51
+ # First, we query the API
52
+ responses = retrieve(query, k=k, filters=filters)
53
+
54
+ # Then, we build the prompt_with_context string
55
+ prompt_with_context = ""
56
+ for response in responses:
57
+ prompt_with_context += f"\n\n### Context {response['metadata']['url']} ###\n{response['metadata']['text']}"
58
+ return {"role": "user", "content": prompt_with_context}
59
+
60
+
61
+ def construct_prompt(
62
+ messages,
63
+ context_message,
64
+ model="gpt-4-1106-preview",
65
+ cite_sources=True,
66
+ context_window=3000,
67
+ ):
68
+ """
69
+ Constructs a RAG (Retrieval-Augmented Generation) prompt by balancing the token count of messages and context_message.
70
+ If the total token count exceeds the maximum limit, it adjusts the token count of each to maintain a 1:1 proportion.
71
+ It then combines both lists and returns the result.
72
+
73
+ Parameters:
74
+ messages (List[dict]): List of messages to be included in the prompt.
75
+ context_message (dict): Context message to be included in the prompt.
76
+ model (str): The model to be used for encoding, default is "gpt-4-1106-preview".
77
+
78
+ Returns:
79
+ List[dict]: The constructed RAG prompt.
80
+ """
81
+ # Get the encoding; default to cl100k_base
82
+ if model in OPENAI_MODELS:
83
+ encoding = tiktoken.encoding_for_model(model)
84
+ else:
85
+ encoding = tiktoken.get_encoding("cl100k_base")
86
+
87
+ # 1) calculate tokens
88
+ reserved_space = 1000
89
+ max_messages_count = int((context_window - reserved_space) / 2)
90
+ max_context_count = int((context_window - reserved_space) / 2)
91
+
92
+ # 2) construct prompt
93
+ prompts = messages.copy()
94
+ prompts.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
95
+ if cite_sources:
96
+ prompts.insert(-1, {"role": "user", "content": PROMPT})
97
+
98
+ # 3) find how many tokens each list has
99
+ messages_token_count = len(
100
+ encoding.encode(
101
+ "\n".join(
102
+ [
103
+ f"<|im_start|>{message['role']}\n{message['content']}<|im_end|>"
104
+ for message in prompts
105
+ ]
106
+ )
107
+ )
108
+ )
109
+ context_token_count = len(
110
+ encoding.encode(
111
+ f"<|im_start|>{context_message['role']}\n{context_message['content']}<|im_end|>"
112
+ )
113
+ )
114
+
115
+ # 4) Balance the token count for each
116
+ if (messages_token_count + context_token_count) > (context_window - reserved_space):
117
+ # context has more than limit, messages has less than limit
118
+ if (messages_token_count < max_messages_count) and (
119
+ context_token_count > max_context_count
120
+ ):
121
+ max_context_count += max_messages_count - messages_token_count
122
+ # messages has more than limit, context has less than limit
123
+ elif (messages_token_count > max_messages_count) and (
124
+ context_token_count < max_context_count
125
+ ):
126
+ max_messages_count += max_context_count - context_token_count
127
+
128
+ # 5) Cut each list to the max count
129
+
130
+ # Cut down messages
131
+ while messages_token_count > max_messages_count:
132
+ removed_encoding = encoding.encode(
133
+ f"<|im_start|>{prompts[1]['role']}\n{prompts[1]['content']}<|im_end|>"
134
+ )
135
+ messages_token_count -= len(removed_encoding)
136
+ if messages_token_count < max_messages_count:
137
+ prompts = (
138
+ [prompts[0]]
139
+ + [
140
+ {
141
+ "role": prompts[1]["role"],
142
+ "content": encoding.decode(
143
+ removed_encoding[
144
+ : min(
145
+ int(max_messages_count -
146
+ messages_token_count),
147
+ len(removed_encoding),
148
+ )
149
+ ]
150
+ )
151
+ .replace("<|im_start|>", "")
152
+ .replace("<|im_end|>", ""),
153
+ }
154
+ ]
155
+ + prompts[2:]
156
+ )
157
+ else:
158
+ prompts = [prompts[0]] + prompts[2:]
159
+
160
+ # Cut down context
161
+ if context_token_count > max_context_count:
162
+ # Taking a proportion of the content chars length
163
+ reduced_chars_length = int(
164
+ len(context_message["content"]) *
165
+ (max_context_count / context_token_count)
166
+ )
167
+ context_message["content"] = context_message["content"][:reduced_chars_length]
168
+
169
+ # 6) Combine both lists
170
+ prompts.insert(-1, context_message)
171
+
172
+ return prompts
173
+
174
+
175
+ def get_remote_chat_response(messages, model="gpt-4-1106-preview"):
176
+ """
177
+ Returns a streamed OpenAI chat response.
178
+
179
+ Parameters:
180
+ messages (List[dict]): List of messages to be included in the prompt.
181
+ model (str): The model to be used for encoding, default is "gpt-4-1106-preview".
182
+
183
+ Returns:
184
+ str: The streamed OpenAI chat response.
185
+ """
186
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
187
+
188
+ try:
189
+ response = client.chat.completions.create(
190
+ model=model, messages=messages, temperature=0.2, stream=True
191
+ )
192
+
193
+ for chunk in response:
194
+ current_context = chunk.choices[0].delta.content
195
+ yield current_context
196
+
197
+ except openai.AuthenticationError as error:
198
+ print("401 Authentication Error:", error)
199
+ raise Exception(
200
+ "Invalid OPENAI_API_KEY. Please re-run with a valid key.")
201
+
202
+ except Exception as error:
203
+ print("Streaming Error:", error)
204
+ raise Exception("Internal Server Error")
205
+
206
+
207
+ def get_other_chat_response(messages, model="local-model"):
208
+ """
209
+ Returns a streamed chat response from a local server.
210
+
211
+ Parameters:
212
+ messages (List[dict]): List of messages to be included in the prompt.
213
+ model (str): The model to be used for encoding, default is "gpt-4-1106-preview".
214
+
215
+ Returns:
216
+ str: The streamed chat response.
217
+ """
218
+ try:
219
+ if model == "local-model":
220
+ url = "http://localhost:1234/v1/chat/completions"
221
+ headers = {"Content-Type": "application/json"}
222
+ data = {
223
+ "messages": messages,
224
+ "temperature": 0.2,
225
+ "max_tokens": -1,
226
+ "stream": True,
227
+ }
228
+ response = requests.post(
229
+ url, headers=headers, data=json.dumps(data), stream=True, timeout=120
230
+ )
231
+
232
+ if response.status_code == 200:
233
+ for chunk in response.iter_content(chunk_size=None):
234
+ decoded_chunk = chunk.decode()
235
+ if (
236
+ "data:" in decoded_chunk
237
+ and decoded_chunk.split("data:")[1].strip()
238
+ ): # Check if the chunk is not empty
239
+ try:
240
+ chunk_dict = json.loads(
241
+ decoded_chunk.split("data:")[1].strip()
242
+ )
243
+ yield chunk_dict["choices"][0]["delta"].get("content", "")
244
+ except json.JSONDecodeError:
245
+ pass
246
+ else:
247
+ print(f"Error: {response.status_code}, {response.text}")
248
+ raise Exception("Internal Server Error")
249
+ else:
250
+ if not os.environ.get("OPENROUTER_API_KEY"):
251
+ raise Exception(
252
+ f"For non-OpenAI models, like {model}, set your OPENROUTER_API_KEY."
253
+ )
254
+
255
+ response = requests.post(
256
+ url="https://openrouter.ai/api/v1/chat/completions",
257
+ headers={
258
+ "Authorization": f"Bearer {os.environ.get('OPENROUTER_API_KEY')}",
259
+ "HTTP-Referer": os.environ.get(
260
+ "OPENROUTER_APP_URL", "https://fleet.so/context"
261
+ ),
262
+ "X-Title": os.environ.get("OPENROUTER_APP_TITLE", "Fleet Context"),
263
+ "Content-Type": "application/json",
264
+ },
265
+ data=json.dumps(
266
+ {"model": model, "messages": messages, "stream": True}),
267
+ stream=True,
268
+ timeout=120,
269
+ )
270
+ if response.status_code == 200:
271
+ for chunk in response.iter_lines():
272
+ decoded_chunk = chunk.decode("utf-8")
273
+ if (
274
+ "data:" in decoded_chunk
275
+ and decoded_chunk.split("data:")[1].strip()
276
+ ): # Check if the chunk is not empty
277
+ try:
278
+ chunk_dict = json.loads(
279
+ decoded_chunk.split("data:")[1].strip()
280
+ )
281
+ yield chunk_dict["choices"][0]["delta"].get("content", "")
282
+ except json.JSONDecodeError:
283
+ pass
284
+ else:
285
+ print(f"Error: {response.status_code}, {response.text}")
286
+ raise Exception("Internal Server Error")
287
+
288
+ except requests.exceptions.RequestException as error:
289
+ print("Request Error:", error)
290
+ raise Exception(
291
+ "Invalid request. Please check your request parameters.")