protobench / models.py
vtrv.vls
Functionality rework
3bef8f9
raw
history blame
6.28 kB
import requests
import json
import torch
import os
from datetime import datetime, timedelta
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
class GigaChat:
def __init__(self, auth_file='auth_token.json'):
# url = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
self.auth_url = "https://api.mlrnd.ru/api/v2/oauth"
# url = "https://gigachat.devices.sberbank.ru/api/v1/chat/completions"
self.gen_url = "https://api.mlrnd.ru/api/v1/chat/completions"
# payload='scope=GIGACHAT_API_CORP'
self.payload='scope=API_v1'
self.auth_file = auth_file
if self.auth_file is None or not os.path.isfile(auth_file):
self.gen_giga_token(auth_file)
@classmethod
def get_giga(cls, auth_file='auth_token.json'):
print('got giga')
return cls(auth_file)
def gen_giga_token(self, auth_file):
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json',
'RqUID': '1b519047-0ee9-4b63-8599-e5ffc9c77e72',
'Authorization': os.getenv('GIGACHAT_API_TOKEN')
}
response = requests.request(
"POST",
self.auth_url,
headers=headers,
data=self.payload,
verify=False
)
with open(auth_file, 'w') as f:
json.dump(json.loads(response.text), f, ensure_ascii=False)
def get_text(self, content, auth_token=None, params=None):
if params is None:
params = dict()
payload = json.dumps(
{
"model": "Test_model",
"messages": content,
"temperature": params.get("temperature") if params.get("temperature") else 1,
"top_p": params.get("top_p") if params.get("top_p") else 0.9,
"n": params.get("n") if params.get("n") else 1,
"stream": False,
"max_tokens": params.get("max_tokens") if params.get("max_tokens") else 512,
"repetition_penalty": params.get("repetition_penalty") if params.get("repetition_penalty") else 1
}
)
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {auth_token}'
}
response = requests.request("POST", self.gen_url, headers=headers, data=payload, verify=False)
return json.loads(response.text)
def get_tinyllama():
print('got llama')
tinyllama = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
return tinyllama
def get_qwen2ins1b():
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-1.5B-Instruct",
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
return {'model': model, 'tokenizer': tokenizer}
def response_tinyllama(
model=None,
messages=None,
params=None
):
if params is None:
params = dict()
messages_dict = [
{
"role": "system",
"content": "You are a friendly and helpful chatbot",
}
]
for step in messages:
messages_dict.append({'role': 'user', 'content': step[0]})
if len(step) >= 2:
messages_dict.append({'role': 'assistant', 'content': step[1]})
prompt = model.tokenizer.apply_chat_template(messages_dict, tokenize=False, add_generation_prompt=True)
outputs = model(
prompt,
max_new_tokens = params.get("max_tokens") if params.get("max_tokens") else 512,
temperature = params.get("temperature") if params.get("temperature") else 1,
top_p = params.get("top_p") if params.get("top_p") else 0.9,
repetition_penalty = params.get("repetition_penalty") if params.get("repetition_penalty") else 1
)
return outputs[0]['generated_text'].split('<|assistant|>')[1].strip()
def response_qwen2ins1b(
model=None,
messages=None,
params=None
):
messages_dict = [
{
"role": "system",
"content": "You are a friendly and helpful chatbot",
}
]
for step in messages:
messages_dict.append({'role': 'user', 'content': step[0]})
if len(step) >= 2:
messages_dict.append({'role': 'assistant', 'content': step[1]})
text = model['tokenizer'].apply_chat_template(
messages_dict,
tokenize=False,
add_generation_prompt=True
)
model_inputs = model['tokenizer']([text], return_tensors="pt")
generated_ids = model['model'].generate(
model_inputs.input_ids,
max_new_tokens = params.get("max_tokens") if params.get("max_tokens") else 512,
temperature = params.get("temperature") if params.get("temperature") else 1,
top_p = params.get("top_p") if params.get("top_p") else 0.9,
repetition_penalty = params.get("repetition_penalty") if params.get("repetition_penalty") else 1
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = model['tokenizer'].batch_decode(generated_ids, skip_special_tokens=True)[0]
return response # outputs[0]['generated_text'] #.split('<|assistant|>')[1].strip()
def response_gigachat(
model=None,
messages=None,
model_params=None
): # content=None, auth_file=None
with open(model.auth_file) as f:
auth_token = json.load(f)
if datetime.fromtimestamp(auth_token['expires_at']/1000) <= datetime.now() - timedelta(seconds=60):
model.gen_giga_token(model.auth_file)
with open(model.auth_file) as f:
auth_token = json.load(f)
content = []
for step in messages:
content.append({'role': 'user', 'content': step[0]})
if len(step) >= 2:
content.append({'role': 'assistant', 'content': step[1]})
resp = model.get_text(content, auth_token['access_token'], model_params)
return resp["choices"][0]["message"]["content"]