|
import os |
|
from threading import Thread |
|
from typing import Iterator |
|
|
|
import gradio as gr |
|
import spaces |
|
import torch |
|
from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
|
|
|
MAX_MAX_NEW_TOKENS = 2048 |
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
model_path = "vinai/PhoGPT-4B-Chat" |
|
|
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
|
config.init_device = device |
|
|
|
quantization = BitsAndBytesConfig(load_in_8bit=True) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path, |
|
config=config, |
|
quantization_config =quantization, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True) |
|
|
|
model.eval() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def generate( |
|
message: str, |
|
chat_history: list[tuple[str, str]], |
|
max_new_tokens: int = 1024, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
top_k: int = 50, |
|
repetition_penalty: float = 1.2, |
|
) -> Iterator[str]: |
|
conversation = [] |
|
for user, assistant in chat_history: |
|
conversation.extend( |
|
[ |
|
{"role": "user", "content": user}, |
|
{"role": "assistant", "content": assistant}, |
|
] |
|
) |
|
conversation.append({"role": "user", "content": message}) |
|
|
|
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") |
|
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
|
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
|
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") |
|
input_ids = input_ids.to(model.device) |
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs ={ |
|
"input_ids": input_ids, |
|
"streamer":streamer, |
|
"max_new_tokens":max_new_tokens, |
|
"do_sample":True, |
|
"top_p":top_p, |
|
"top_k":top_k, |
|
"temperature":temperature, |
|
"num_beams":1, |
|
"repetition_penalty":repetition_penalty, |
|
"eos_token_id":tokenizer.eos_token_id, |
|
"pad_token_id":tokenizer.pad_token_id |
|
} |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
outputs = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
yield "".join(outputs) |
|
|
|
|
|
chat_interface = gr.ChatInterface( |
|
fn=generate, |
|
chatbot=gr.Chatbot(height=500, label = "VN GPT", show_label=True), |
|
textbox=gr.Textbox(placeholder="Nhập hội thoại tại đây", container=False, scale=7), |
|
additional_inputs=[ |
|
gr.Slider( |
|
label="Độ dài token", |
|
minimum=1, |
|
maximum=MAX_MAX_NEW_TOKENS, |
|
step=1, |
|
value=DEFAULT_MAX_NEW_TOKENS, |
|
), |
|
gr.Slider( |
|
label="Độ sáng tạo", |
|
minimum=0.1, |
|
maximum=4.0, |
|
step=0.1, |
|
value=0.6, |
|
), |
|
gr.Slider( |
|
label="Lựa chọn từ dựa trên xác suất tích lũy", |
|
minimum=0.05, |
|
maximum=1.0, |
|
step=0.05, |
|
value=0.9, |
|
), |
|
gr.Slider( |
|
label="Lựa chọn k từ có xác suất cao nhất", |
|
minimum=1, |
|
maximum=1000, |
|
step=1, |
|
value=50, |
|
), |
|
gr.Slider( |
|
label="Phạt lặp lại", |
|
minimum=1.0, |
|
maximum=2.0, |
|
step=0.05, |
|
value=1.2, |
|
), |
|
], |
|
theme="soft", |
|
stop_btn=None, |
|
examples = [ |
|
["Lợi ích của sữa mẹ ?"], |
|
["Sữa non là gì ?"], |
|
["Trẻ sơ sinh cần ngủ bao nhiêu giờ mỗi ngày?"], |
|
["Bao lâu nên cho trẻ sơ sinh bú một lần?"], |
|
["Khi nào nên bắt đầu cho trẻ ăn dặm?"], |
|
["Làm thế nào để giúp trẻ ngủ ngon vào ban đêm?"] |
|
], |
|
|
|
cache_examples=False, |
|
title = "VN-GPT", |
|
clear_btn="🗑️ Xóa", |
|
undo_btn="↩️ Hoàn tác", |
|
submit_btn="🚀 Gửi", |
|
retry_btn="🔄 Thử lại", |
|
additional_inputs_accordion="Tùy chỉnh nâng cao", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
chat_interface.queue(max_size=20).launch() |