File size: 3,378 Bytes
8df0f23
 
 
 
 
d8f6559
8df0f23
 
 
d8f6559
 
 
8df0f23
d8f6559
8df0f23
 
 
 
 
 
 
 
d8f6559
 
 
 
 
 
 
 
8df0f23
 
 
 
 
 
 
 
 
 
 
 
d8f6559
 
 
 
 
 
 
 
8df0f23
 
d8f6559
 
8df0f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8f6559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

from openai import OpenAI
import logging
from typing import List
import os 
from constant import HYPERBOLIC_MODELS, MODEL_MAPPING


def model_name_mapping(model_name):
    model_mapping = MODEL_MAPPING   
    if model_name in model_mapping:
        return model_mapping[model_name]
    else:
        raise ValueError("Invalid model name:", model_name)


def urial_template(urial_prompt, history, message):
    current_prompt = urial_prompt + "\n"
    for user_msg, ai_msg in history:
        current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n'
    current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n'
    return current_prompt

def chat_template(history, message):
    messages = [] 
    for user_msg, ai_msg in history:
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": ai_msg})
    messages.append({"role": "user", "content": message})
    return messages

def openai_base_request(
    model: str=None, 
    temperature: float=0,
    max_tokens: int=512,
    top_p: float=1.0, 
    prompt: str=None,
    n: int=1, 
    repetition_penalty: float=1.0,
    stop: List[str]=None, 
    api_key: str=None,
    ):  

    if model in HYPERBOLIC_MODELS:
        BASE_URL = "https://api.hyperbolic.xyz/v1"
        DEFAULT_API_KEY = os.getenv("HYPERBOLIC_API_KEY")
    else:
        BASE_URL = "https://api.together.xyz/v1"
        DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY") 

    if api_key is None:
        api_key = DEFAULT_API_KEY
    client = OpenAI(api_key=api_key, base_url=BASE_URL) 
    logging.info(f"Requesting base completion from OpenAI API with model {model}")
    logging.info(f"Prompt: {prompt}")
    logging.info(f"Temperature: {temperature}")
    logging.info(f"Max tokens: {max_tokens}")
    logging.info(f"Top-p: {top_p}")
    logging.info(f"Repetition penalty: {repetition_penalty}")
    logging.info(f"Stop: {stop}")

    request = client.completions.create(
        model=model, 
        prompt=prompt,
        temperature=float(temperature),
        max_tokens=int(max_tokens),
        top_p=float(top_p),
        n=n,
        extra_body={'repetition_penalty': float(repetition_penalty)},
        stop=stop, 
        stream=True
    ) 
    
    return request 



def openai_chat_request(
    model: str=None, 
    temperature: float=0,
    max_tokens: int=512,
    top_p: float=1.0, 
    messages=None,
    n: int=1, 
    repetition_penalty: float=1.0,
    stop: List[str]=None, 
    api_key: str=None,
    ):  

    if model in HYPERBOLIC_MODELS:
        BASE_URL = "https://api.hyperbolic.xyz/v1"
        DEFAULT_API_KEY = os.getenv("HYPERBOLIC_API_KEY")
    else:
        BASE_URL = "https://api.together.xyz/v1"
        DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY") 

    if api_key is None:
        api_key = DEFAULT_API_KEY
    
    logging.info(f"Requesting chat completion from OpenAI API with model {model}")

    client = OpenAI(api_key=api_key, base_url=BASE_URL)  

    request = client.chat.completions.create(
        model=model, 
        messages=messages,
        temperature=float(temperature),
        max_tokens=int(max_tokens),
        top_p=float(top_p),
        n=n,
        extra_body={'repetition_penalty': float(repetition_penalty)},
        stop=stop, 
        stream=True
    )  
    return request