Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
•
29ea64b
1
Parent(s):
c596b21
Update with h2oGPT hash c7453f1b1ab51fb1cd342a9867f23cd7b538e000
Browse files- src/client_test.py +8 -17
- src/gen.py +7 -44
- src/gpt_langchain.py +279 -337
- src/gradio_runner.py +10 -7
src/client_test.py
CHANGED
@@ -80,7 +80,7 @@ def get_args(prompt, prompt_type=None, chat=False, stream_output=False,
|
|
80 |
version=None,
|
81 |
h2ogpt_key=None,
|
82 |
visible_models=None,
|
83 |
-
system_prompt='', # default of no system prompt
|
84 |
add_search_to_context=False,
|
85 |
chat_conversation=None,
|
86 |
text_context_list=None,
|
@@ -256,18 +256,13 @@ def run_client_nochat_api(prompt, prompt_type, max_new_tokens, version=None, h2o
|
|
256 |
|
257 |
|
258 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
259 |
-
def test_client_basic_api_lean(
|
260 |
-
|
261 |
-
|
262 |
-
version=version, h2ogpt_key=h2ogpt_key,
|
263 |
-
chat_conversation=chat_conversation,
|
264 |
-
system_prompt=system_prompt)
|
265 |
|
266 |
|
267 |
-
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None
|
268 |
-
|
269 |
-
kwargs = dict(instruction_nochat=prompt, h2ogpt_key=h2ogpt_key, chat_conversation=chat_conversation,
|
270 |
-
system_prompt=system_prompt)
|
271 |
|
272 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
273 |
client = get_client(serialize=True)
|
@@ -367,9 +362,7 @@ def run_client_chat(prompt='',
|
|
367 |
langchain_agents=[],
|
368 |
prompt_type=None, prompt_dict=None,
|
369 |
version=None,
|
370 |
-
h2ogpt_key=None
|
371 |
-
chat_conversation=None,
|
372 |
-
system_prompt=''):
|
373 |
client = get_client(serialize=False)
|
374 |
|
375 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
@@ -379,9 +372,7 @@ def run_client_chat(prompt='',
|
|
379 |
langchain_agents=langchain_agents,
|
380 |
prompt_dict=prompt_dict,
|
381 |
version=version,
|
382 |
-
h2ogpt_key=h2ogpt_key
|
383 |
-
chat_conversation=chat_conversation,
|
384 |
-
system_prompt=system_prompt)
|
385 |
return run_client(client, prompt, args, kwargs)
|
386 |
|
387 |
|
|
|
80 |
version=None,
|
81 |
h2ogpt_key=None,
|
82 |
visible_models=None,
|
83 |
+
system_prompt='', # default of no system prompt tiggered by empty string
|
84 |
add_search_to_context=False,
|
85 |
chat_conversation=None,
|
86 |
text_context_list=None,
|
|
|
256 |
|
257 |
|
258 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
259 |
+
def test_client_basic_api_lean(prompt_type='human_bot', version=None, h2ogpt_key=None):
|
260 |
+
return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50,
|
261 |
+
version=version, h2ogpt_key=h2ogpt_key)
|
|
|
|
|
|
|
262 |
|
263 |
|
264 |
+
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None):
|
265 |
+
kwargs = dict(instruction_nochat=prompt, h2ogpt_key=h2ogpt_key)
|
|
|
|
|
266 |
|
267 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
268 |
client = get_client(serialize=True)
|
|
|
362 |
langchain_agents=[],
|
363 |
prompt_type=None, prompt_dict=None,
|
364 |
version=None,
|
365 |
+
h2ogpt_key=None):
|
|
|
|
|
366 |
client = get_client(serialize=False)
|
367 |
|
368 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
|
|
372 |
langchain_agents=langchain_agents,
|
373 |
prompt_dict=prompt_dict,
|
374 |
version=version,
|
375 |
+
h2ogpt_key=h2ogpt_key)
|
|
|
|
|
376 |
return run_client(client, prompt, args, kwargs)
|
377 |
|
378 |
|
src/gen.py
CHANGED
@@ -335,7 +335,7 @@ def main(
|
|
335 |
|
336 |
Or Address can be for vLLM:
|
337 |
Use: "vllm:IP:port" for OpenAI-compliant vLLM endpoint
|
338 |
-
|
339 |
|
340 |
Or Address can be replicate:
|
341 |
Use:
|
@@ -2236,17 +2236,6 @@ def evaluate(
|
|
2236 |
instruction = instruction_nochat
|
2237 |
iinput = iinput_nochat
|
2238 |
|
2239 |
-
# avoid instruction in chat_conversation itself, since always used as additional context to prompt in what follows
|
2240 |
-
if isinstance(chat_conversation, list) and \
|
2241 |
-
len(chat_conversation) > 0 and \
|
2242 |
-
len(chat_conversation[-1]) == 2 and \
|
2243 |
-
chat_conversation[-1][0] == instruction:
|
2244 |
-
chat_conversation = chat_conversation[:-1]
|
2245 |
-
if not add_chat_history_to_context:
|
2246 |
-
# make it easy to ignore without needing add_chat_history_to_context
|
2247 |
-
# some langchain or unit test may need to then handle more general case
|
2248 |
-
chat_conversation = []
|
2249 |
-
|
2250 |
# in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice
|
2251 |
model_lower = base_model.lower()
|
2252 |
if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom':
|
@@ -2495,8 +2484,7 @@ def evaluate(
|
|
2495 |
prompt, \
|
2496 |
instruction, iinput, context, \
|
2497 |
num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
|
2498 |
-
chat_index,
|
2499 |
-
top_k_docs_trial, one_doc_size = \
|
2500 |
get_limited_prompt(instruction,
|
2501 |
iinput,
|
2502 |
tokenizer,
|
@@ -2564,6 +2552,8 @@ def evaluate(
|
|
2564 |
sanitize_bot_response=sanitize_bot_response)
|
2565 |
yield dict(response=response, sources=sources, save_dict=dict())
|
2566 |
elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
|
|
|
|
|
2567 |
if system_prompt in [None, 'None', 'auto']:
|
2568 |
openai_system_prompt = "You are a helpful assistant."
|
2569 |
else:
|
@@ -2571,16 +2561,7 @@ def evaluate(
|
|
2571 |
messages0 = []
|
2572 |
if openai_system_prompt:
|
2573 |
messages0.append({"role": "system", "content": openai_system_prompt})
|
2574 |
-
|
2575 |
-
assert external_handle_chat_conversation, "Should be handling only externally"
|
2576 |
-
# chat_index handles token counting issues
|
2577 |
-
for message1 in chat_conversation[chat_index:]:
|
2578 |
-
if len(message1) == 2:
|
2579 |
-
messages0.append(
|
2580 |
-
{'role': 'user', 'content': message1[0] if message1[0] is not None else ''})
|
2581 |
-
messages0.append(
|
2582 |
-
{'role': 'assistant', 'content': message1[1] if message1[1] is not None else ''})
|
2583 |
-
messages0.append({'role': 'user', 'content': prompt if prompt is not None else ''})
|
2584 |
responses = openai.ChatCompletion.create(
|
2585 |
model=base_model,
|
2586 |
messages=messages0,
|
@@ -3628,27 +3609,13 @@ def get_limited_prompt(instruction,
|
|
3628 |
stream_output = prompter.stream_output
|
3629 |
system_prompt = prompter.system_prompt
|
3630 |
|
3631 |
-
generate_prompt_type = prompt_type
|
3632 |
-
external_handle_chat_conversation = False
|
3633 |
-
if any(inference_server.startswith(x) for x in ['openai_chat', 'openai_azure_chat', 'vllm_chat']):
|
3634 |
-
# Chat APIs do not take prompting
|
3635 |
-
# Replicate does not need prompting if no chat history, but in general can take prompting
|
3636 |
-
# if using prompter, prompter.system_prompt will already be filled with automatic (e.g. from llama-2),
|
3637 |
-
# so if replicate final prompt with system prompt still correct because only access prompter.system_prompt that was already set
|
3638 |
-
# below already true for openai,
|
3639 |
-
# but not vllm by default as that can be any model and handled by FastChat API inside vLLM itself
|
3640 |
-
generate_prompt_type = 'plain'
|
3641 |
-
# Chat APIs don't handle chat history via single prompt, but in messages, assumed to be handled outside this function
|
3642 |
-
chat_conversation = []
|
3643 |
-
external_handle_chat_conversation = True
|
3644 |
-
|
3645 |
# merge handles if chat_conversation is None
|
3646 |
history = []
|
3647 |
history = merge_chat_conversation_history(chat_conversation, history)
|
3648 |
history_to_context_func = functools.partial(history_to_context,
|
3649 |
langchain_mode=langchain_mode,
|
3650 |
add_chat_history_to_context=add_chat_history_to_context,
|
3651 |
-
prompt_type=
|
3652 |
prompt_dict=prompt_dict,
|
3653 |
chat=chat,
|
3654 |
model_max_length=model_max_length,
|
@@ -3781,9 +3748,6 @@ def get_limited_prompt(instruction,
|
|
3781 |
stream_output = False # doesn't matter
|
3782 |
prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output,
|
3783 |
system_prompt=system_prompt)
|
3784 |
-
if prompt_type != generate_prompt_type:
|
3785 |
-
# override just this attribute, keep system_prompt etc. from original prompt_type
|
3786 |
-
prompter.prompt_type = generate_prompt_type
|
3787 |
|
3788 |
data_point = dict(context=context, instruction=instruction, input=iinput)
|
3789 |
# handle promptA/promptB addition if really from history.
|
@@ -3796,8 +3760,7 @@ def get_limited_prompt(instruction,
|
|
3796 |
return prompt, \
|
3797 |
instruction, iinput, context, \
|
3798 |
num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
|
3799 |
-
chat_index,
|
3800 |
-
top_k_docs, one_doc_size
|
3801 |
|
3802 |
|
3803 |
def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None):
|
|
|
335 |
|
336 |
Or Address can be for vLLM:
|
337 |
Use: "vllm:IP:port" for OpenAI-compliant vLLM endpoint
|
338 |
+
Note: vllm_chat not supported by vLLM project.
|
339 |
|
340 |
Or Address can be replicate:
|
341 |
Use:
|
|
|
2236 |
instruction = instruction_nochat
|
2237 |
iinput = iinput_nochat
|
2238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2239 |
# in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice
|
2240 |
model_lower = base_model.lower()
|
2241 |
if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom':
|
|
|
2484 |
prompt, \
|
2485 |
instruction, iinput, context, \
|
2486 |
num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
|
2487 |
+
chat_index, top_k_docs_trial, one_doc_size = \
|
|
|
2488 |
get_limited_prompt(instruction,
|
2489 |
iinput,
|
2490 |
tokenizer,
|
|
|
2552 |
sanitize_bot_response=sanitize_bot_response)
|
2553 |
yield dict(response=response, sources=sources, save_dict=dict())
|
2554 |
elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
|
2555 |
+
if inf_type == 'vllm_chat':
|
2556 |
+
raise NotImplementedError('%s not supported by vLLM' % inf_type)
|
2557 |
if system_prompt in [None, 'None', 'auto']:
|
2558 |
openai_system_prompt = "You are a helpful assistant."
|
2559 |
else:
|
|
|
2561 |
messages0 = []
|
2562 |
if openai_system_prompt:
|
2563 |
messages0.append({"role": "system", "content": openai_system_prompt})
|
2564 |
+
messages0.append({'role': 'user', 'content': prompt})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2565 |
responses = openai.ChatCompletion.create(
|
2566 |
model=base_model,
|
2567 |
messages=messages0,
|
|
|
3609 |
stream_output = prompter.stream_output
|
3610 |
system_prompt = prompter.system_prompt
|
3611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3612 |
# merge handles if chat_conversation is None
|
3613 |
history = []
|
3614 |
history = merge_chat_conversation_history(chat_conversation, history)
|
3615 |
history_to_context_func = functools.partial(history_to_context,
|
3616 |
langchain_mode=langchain_mode,
|
3617 |
add_chat_history_to_context=add_chat_history_to_context,
|
3618 |
+
prompt_type=prompt_type,
|
3619 |
prompt_dict=prompt_dict,
|
3620 |
chat=chat,
|
3621 |
model_max_length=model_max_length,
|
|
|
3748 |
stream_output = False # doesn't matter
|
3749 |
prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output,
|
3750 |
system_prompt=system_prompt)
|
|
|
|
|
|
|
3751 |
|
3752 |
data_point = dict(context=context, instruction=instruction, input=iinput)
|
3753 |
# handle promptA/promptB addition if really from history.
|
|
|
3760 |
return prompt, \
|
3761 |
instruction, iinput, context, \
|
3762 |
num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
|
3763 |
+
chat_index, top_k_docs, one_doc_size
|
|
|
3764 |
|
3765 |
|
3766 |
def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None):
|
src/gpt_langchain.py
CHANGED
@@ -29,11 +29,10 @@ import yaml
|
|
29 |
|
30 |
from joblib import delayed
|
31 |
from langchain.callbacks import streaming_stdout
|
32 |
-
from langchain.callbacks.base import Callbacks
|
33 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
34 |
from langchain.llms.huggingface_pipeline import VALID_TASKS
|
35 |
from langchain.llms.utils import enforce_stop_tokens
|
36 |
-
from langchain.schema import LLMResult, Generation
|
37 |
from langchain.tools import PythonREPLTool
|
38 |
from langchain.tools.json.tool import JsonSpec
|
39 |
from tqdm import tqdm
|
@@ -945,10 +944,7 @@ class H2OReplicate(Replicate):
|
|
945 |
assert self.tokenizer is not None
|
946 |
from h2oai_pipeline import H2OTextGenerationPipeline
|
947 |
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
948 |
-
# Note Replicate handles the prompting of the specific model
|
949 |
-
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
950 |
-
prompt = self.prompter.generate_prompt(data_point)
|
951 |
-
|
952 |
return super()._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
|
953 |
|
954 |
def get_token_ids(self, text: str) -> List[int]:
|
@@ -957,98 +953,21 @@ class H2OReplicate(Replicate):
|
|
957 |
# return _get_token_ids_default_method(text)
|
958 |
|
959 |
|
960 |
-
class
|
961 |
-
def get_messages(self, prompts):
|
962 |
-
from langchain.schema import AIMessage, SystemMessage, HumanMessage
|
963 |
-
messages = []
|
964 |
-
if self.system_prompt:
|
965 |
-
messages.append(SystemMessage(content=self.system_prompt))
|
966 |
-
if self.chat_conversation:
|
967 |
-
for messages1 in self.chat_conversation:
|
968 |
-
messages.append(HumanMessage(content=messages1[0] if messages1[0] is not None else ''))
|
969 |
-
messages.append(AIMessage(content=messages1[1] if messages1[1] is not None else ''))
|
970 |
-
assert len(prompts) == 1, "Not implemented"
|
971 |
-
messages.append(HumanMessage(content=prompts[0].text if prompts[0].text is not None else ''))
|
972 |
-
return [messages]
|
973 |
-
|
974 |
-
|
975 |
-
class H2OChatOpenAI(ChatOpenAI, ExtraChat):
|
976 |
-
tokenizer: Any = None # for vllm_chat
|
977 |
-
system_prompt: Any = None
|
978 |
-
chat_conversation: Any = []
|
979 |
-
|
980 |
@classmethod
|
981 |
def _all_required_field_names(cls) -> Set:
|
982 |
_all_required_field_names = super(ChatOpenAI, cls)._all_required_field_names()
|
983 |
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
|
984 |
return _all_required_field_names
|
985 |
|
986 |
-
def get_token_ids(self, text: str) -> List[int]:
|
987 |
-
if self.tokenizer is not None:
|
988 |
-
return self.tokenizer.encode(text)
|
989 |
-
else:
|
990 |
-
# OpenAI uses tiktoken
|
991 |
-
return super().get_token_ids(text)
|
992 |
-
|
993 |
-
def generate_prompt(
|
994 |
-
self,
|
995 |
-
prompts: List[PromptValue],
|
996 |
-
stop: Optional[List[str]] = None,
|
997 |
-
callbacks: Callbacks = None,
|
998 |
-
**kwargs: Any,
|
999 |
-
) -> LLMResult:
|
1000 |
-
prompt_messages = self.get_messages(prompts)
|
1001 |
-
# prompt_messages = [p.to_messages() for p in prompts]
|
1002 |
-
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
1003 |
-
|
1004 |
-
async def agenerate_prompt(
|
1005 |
-
self,
|
1006 |
-
prompts: List[PromptValue],
|
1007 |
-
stop: Optional[List[str]] = None,
|
1008 |
-
callbacks: Callbacks = None,
|
1009 |
-
**kwargs: Any,
|
1010 |
-
) -> LLMResult:
|
1011 |
-
prompt_messages = self.get_messages(prompts)
|
1012 |
-
# prompt_messages = [p.to_messages() for p in prompts]
|
1013 |
-
return await self.agenerate(
|
1014 |
-
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
1015 |
-
)
|
1016 |
-
|
1017 |
-
|
1018 |
-
class H2OAzureChatOpenAI(AzureChatOpenAI, ExtraChat):
|
1019 |
-
system_prompt: Any = None
|
1020 |
-
chat_conversation: Any = []
|
1021 |
|
|
|
1022 |
@classmethod
|
1023 |
def _all_required_field_names(cls) -> Set:
|
1024 |
_all_required_field_names = super(AzureChatOpenAI, cls)._all_required_field_names()
|
1025 |
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
|
1026 |
return _all_required_field_names
|
1027 |
|
1028 |
-
def generate_prompt(
|
1029 |
-
self,
|
1030 |
-
prompts: List[PromptValue],
|
1031 |
-
stop: Optional[List[str]] = None,
|
1032 |
-
callbacks: Callbacks = None,
|
1033 |
-
**kwargs: Any,
|
1034 |
-
) -> LLMResult:
|
1035 |
-
prompt_messages = self.get_messages(prompts)
|
1036 |
-
# prompt_messages = [p.to_messages() for p in prompts]
|
1037 |
-
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
1038 |
-
|
1039 |
-
async def agenerate_prompt(
|
1040 |
-
self,
|
1041 |
-
prompts: List[PromptValue],
|
1042 |
-
stop: Optional[List[str]] = None,
|
1043 |
-
callbacks: Callbacks = None,
|
1044 |
-
**kwargs: Any,
|
1045 |
-
) -> LLMResult:
|
1046 |
-
prompt_messages = self.get_messages(prompts)
|
1047 |
-
# prompt_messages = [p.to_messages() for p in prompts]
|
1048 |
-
return await self.agenerate(
|
1049 |
-
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
1050 |
-
)
|
1051 |
-
|
1052 |
|
1053 |
class H2OAzureOpenAI(AzureOpenAI):
|
1054 |
@classmethod
|
@@ -1133,7 +1052,7 @@ def get_llm(use_openai_model=False,
|
|
1133 |
if 'meta/llama' in model_string:
|
1134 |
temperature = max(0.01, temperature if do_sample else 0)
|
1135 |
else:
|
1136 |
-
temperature =
|
1137 |
gen_kwargs = dict(temperature=temperature,
|
1138 |
seed=1234,
|
1139 |
max_length=max_new_tokens, # langchain
|
@@ -1149,7 +1068,8 @@ def get_llm(use_openai_model=False,
|
|
1149 |
if system_prompt:
|
1150 |
gen_kwargs.update(dict(system_prompt=system_prompt))
|
1151 |
|
1152 |
-
# replicate handles prompting
|
|
|
1153 |
if stream_output:
|
1154 |
callbacks = [StreamingGradioCallbackHandler()]
|
1155 |
streamer = callbacks[0] if stream_output else None
|
@@ -1188,8 +1108,8 @@ def get_llm(use_openai_model=False,
|
|
1188 |
if inf_type == 'openai_chat' or inf_type == 'vllm_chat':
|
1189 |
cls = H2OChatOpenAI
|
1190 |
# FIXME: Support context, iinput
|
1191 |
-
if inf_type == 'vllm_chat':
|
1192 |
-
|
1193 |
openai_api_key = openai.api_key
|
1194 |
elif inf_type == 'openai_azure_chat':
|
1195 |
cls = H2OAzureChatOpenAI
|
@@ -1248,8 +1168,6 @@ def get_llm(use_openai_model=False,
|
|
1248 |
logit_bias=None if inf_type == 'vllm' else {},
|
1249 |
max_retries=6,
|
1250 |
streaming=stream_output,
|
1251 |
-
system_prompt=system_prompt,
|
1252 |
-
# chat_conversation=chat_conversation, # don't do here, not token aware
|
1253 |
**kwargs_extra
|
1254 |
)
|
1255 |
streamer = callbacks[0] if stream_output else None
|
@@ -3582,6 +3500,7 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
|
|
3582 |
prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=chat, stream_output=stream_output,
|
3583 |
system_prompt=system_prompt)
|
3584 |
|
|
|
3585 |
scores = []
|
3586 |
chain = None
|
3587 |
|
@@ -3598,8 +3517,8 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
|
|
3598 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
3599 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
3600 |
docs, chain, scores, \
|
3601 |
-
num_docs_before_cut, \
|
3602 |
-
use_llm_if_no_docs, top_k_docs_max_show = \
|
3603 |
get_chain(**sim_kwargs)
|
3604 |
if document_subset in non_query_commands:
|
3605 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
@@ -3620,21 +3539,23 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
|
|
3620 |
ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs)
|
3621 |
yield dict(prompt=prompt_basic, response=formatted_doc_chunks, sources=extra, num_prompt_tokens=0)
|
3622 |
return
|
3623 |
-
if
|
3624 |
-
if not docs
|
3625 |
-
|
3626 |
-
|
3627 |
-
|
3628 |
-
|
3629 |
-
|
3630 |
-
|
|
|
|
|
3631 |
extra = ''
|
3632 |
yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
|
3633 |
return
|
3634 |
|
3635 |
-
|
3636 |
-
|
3637 |
-
|
3638 |
return
|
3639 |
|
3640 |
# context stuff similar to used in evaluate()
|
@@ -3735,8 +3656,7 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
|
|
3735 |
prompt = prompt_basic
|
3736 |
num_prompt_tokens = get_token_count(prompt, tokenizer)
|
3737 |
|
3738 |
-
if
|
3739 |
-
# if no docs, then no sources to cite
|
3740 |
ret = answer
|
3741 |
extra = ''
|
3742 |
yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens)
|
@@ -3895,7 +3815,8 @@ def get_chain(query=None,
|
|
3895 |
if text_context_list is None:
|
3896 |
text_context_list = []
|
3897 |
|
3898 |
-
#
|
|
|
3899 |
query_action = langchain_action == LangChainAction.QUERY.value
|
3900 |
summarize_action = langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
3901 |
LangChainAction.SUMMARIZE_ALL.value,
|
@@ -3927,6 +3848,8 @@ def get_chain(query=None,
|
|
3927 |
add_search_to_context &= len(docs_search) > 0
|
3928 |
top_k_docs_max_show = max(top_k_docs_max_show, len(docs_search))
|
3929 |
|
|
|
|
|
3930 |
use_llm_if_no_docs = True
|
3931 |
|
3932 |
from src.output_parser import H2OMRKLOutputParser
|
@@ -3954,9 +3877,10 @@ def get_chain(query=None,
|
|
3954 |
|
3955 |
docs = []
|
3956 |
scores = []
|
|
|
3957 |
num_docs_before_cut = 0
|
3958 |
use_llm_if_no_docs = True
|
3959 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
3960 |
|
3961 |
if LangChainAgent.COLLECTION.value in langchain_agents:
|
3962 |
output_parser = H2OMRKLOutputParser()
|
@@ -3975,9 +3899,10 @@ def get_chain(query=None,
|
|
3975 |
|
3976 |
docs = []
|
3977 |
scores = []
|
|
|
3978 |
num_docs_before_cut = 0
|
3979 |
use_llm_if_no_docs = True
|
3980 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
3981 |
|
3982 |
if LangChainAgent.PYTHON.value in langchain_agents and inference_server.startswith('openai'):
|
3983 |
chain = create_python_agent(
|
@@ -3993,9 +3918,10 @@ def get_chain(query=None,
|
|
3993 |
|
3994 |
docs = []
|
3995 |
scores = []
|
|
|
3996 |
num_docs_before_cut = 0
|
3997 |
use_llm_if_no_docs = True
|
3998 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
3999 |
|
4000 |
if LangChainAgent.PANDAS.value in langchain_agents and inference_server.startswith('openai_chat'):
|
4001 |
# FIXME: DATA
|
@@ -4012,9 +3938,10 @@ def get_chain(query=None,
|
|
4012 |
|
4013 |
docs = []
|
4014 |
scores = []
|
|
|
4015 |
num_docs_before_cut = 0
|
4016 |
use_llm_if_no_docs = True
|
4017 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
4018 |
|
4019 |
if isinstance(document_choice, str):
|
4020 |
document_choice = [document_choice]
|
@@ -4044,9 +3971,10 @@ def get_chain(query=None,
|
|
4044 |
|
4045 |
docs = []
|
4046 |
scores = []
|
|
|
4047 |
num_docs_before_cut = 0
|
4048 |
use_llm_if_no_docs = True
|
4049 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
4050 |
|
4051 |
if isinstance(document_choice, str):
|
4052 |
document_choice = [document_choice]
|
@@ -4057,7 +3985,7 @@ def get_chain(query=None,
|
|
4057 |
document_choice_agent = [x for x in document_choice_agent if x.endswith('.csv')]
|
4058 |
if LangChainAgent.CSV.value in langchain_agents and len(document_choice_agent) == 1 and document_choice_agent[
|
4059 |
0].endswith(
|
4060 |
-
|
4061 |
data_file = document_choice[0]
|
4062 |
if inference_server.startswith('openai_chat'):
|
4063 |
chain = create_csv_agent(
|
@@ -4078,9 +4006,19 @@ def get_chain(query=None,
|
|
4078 |
|
4079 |
docs = []
|
4080 |
scores = []
|
|
|
4081 |
num_docs_before_cut = 0
|
4082 |
use_llm_if_no_docs = True
|
4083 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4084 |
|
4085 |
# https://github.com/hwchase17/langchain/issues/1946
|
4086 |
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
@@ -4152,7 +4090,8 @@ def get_chain(query=None,
|
|
4152 |
pre_prompt_query, prompt_query,
|
4153 |
pre_prompt_summary, prompt_summary,
|
4154 |
langchain_action,
|
4155 |
-
|
|
|
4156 |
auto_reduce_chunks,
|
4157 |
got_db_docs,
|
4158 |
add_search_to_context)
|
@@ -4160,242 +4099,239 @@ def get_chain(query=None,
|
|
4160 |
max_input_tokens = get_max_input_tokens(llm=llm, tokenizer=tokenizer, inference_server=inference_server,
|
4161 |
model_name=model_name, max_new_tokens=max_new_tokens)
|
4162 |
|
4163 |
-
if
|
4164 |
-
|
4165 |
-
|
4166 |
-
|
4167 |
-
|
4168 |
-
|
4169 |
-
|
4170 |
-
|
4171 |
-
|
4172 |
-
|
4173 |
-
|
4174 |
-
|
4175 |
-
|
4176 |
-
|
4177 |
-
|
4178 |
-
|
4179 |
-
|
4180 |
-
|
4181 |
-
|
4182 |
-
|
4183 |
-
|
4184 |
-
|
4185 |
-
{"filter": {"chunk_id": {"$
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4186 |
else:
|
4187 |
-
|
4188 |
-
|
4189 |
-
|
4190 |
-
filter_kwargs = {}
|
4191 |
-
|
|
|
|
|
|
|
|
|
4192 |
or_filter = [
|
4193 |
-
{"
|
4194 |
-
|
|
|
4195 |
for x in document_choice]
|
4196 |
filter_kwargs = dict(filter={"$or": or_filter})
|
4197 |
-
|
4198 |
-
|
|
|
|
|
|
|
|
|
4199 |
one_filter = \
|
4200 |
-
[{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {
|
4201 |
-
|
4202 |
-
|
4203 |
-
"$eq": -1}}
|
4204 |
for x in document_choice][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4205 |
|
4206 |
-
|
4207 |
-
|
4208 |
-
|
4209 |
-
|
4210 |
-
|
4211 |
-
0] == DocumentChoice.ALL.value:
|
4212 |
-
filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
|
4213 |
-
{"filter": {"chunk_id": {"$eq": -1}}}
|
4214 |
-
filter_kwargs_backup = {"filter": {"chunk_id": {"$gte": 0}}}
|
4215 |
-
elif len(document_choice) >= 2:
|
4216 |
-
if document_choice[0] == DocumentChoice.ALL.value:
|
4217 |
-
document_choice = document_choice[1:]
|
4218 |
-
or_filter = [
|
4219 |
-
{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
|
4220 |
-
"chunk_id": {
|
4221 |
-
"$eq": -1}}
|
4222 |
-
for x in document_choice]
|
4223 |
-
filter_kwargs = dict(filter={"$or": or_filter})
|
4224 |
-
or_filter_backup = [
|
4225 |
-
{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
|
4226 |
-
for x in document_choice]
|
4227 |
-
filter_kwargs_backup = dict(filter={"$or": or_filter_backup})
|
4228 |
-
elif len(document_choice) == 1:
|
4229 |
-
# degenerate UX bug in chroma
|
4230 |
-
one_filter = \
|
4231 |
-
[{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
|
4232 |
-
"chunk_id": {
|
4233 |
-
"$eq": -1}}
|
4234 |
-
for x in document_choice][0]
|
4235 |
-
filter_kwargs = dict(filter=one_filter)
|
4236 |
-
one_filter_backup = \
|
4237 |
-
[{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
|
4238 |
-
for x in document_choice][0]
|
4239 |
-
filter_kwargs_backup = dict(filter=one_filter_backup)
|
4240 |
-
else:
|
4241 |
-
# shouldn't reach
|
4242 |
-
filter_kwargs = {}
|
4243 |
-
filter_kwargs_backup = {}
|
4244 |
-
|
4245 |
-
if document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
|
4246 |
-
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs,
|
4247 |
-
text_context_list=text_context_list)
|
4248 |
-
if len(db_documents) == 0 and filter_kwargs_backup:
|
4249 |
-
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs_backup,
|
4250 |
text_context_list=text_context_list)
|
4251 |
-
|
4252 |
-
|
4253 |
-
|
4254 |
-
|
4255 |
-
|
4256 |
-
|
4257 |
-
|
4258 |
-
|
4259 |
-
|
4260 |
-
|
4261 |
-
|
4262 |
-
|
4263 |
-
|
4264 |
-
|
4265 |
-
|
4266 |
-
|
4267 |
-
|
4268 |
-
|
4269 |
-
|
4270 |
-
|
4271 |
-
|
4272 |
-
|
4273 |
-
]
|
4274 |
-
if len(docs_with_score2) == 0 and len(docs_with_score) > 0:
|
4275 |
-
# old database without chunk_id, migration added 0 but didn't make -1 as that would be expensive
|
4276 |
-
# just do again and relax filter, let summarize operate on actual chunks if nothing else
|
4277 |
docs_with_score2 = [x for hx, cx, x in
|
4278 |
-
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score),
|
4279 |
-
|
4280 |
]
|
4281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4282 |
|
4283 |
-
|
4284 |
-
|
4285 |
-
|
4286 |
-
|
4287 |
-
|
4288 |
-
|
4289 |
-
|
4290 |
-
docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs, db, db_type,
|
4291 |
-
text_context_list=text_context_list,
|
4292 |
-
verbose=verbose)
|
4293 |
-
if len(docs_with_score) == 0 and filter_kwargs_backup:
|
4294 |
-
docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs_backup, db,
|
4295 |
-
db_type,
|
4296 |
text_context_list=text_context_list,
|
4297 |
verbose=verbose)
|
4298 |
-
|
4299 |
-
|
4300 |
-
|
4301 |
-
|
4302 |
-
|
4303 |
-
|
4304 |
-
|
4305 |
-
|
4306 |
-
|
4307 |
-
|
4308 |
-
|
4309 |
-
|
4310 |
-
|
4311 |
-
|
4312 |
-
|
4313 |
-
|
4314 |
-
|
4315 |
-
|
4316 |
-
|
4317 |
-
|
4318 |
-
|
4319 |
-
|
4320 |
-
|
4321 |
-
|
4322 |
-
|
4323 |
-
|
4324 |
-
|
4325 |
-
|
4326 |
-
|
4327 |
-
|
4328 |
-
|
4329 |
-
|
4330 |
-
|
4331 |
-
|
4332 |
-
|
4333 |
-
|
4334 |
-
|
4335 |
-
|
4336 |
-
|
4337 |
-
|
4338 |
-
|
4339 |
-
|
4340 |
-
|
4341 |
-
|
4342 |
-
|
4343 |
-
assert external_handle_chat_conversation, "Should be handling only externally"
|
4344 |
-
llm.chat_conversation = chat_conversation[chat_index:]
|
4345 |
-
if hasattr(llm, 'context'):
|
4346 |
-
llm.context = context
|
4347 |
-
if hasattr(llm, 'iinput'):
|
4348 |
-
llm.iinput = iinput
|
4349 |
-
# avoid craziness
|
4350 |
-
if 0 < top_k_docs_trial < max_chunks:
|
4351 |
# avoid craziness
|
4352 |
-
if
|
4353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4354 |
else:
|
4355 |
-
|
4356 |
-
elif top_k_docs_trial >= max_chunks:
|
4357 |
-
top_k_docs = max_chunks
|
4358 |
-
if top_k_docs > 0:
|
4359 |
-
docs_with_score = docs_with_score[:top_k_docs]
|
4360 |
-
elif one_doc_size is not None:
|
4361 |
-
docs_with_score = [docs_with_score[0][:one_doc_size]]
|
4362 |
else:
|
4363 |
-
|
4364 |
-
|
4365 |
-
|
4366 |
-
|
4367 |
-
|
4368 |
-
|
4369 |
-
text_context_list=[x[0].page_content for x in docs_with_score],
|
4370 |
-
max_input_tokens=total_tokens_for_docs)
|
4371 |
|
4372 |
-
|
4373 |
-
|
4374 |
-
# put most relevant chunks closest to question,
|
4375 |
-
# esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
|
4376 |
-
# BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
|
4377 |
-
if docs_ordering_type in ['best_first']:
|
4378 |
-
pass
|
4379 |
-
elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']:
|
4380 |
-
docs_with_score.reverse()
|
4381 |
-
elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']:
|
4382 |
-
docs_with_score = reverse_ucurve_list(docs_with_score)
|
4383 |
-
else:
|
4384 |
-
raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type)
|
4385 |
|
4386 |
-
|
4387 |
-
|
4388 |
-
|
4389 |
-
|
4390 |
-
|
4391 |
-
|
4392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4393 |
|
4394 |
-
|
|
|
|
|
4395 |
|
4396 |
if document_subset in non_query_commands:
|
4397 |
-
# no LLM use
|
4398 |
-
return docs, None, [], num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
4399 |
|
4400 |
# FIXME: WIP
|
4401 |
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
@@ -4413,6 +4349,7 @@ def get_chain(query=None,
|
|
4413 |
|
4414 |
if len(docs) == 0:
|
4415 |
# avoid context == in prompt then
|
|
|
4416 |
template = template_if_no_docs
|
4417 |
|
4418 |
got_db_docs = got_db_docs and len(text_context_list) < len(docs)
|
@@ -4424,7 +4361,8 @@ def get_chain(query=None,
|
|
4424 |
pre_prompt_query, prompt_query,
|
4425 |
pre_prompt_summary, prompt_summary,
|
4426 |
langchain_action,
|
4427 |
-
|
|
|
4428 |
auto_reduce_chunks,
|
4429 |
got_db_docs,
|
4430 |
add_search_to_context)
|
@@ -4442,7 +4380,10 @@ def get_chain(query=None,
|
|
4442 |
else:
|
4443 |
# only if use_openai_model = True, unused normally except in testing
|
4444 |
chain = load_qa_with_sources_chain(llm)
|
4445 |
-
|
|
|
|
|
|
|
4446 |
target = wrapped_partial(chain, chain_kwargs)
|
4447 |
elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
4448 |
LangChainAction.SUMMARIZE_REFINE,
|
@@ -4486,7 +4427,7 @@ def get_chain(query=None,
|
|
4486 |
else:
|
4487 |
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
4488 |
|
4489 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
4490 |
|
4491 |
|
4492 |
def get_max_model_length(llm=None, tokenizer=None, inference_server=None, model_name=None):
|
@@ -4532,11 +4473,11 @@ def get_tokenizer(db=None, llm=None, tokenizer=None, inference_server=None, use_
|
|
4532 |
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
|
4533 |
# more accurate
|
4534 |
return llm.pipeline.tokenizer
|
4535 |
-
elif hasattr(llm, 'tokenizer')
|
4536 |
# e.g. TGI client mode etc.
|
4537 |
return llm.tokenizer
|
4538 |
elif inference_server in ['openai', 'openai_chat', 'openai_azure',
|
4539 |
-
'openai_azure_chat']
|
4540 |
return tokenizer
|
4541 |
elif isinstance(tokenizer, FakeTokenizer):
|
4542 |
return tokenizer
|
@@ -4559,7 +4500,8 @@ def get_template(query, iinput,
|
|
4559 |
pre_prompt_query, prompt_query,
|
4560 |
pre_prompt_summary, prompt_summary,
|
4561 |
langchain_action,
|
4562 |
-
|
|
|
4563 |
auto_reduce_chunks,
|
4564 |
got_db_docs,
|
4565 |
add_search_to_context):
|
@@ -4581,7 +4523,7 @@ def get_template(query, iinput,
|
|
4581 |
if langchain_action == LangChainAction.QUERY.value:
|
4582 |
if iinput:
|
4583 |
query = "%s\n%s" % (query, iinput)
|
4584 |
-
if not
|
4585 |
template_if_no_docs = template = """{context}{question}"""
|
4586 |
else:
|
4587 |
template = """%s
|
|
|
29 |
|
30 |
from joblib import delayed
|
31 |
from langchain.callbacks import streaming_stdout
|
|
|
32 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
33 |
from langchain.llms.huggingface_pipeline import VALID_TASKS
|
34 |
from langchain.llms.utils import enforce_stop_tokens
|
35 |
+
from langchain.schema import LLMResult, Generation
|
36 |
from langchain.tools import PythonREPLTool
|
37 |
from langchain.tools.json.tool import JsonSpec
|
38 |
from tqdm import tqdm
|
|
|
944 |
assert self.tokenizer is not None
|
945 |
from h2oai_pipeline import H2OTextGenerationPipeline
|
946 |
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
947 |
+
# Note Replicate handles the prompting of the specific model
|
|
|
|
|
|
|
948 |
return super()._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
|
949 |
|
950 |
def get_token_ids(self, text: str) -> List[int]:
|
|
|
953 |
# return _get_token_ids_default_method(text)
|
954 |
|
955 |
|
956 |
+
class H2OChatOpenAI(ChatOpenAI):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
957 |
@classmethod
|
958 |
def _all_required_field_names(cls) -> Set:
|
959 |
_all_required_field_names = super(ChatOpenAI, cls)._all_required_field_names()
|
960 |
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
|
961 |
return _all_required_field_names
|
962 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
963 |
|
964 |
+
class H2OAzureChatOpenAI(AzureChatOpenAI):
|
965 |
@classmethod
|
966 |
def _all_required_field_names(cls) -> Set:
|
967 |
_all_required_field_names = super(AzureChatOpenAI, cls)._all_required_field_names()
|
968 |
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
|
969 |
return _all_required_field_names
|
970 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
971 |
|
972 |
class H2OAzureOpenAI(AzureOpenAI):
|
973 |
@classmethod
|
|
|
1052 |
if 'meta/llama' in model_string:
|
1053 |
temperature = max(0.01, temperature if do_sample else 0)
|
1054 |
else:
|
1055 |
+
temperature =temperature if do_sample else 0
|
1056 |
gen_kwargs = dict(temperature=temperature,
|
1057 |
seed=1234,
|
1058 |
max_length=max_new_tokens, # langchain
|
|
|
1068 |
if system_prompt:
|
1069 |
gen_kwargs.update(dict(system_prompt=system_prompt))
|
1070 |
|
1071 |
+
# replicate handles prompting, so avoid get_response() filter
|
1072 |
+
prompter.prompt_type = 'plain'
|
1073 |
if stream_output:
|
1074 |
callbacks = [StreamingGradioCallbackHandler()]
|
1075 |
streamer = callbacks[0] if stream_output else None
|
|
|
1108 |
if inf_type == 'openai_chat' or inf_type == 'vllm_chat':
|
1109 |
cls = H2OChatOpenAI
|
1110 |
# FIXME: Support context, iinput
|
1111 |
+
# if inf_type == 'vllm_chat':
|
1112 |
+
# kwargs_extra.update(dict(tokenizer=tokenizer))
|
1113 |
openai_api_key = openai.api_key
|
1114 |
elif inf_type == 'openai_azure_chat':
|
1115 |
cls = H2OAzureChatOpenAI
|
|
|
1168 |
logit_bias=None if inf_type == 'vllm' else {},
|
1169 |
max_retries=6,
|
1170 |
streaming=stream_output,
|
|
|
|
|
1171 |
**kwargs_extra
|
1172 |
)
|
1173 |
streamer = callbacks[0] if stream_output else None
|
|
|
3500 |
prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=chat, stream_output=stream_output,
|
3501 |
system_prompt=system_prompt)
|
3502 |
|
3503 |
+
use_docs_planned = False
|
3504 |
scores = []
|
3505 |
chain = None
|
3506 |
|
|
|
3517 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
3518 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
3519 |
docs, chain, scores, \
|
3520 |
+
use_docs_planned, num_docs_before_cut, \
|
3521 |
+
use_llm_if_no_docs, llm_mode, top_k_docs_max_show = \
|
3522 |
get_chain(**sim_kwargs)
|
3523 |
if document_subset in non_query_commands:
|
3524 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
|
|
3539 |
ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs)
|
3540 |
yield dict(prompt=prompt_basic, response=formatted_doc_chunks, sources=extra, num_prompt_tokens=0)
|
3541 |
return
|
3542 |
+
if not use_llm_if_no_docs:
|
3543 |
+
if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
3544 |
+
LangChainAction.SUMMARIZE_ALL.value,
|
3545 |
+
LangChainAction.SUMMARIZE_REFINE.value]:
|
3546 |
+
ret = 'No relevant documents to summarize.' if num_docs_before_cut else 'No documents to summarize.'
|
3547 |
+
extra = ''
|
3548 |
+
yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
|
3549 |
+
return
|
3550 |
+
if not docs and not llm_mode:
|
3551 |
+
ret = 'No relevant documents to query (for chatting with LLM, pick Resources->Collections->LLM).' if num_docs_before_cut else 'No documents to query (for chatting with LLM, pick Resources->Collections->LLM).'
|
3552 |
extra = ''
|
3553 |
yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
|
3554 |
return
|
3555 |
|
3556 |
+
if chain is None and not langchain_only_model:
|
3557 |
+
# here if no docs at all and not HF type
|
3558 |
+
# can only return if HF type
|
3559 |
return
|
3560 |
|
3561 |
# context stuff similar to used in evaluate()
|
|
|
3656 |
prompt = prompt_basic
|
3657 |
num_prompt_tokens = get_token_count(prompt, tokenizer)
|
3658 |
|
3659 |
+
if not use_docs_planned:
|
|
|
3660 |
ret = answer
|
3661 |
extra = ''
|
3662 |
yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens)
|
|
|
3815 |
if text_context_list is None:
|
3816 |
text_context_list = []
|
3817 |
|
3818 |
+
# default value:
|
3819 |
+
llm_mode = langchain_mode in ['Disabled', 'LLM'] and len(text_context_list) == 0
|
3820 |
query_action = langchain_action == LangChainAction.QUERY.value
|
3821 |
summarize_action = langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
3822 |
LangChainAction.SUMMARIZE_ALL.value,
|
|
|
3848 |
add_search_to_context &= len(docs_search) > 0
|
3849 |
top_k_docs_max_show = max(top_k_docs_max_show, len(docs_search))
|
3850 |
|
3851 |
+
if len(text_context_list) > 0:
|
3852 |
+
llm_mode = False
|
3853 |
use_llm_if_no_docs = True
|
3854 |
|
3855 |
from src.output_parser import H2OMRKLOutputParser
|
|
|
3877 |
|
3878 |
docs = []
|
3879 |
scores = []
|
3880 |
+
use_docs_planned = False
|
3881 |
num_docs_before_cut = 0
|
3882 |
use_llm_if_no_docs = True
|
3883 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
3884 |
|
3885 |
if LangChainAgent.COLLECTION.value in langchain_agents:
|
3886 |
output_parser = H2OMRKLOutputParser()
|
|
|
3899 |
|
3900 |
docs = []
|
3901 |
scores = []
|
3902 |
+
use_docs_planned = False
|
3903 |
num_docs_before_cut = 0
|
3904 |
use_llm_if_no_docs = True
|
3905 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
3906 |
|
3907 |
if LangChainAgent.PYTHON.value in langchain_agents and inference_server.startswith('openai'):
|
3908 |
chain = create_python_agent(
|
|
|
3918 |
|
3919 |
docs = []
|
3920 |
scores = []
|
3921 |
+
use_docs_planned = False
|
3922 |
num_docs_before_cut = 0
|
3923 |
use_llm_if_no_docs = True
|
3924 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
3925 |
|
3926 |
if LangChainAgent.PANDAS.value in langchain_agents and inference_server.startswith('openai_chat'):
|
3927 |
# FIXME: DATA
|
|
|
3938 |
|
3939 |
docs = []
|
3940 |
scores = []
|
3941 |
+
use_docs_planned = False
|
3942 |
num_docs_before_cut = 0
|
3943 |
use_llm_if_no_docs = True
|
3944 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
3945 |
|
3946 |
if isinstance(document_choice, str):
|
3947 |
document_choice = [document_choice]
|
|
|
3971 |
|
3972 |
docs = []
|
3973 |
scores = []
|
3974 |
+
use_docs_planned = False
|
3975 |
num_docs_before_cut = 0
|
3976 |
use_llm_if_no_docs = True
|
3977 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
3978 |
|
3979 |
if isinstance(document_choice, str):
|
3980 |
document_choice = [document_choice]
|
|
|
3985 |
document_choice_agent = [x for x in document_choice_agent if x.endswith('.csv')]
|
3986 |
if LangChainAgent.CSV.value in langchain_agents and len(document_choice_agent) == 1 and document_choice_agent[
|
3987 |
0].endswith(
|
3988 |
+
'.csv'):
|
3989 |
data_file = document_choice[0]
|
3990 |
if inference_server.startswith('openai_chat'):
|
3991 |
chain = create_csv_agent(
|
|
|
4006 |
|
4007 |
docs = []
|
4008 |
scores = []
|
4009 |
+
use_docs_planned = False
|
4010 |
num_docs_before_cut = 0
|
4011 |
use_llm_if_no_docs = True
|
4012 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
4013 |
+
|
4014 |
+
# determine whether use of context out of docs is planned
|
4015 |
+
if not use_openai_model and prompt_type not in ['plain'] or langchain_only_model:
|
4016 |
+
if llm_mode:
|
4017 |
+
use_docs_planned = False
|
4018 |
+
else:
|
4019 |
+
use_docs_planned = True
|
4020 |
+
else:
|
4021 |
+
use_docs_planned = True
|
4022 |
|
4023 |
# https://github.com/hwchase17/langchain/issues/1946
|
4024 |
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
|
|
4090 |
pre_prompt_query, prompt_query,
|
4091 |
pre_prompt_summary, prompt_summary,
|
4092 |
langchain_action,
|
4093 |
+
llm_mode,
|
4094 |
+
use_docs_planned,
|
4095 |
auto_reduce_chunks,
|
4096 |
got_db_docs,
|
4097 |
add_search_to_context)
|
|
|
4099 |
max_input_tokens = get_max_input_tokens(llm=llm, tokenizer=tokenizer, inference_server=inference_server,
|
4100 |
model_name=model_name, max_new_tokens=max_new_tokens)
|
4101 |
|
4102 |
+
if (db or text_context_list) and use_docs_planned:
|
4103 |
+
if hasattr(db, '_persist_directory'):
|
4104 |
+
lock_file = get_db_lock_file(db, lock_type='sim')
|
4105 |
+
else:
|
4106 |
+
base_path = 'locks'
|
4107 |
+
base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
|
4108 |
+
name_path = "sim.lock"
|
4109 |
+
lock_file = os.path.join(base_path, name_path)
|
4110 |
+
|
4111 |
+
if not (isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db)):
|
4112 |
+
# only chroma supports filtering
|
4113 |
+
filter_kwargs = {}
|
4114 |
+
filter_kwargs_backup = {}
|
4115 |
+
else:
|
4116 |
+
import logging
|
4117 |
+
logging.getLogger("chromadb").setLevel(logging.ERROR)
|
4118 |
+
assert document_choice is not None, "Document choice was None"
|
4119 |
+
if isinstance(db, Chroma):
|
4120 |
+
filter_kwargs_backup = {} # shouldn't ever need backup
|
4121 |
+
# chroma >= 0.4
|
4122 |
+
if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
|
4123 |
+
0] == DocumentChoice.ALL.value:
|
4124 |
+
filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
|
4125 |
+
{"filter": {"chunk_id": {"$eq": -1}}}
|
4126 |
+
else:
|
4127 |
+
if document_choice[0] == DocumentChoice.ALL.value:
|
4128 |
+
document_choice = document_choice[1:]
|
4129 |
+
if len(document_choice) == 0:
|
4130 |
+
filter_kwargs = {}
|
4131 |
+
elif len(document_choice) > 1:
|
4132 |
+
or_filter = [
|
4133 |
+
{"$and": [dict(source={"$eq": x}), dict(chunk_id={"$gte": 0})]} if query_action else {
|
4134 |
+
"$and": [dict(source={"$eq": x}), dict(chunk_id={"$eq": -1})]}
|
4135 |
+
for x in document_choice]
|
4136 |
+
filter_kwargs = dict(filter={"$or": or_filter})
|
4137 |
+
else:
|
4138 |
+
# still chromadb UX bug, have to do different thing for 1 vs. 2+ docs when doing filter
|
4139 |
+
one_filter = \
|
4140 |
+
[{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {
|
4141 |
+
"source": {"$eq": x},
|
4142 |
+
"chunk_id": {
|
4143 |
+
"$eq": -1}}
|
4144 |
+
for x in document_choice][0]
|
4145 |
+
|
4146 |
+
filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']),
|
4147 |
+
dict(chunk_id=one_filter['chunk_id'])]})
|
4148 |
else:
|
4149 |
+
# migration for chroma < 0.4
|
4150 |
+
if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
|
4151 |
+
0] == DocumentChoice.ALL.value:
|
4152 |
+
filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
|
4153 |
+
{"filter": {"chunk_id": {"$eq": -1}}}
|
4154 |
+
filter_kwargs_backup = {"filter": {"chunk_id": {"$gte": 0}}}
|
4155 |
+
elif len(document_choice) >= 2:
|
4156 |
+
if document_choice[0] == DocumentChoice.ALL.value:
|
4157 |
+
document_choice = document_choice[1:]
|
4158 |
or_filter = [
|
4159 |
+
{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
|
4160 |
+
"chunk_id": {
|
4161 |
+
"$eq": -1}}
|
4162 |
for x in document_choice]
|
4163 |
filter_kwargs = dict(filter={"$or": or_filter})
|
4164 |
+
or_filter_backup = [
|
4165 |
+
{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
|
4166 |
+
for x in document_choice]
|
4167 |
+
filter_kwargs_backup = dict(filter={"$or": or_filter_backup})
|
4168 |
+
elif len(document_choice) == 1:
|
4169 |
+
# degenerate UX bug in chroma
|
4170 |
one_filter = \
|
4171 |
+
[{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
|
4172 |
+
"chunk_id": {
|
4173 |
+
"$eq": -1}}
|
|
|
4174 |
for x in document_choice][0]
|
4175 |
+
filter_kwargs = dict(filter=one_filter)
|
4176 |
+
one_filter_backup = \
|
4177 |
+
[{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
|
4178 |
+
for x in document_choice][0]
|
4179 |
+
filter_kwargs_backup = dict(filter=one_filter_backup)
|
4180 |
+
else:
|
4181 |
+
# shouldn't reach
|
4182 |
+
filter_kwargs = {}
|
4183 |
+
filter_kwargs_backup = {}
|
4184 |
|
4185 |
+
if llm_mode:
|
4186 |
+
docs = []
|
4187 |
+
scores = []
|
4188 |
+
elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
|
4189 |
+
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4190 |
text_context_list=text_context_list)
|
4191 |
+
if len(db_documents) == 0 and filter_kwargs_backup:
|
4192 |
+
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs_backup,
|
4193 |
+
text_context_list=text_context_list)
|
4194 |
+
|
4195 |
+
if top_k_docs == -1:
|
4196 |
+
top_k_docs = len(db_documents)
|
4197 |
+
# similar to langchain's chroma's _results_to_docs_and_scores
|
4198 |
+
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
4199 |
+
for result in zip(db_documents, db_metadatas)]
|
4200 |
+
# set in metadata original order of docs
|
4201 |
+
[x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)]
|
4202 |
+
|
4203 |
+
# order documents
|
4204 |
+
doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas]
|
4205 |
+
if query_action:
|
4206 |
+
doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
|
4207 |
+
docs_with_score2 = [x for hx, cx, x in
|
4208 |
+
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
|
4209 |
+
if cx >= 0]
|
4210 |
+
else:
|
4211 |
+
assert summarize_action
|
4212 |
+
doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas]
|
|
|
|
|
|
|
|
|
4213 |
docs_with_score2 = [x for hx, cx, x in
|
4214 |
+
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
|
4215 |
+
if cx == -1
|
4216 |
]
|
4217 |
+
if len(docs_with_score2) == 0 and len(docs_with_score) > 0:
|
4218 |
+
# old database without chunk_id, migration added 0 but didn't make -1 as that would be expensive
|
4219 |
+
# just do again and relax filter, let summarize operate on actual chunks if nothing else
|
4220 |
+
docs_with_score2 = [x for hx, cx, x in
|
4221 |
+
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score),
|
4222 |
+
key=lambda x: (x[0], x[1]))
|
4223 |
+
]
|
4224 |
+
docs_with_score = docs_with_score2
|
4225 |
|
4226 |
+
docs_with_score = docs_with_score[:top_k_docs]
|
4227 |
+
docs = [x[0] for x in docs_with_score]
|
4228 |
+
scores = [x[1] for x in docs_with_score]
|
4229 |
+
num_docs_before_cut = len(docs)
|
4230 |
+
else:
|
4231 |
+
with filelock.FileLock(lock_file):
|
4232 |
+
docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs, db, db_type,
|
|
|
|
|
|
|
|
|
|
|
|
|
4233 |
text_context_list=text_context_list,
|
4234 |
verbose=verbose)
|
4235 |
+
if len(docs_with_score) == 0 and filter_kwargs_backup:
|
4236 |
+
docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs_backup, db,
|
4237 |
+
db_type,
|
4238 |
+
text_context_list=text_context_list,
|
4239 |
+
verbose=verbose)
|
4240 |
+
|
4241 |
+
tokenizer = get_tokenizer(db=db, llm=llm, tokenizer=tokenizer, inference_server=inference_server,
|
4242 |
+
use_openai_model=use_openai_model,
|
4243 |
+
db_type=db_type)
|
4244 |
+
# NOTE: if map_reduce, then no need to auto reduce chunks
|
4245 |
+
if query_action and (top_k_docs == -1 or auto_reduce_chunks):
|
4246 |
+
top_k_docs_tokenize = 100
|
4247 |
+
docs_with_score = docs_with_score[:top_k_docs_tokenize]
|
4248 |
+
|
4249 |
+
prompt_no_docs = template.format(context='', question=query)
|
4250 |
+
|
4251 |
+
model_max_length = tokenizer.model_max_length
|
4252 |
+
chat = True # FIXME?
|
4253 |
+
|
4254 |
+
# first docs_with_score are most important with highest score
|
4255 |
+
full_prompt, \
|
4256 |
+
instruction, iinput, context, \
|
4257 |
+
num_prompt_tokens, max_new_tokens, \
|
4258 |
+
num_prompt_tokens0, num_prompt_tokens_actual, \
|
4259 |
+
chat_index, top_k_docs_trial, one_doc_size = \
|
4260 |
+
get_limited_prompt(prompt_no_docs,
|
4261 |
+
iinput,
|
4262 |
+
tokenizer,
|
4263 |
+
prompter=prompter,
|
4264 |
+
inference_server=inference_server,
|
4265 |
+
prompt_type=prompt_type,
|
4266 |
+
prompt_dict=prompt_dict,
|
4267 |
+
chat=chat,
|
4268 |
+
max_new_tokens=max_new_tokens,
|
4269 |
+
system_prompt=system_prompt,
|
4270 |
+
context=context,
|
4271 |
+
chat_conversation=chat_conversation,
|
4272 |
+
text_context_list=[x[0].page_content for x in docs_with_score],
|
4273 |
+
keep_sources_in_context=keep_sources_in_context,
|
4274 |
+
model_max_length=model_max_length,
|
4275 |
+
memory_restriction_level=memory_restriction_level,
|
4276 |
+
langchain_mode=langchain_mode,
|
4277 |
+
add_chat_history_to_context=add_chat_history_to_context,
|
4278 |
+
min_max_new_tokens=min_max_new_tokens,
|
4279 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4280 |
# avoid craziness
|
4281 |
+
if 0 < top_k_docs_trial < max_chunks:
|
4282 |
+
# avoid craziness
|
4283 |
+
if top_k_docs == -1:
|
4284 |
+
top_k_docs = top_k_docs_trial
|
4285 |
+
else:
|
4286 |
+
top_k_docs = min(top_k_docs, top_k_docs_trial)
|
4287 |
+
elif top_k_docs_trial >= max_chunks:
|
4288 |
+
top_k_docs = max_chunks
|
4289 |
+
if top_k_docs > 0:
|
4290 |
+
docs_with_score = docs_with_score[:top_k_docs]
|
4291 |
+
elif one_doc_size is not None:
|
4292 |
+
docs_with_score = [docs_with_score[0][:one_doc_size]]
|
4293 |
else:
|
4294 |
+
docs_with_score = []
|
|
|
|
|
|
|
|
|
|
|
|
|
4295 |
else:
|
4296 |
+
if total_tokens_for_docs is not None:
|
4297 |
+
# used to limit tokens for summarization, e.g. public instance
|
4298 |
+
top_k_docs, one_doc_size, num_doc_tokens = \
|
4299 |
+
get_docs_tokens(tokenizer,
|
4300 |
+
text_context_list=[x[0].page_content for x in docs_with_score],
|
4301 |
+
max_input_tokens=total_tokens_for_docs)
|
|
|
|
|
4302 |
|
4303 |
+
docs_with_score = docs_with_score[:top_k_docs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4304 |
|
4305 |
+
# put most relevant chunks closest to question,
|
4306 |
+
# esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
|
4307 |
+
# BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
|
4308 |
+
if docs_ordering_type in ['best_first']:
|
4309 |
+
pass
|
4310 |
+
elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']:
|
4311 |
+
docs_with_score.reverse()
|
4312 |
+
elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']:
|
4313 |
+
docs_with_score = reverse_ucurve_list(docs_with_score)
|
4314 |
+
else:
|
4315 |
+
raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type)
|
4316 |
+
|
4317 |
+
# cut off so no high distance docs/sources considered
|
4318 |
+
num_docs_before_cut = len(docs_with_score)
|
4319 |
+
docs = [x[0] for x in docs_with_score if x[1] < cut_distance]
|
4320 |
+
scores = [x[1] for x in docs_with_score if x[1] < cut_distance]
|
4321 |
+
if len(scores) > 0 and verbose:
|
4322 |
+
print("Distance: min: %s max: %s mean: %s median: %s" %
|
4323 |
+
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
|
4324 |
+
else:
|
4325 |
+
docs = []
|
4326 |
+
scores = []
|
4327 |
|
4328 |
+
if not docs and use_docs_planned and not langchain_only_model:
|
4329 |
+
# if HF type and have no docs, can bail out
|
4330 |
+
return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
4331 |
|
4332 |
if document_subset in non_query_commands:
|
4333 |
+
# no LLM use
|
4334 |
+
return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
4335 |
|
4336 |
# FIXME: WIP
|
4337 |
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
|
|
4349 |
|
4350 |
if len(docs) == 0:
|
4351 |
# avoid context == in prompt then
|
4352 |
+
use_docs_planned = False
|
4353 |
template = template_if_no_docs
|
4354 |
|
4355 |
got_db_docs = got_db_docs and len(text_context_list) < len(docs)
|
|
|
4361 |
pre_prompt_query, prompt_query,
|
4362 |
pre_prompt_summary, prompt_summary,
|
4363 |
langchain_action,
|
4364 |
+
llm_mode,
|
4365 |
+
use_docs_planned,
|
4366 |
auto_reduce_chunks,
|
4367 |
got_db_docs,
|
4368 |
add_search_to_context)
|
|
|
4380 |
else:
|
4381 |
# only if use_openai_model = True, unused normally except in testing
|
4382 |
chain = load_qa_with_sources_chain(llm)
|
4383 |
+
if not use_docs_planned:
|
4384 |
+
chain_kwargs = dict(input_documents=[], question=query)
|
4385 |
+
else:
|
4386 |
+
chain_kwargs = dict(input_documents=docs, question=query)
|
4387 |
target = wrapped_partial(chain, chain_kwargs)
|
4388 |
elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
4389 |
LangChainAction.SUMMARIZE_REFINE,
|
|
|
4427 |
else:
|
4428 |
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
4429 |
|
4430 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
4431 |
|
4432 |
|
4433 |
def get_max_model_length(llm=None, tokenizer=None, inference_server=None, model_name=None):
|
|
|
4473 |
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
|
4474 |
# more accurate
|
4475 |
return llm.pipeline.tokenizer
|
4476 |
+
elif hasattr(llm, 'tokenizer'):
|
4477 |
# e.g. TGI client mode etc.
|
4478 |
return llm.tokenizer
|
4479 |
elif inference_server in ['openai', 'openai_chat', 'openai_azure',
|
4480 |
+
'openai_azure_chat']:
|
4481 |
return tokenizer
|
4482 |
elif isinstance(tokenizer, FakeTokenizer):
|
4483 |
return tokenizer
|
|
|
4500 |
pre_prompt_query, prompt_query,
|
4501 |
pre_prompt_summary, prompt_summary,
|
4502 |
langchain_action,
|
4503 |
+
llm_mode,
|
4504 |
+
use_docs_planned,
|
4505 |
auto_reduce_chunks,
|
4506 |
got_db_docs,
|
4507 |
add_search_to_context):
|
|
|
4523 |
if langchain_action == LangChainAction.QUERY.value:
|
4524 |
if iinput:
|
4525 |
query = "%s\n%s" % (query, iinput)
|
4526 |
+
if llm_mode or not use_docs_planned:
|
4527 |
template_if_no_docs = template = """{context}{question}"""
|
4528 |
else:
|
4529 |
template = """%s
|
src/gradio_runner.py
CHANGED
@@ -737,7 +737,8 @@ def go_gradio(**kwargs):
|
|
737 |
visible=True,
|
738 |
elem_id="langchain_agents",
|
739 |
filterable=False)
|
740 |
-
visible_doc_track = upload_visible and kwargs['visible_doc_track'] and not kwargs[
|
|
|
741 |
row_doc_track = gr.Row(visible=visible_doc_track)
|
742 |
with row_doc_track:
|
743 |
if kwargs['langchain_mode'] in langchain_modes_non_db:
|
@@ -784,6 +785,9 @@ def go_gradio(**kwargs):
|
|
784 |
text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False,
|
785 |
show_copy_button=True)
|
786 |
|
|
|
|
|
|
|
787 |
# CHAT
|
788 |
col_chat = gr.Column(visible=kwargs['chat'])
|
789 |
with col_chat:
|
@@ -806,7 +810,8 @@ def go_gradio(**kwargs):
|
|
806 |
size="sm",
|
807 |
min_width=24,
|
808 |
file_types=['.' + x for x in file_types],
|
809 |
-
file_count="multiple"
|
|
|
810 |
|
811 |
submit_buttons = gr.Row(equal_height=False, visible=kwargs['visible_submit_buttons'])
|
812 |
with submit_buttons:
|
@@ -886,11 +891,9 @@ def go_gradio(**kwargs):
|
|
886 |
visible=sources_visible and allow_upload_to_user_data)
|
887 |
with gr.Column(scale=4):
|
888 |
pass
|
|
|
889 |
with gr.Row():
|
890 |
with gr.Column(scale=1):
|
891 |
-
visible_add_remove_collection = (allow_upload_to_user_data or
|
892 |
-
allow_upload_to_my_data) and \
|
893 |
-
kwargs['langchain_mode'] != 'Disabled'
|
894 |
add_placeholder = "e.g. UserData2, shared, user_path2" \
|
895 |
if not is_public else "e.g. MyData2, personal (optional)"
|
896 |
remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
|
@@ -1143,7 +1146,8 @@ def go_gradio(**kwargs):
|
|
1143 |
)
|
1144 |
min_max_new_tokens = gr.Slider(
|
1145 |
minimum=1, maximum=max_max_new_tokens, step=1,
|
1146 |
-
value=min(max_max_new_tokens, kwargs['min_max_new_tokens']),
|
|
|
1147 |
)
|
1148 |
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
|
1149 |
value=kwargs['early_stopping'], visible=max_beams > 1)
|
@@ -2881,7 +2885,6 @@ def go_gradio(**kwargs):
|
|
2881 |
history = args_list[-1]
|
2882 |
if not history:
|
2883 |
history = []
|
2884 |
-
# NOTE: For these, could check if None, then automatically use CLI values, but too complex behavior
|
2885 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
2886 |
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
|
2887 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
|
|
737 |
visible=True,
|
738 |
elem_id="langchain_agents",
|
739 |
filterable=False)
|
740 |
+
visible_doc_track = upload_visible and kwargs['visible_doc_track'] and not kwargs[
|
741 |
+
'large_file_count_mode']
|
742 |
row_doc_track = gr.Row(visible=visible_doc_track)
|
743 |
with row_doc_track:
|
744 |
if kwargs['langchain_mode'] in langchain_modes_non_db:
|
|
|
785 |
text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False,
|
786 |
show_copy_button=True)
|
787 |
|
788 |
+
visible_upload = (allow_upload_to_user_data or
|
789 |
+
allow_upload_to_my_data) and \
|
790 |
+
kwargs['langchain_mode'] != 'Disabled'
|
791 |
# CHAT
|
792 |
col_chat = gr.Column(visible=kwargs['chat'])
|
793 |
with col_chat:
|
|
|
810 |
size="sm",
|
811 |
min_width=24,
|
812 |
file_types=['.' + x for x in file_types],
|
813 |
+
file_count="multiple",
|
814 |
+
visible=visible_upload)
|
815 |
|
816 |
submit_buttons = gr.Row(equal_height=False, visible=kwargs['visible_submit_buttons'])
|
817 |
with submit_buttons:
|
|
|
891 |
visible=sources_visible and allow_upload_to_user_data)
|
892 |
with gr.Column(scale=4):
|
893 |
pass
|
894 |
+
visible_add_remove_collection = visible_upload
|
895 |
with gr.Row():
|
896 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
897 |
add_placeholder = "e.g. UserData2, shared, user_path2" \
|
898 |
if not is_public else "e.g. MyData2, personal (optional)"
|
899 |
remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
|
|
|
1146 |
)
|
1147 |
min_max_new_tokens = gr.Slider(
|
1148 |
minimum=1, maximum=max_max_new_tokens, step=1,
|
1149 |
+
value=min(max_max_new_tokens, kwargs['min_max_new_tokens']),
|
1150 |
+
label="Min. of Max output length",
|
1151 |
)
|
1152 |
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
|
1153 |
value=kwargs['early_stopping'], visible=max_beams > 1)
|
|
|
2885 |
history = args_list[-1]
|
2886 |
if not history:
|
2887 |
history = []
|
|
|
2888 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
2889 |
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
|
2890 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|