|
import torch |
|
from transformers import LlamaTokenizer, LlamaForCausalLM |
|
from peft import PeftModel |
|
from typing import Iterator |
|
from variables import SYSTEM, HUMAN, AI |
|
|
|
|
|
def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True): |
|
""" |
|
Loads the tokenizer and chatbot model. |
|
|
|
Args: |
|
base_model (str): The base model to use (path to the model). |
|
adapter_model (str): The LoRA model to use (path to LoRA model). |
|
load_8bit (bool): Whether to load the model in 8-bit mode. |
|
""" |
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
|
|
try: |
|
if torch.backends.mps.is_available(): |
|
device = "mps" |
|
except: |
|
pass |
|
tokenizer = LlamaTokenizer.from_pretrained(base_model) |
|
if device == "cuda": |
|
model = LlamaForCausalLM.from_pretrained( |
|
base_model, |
|
load_in_8bit=load_8bit, |
|
torch_dtype=torch.float16 |
|
) |
|
elif device == "mps": |
|
model = LlamaForCausalLM.from_pretrained( |
|
base_model, |
|
device_map={"": device} |
|
) |
|
if adapter_model is not None: |
|
model = PeftModel.from_pretrained( |
|
model, |
|
adapter_model, |
|
device_map={"": device}, |
|
torch_dtype=torch.float16, |
|
) |
|
else: |
|
model = LlamaForCausalLM.from_pretrained( |
|
base_model, |
|
device_map={"": device}, |
|
low_cpu_mem_usage=True, |
|
torch_dtype=torch.bfloat16, |
|
offload_folder="." |
|
) |
|
if adapter_model is not None: |
|
model = PeftModel.from_pretrained( |
|
model, |
|
adapter_model, |
|
torch_dtype=torch.bfloat16, |
|
offload_folder="." |
|
) |
|
|
|
model.eval() |
|
return tokenizer, model, device |
|
|
|
class State: |
|
interrupted = False |
|
|
|
def interrupt(self): |
|
self.interrupted = True |
|
|
|
def recover(self): |
|
self.interrupted = False |
|
|
|
shared_state = State() |
|
|
|
def decode( |
|
input_ids: torch.Tensor, |
|
model: PeftModel, |
|
tokenizer: LlamaTokenizer, |
|
stop_words: list, |
|
max_length: int, |
|
temperature: float = 1.0, |
|
top_p: float = 1.0, |
|
) -> Iterator[str]: |
|
generated_tokens = [] |
|
past_key_values = None |
|
|
|
for _ in range(max_length): |
|
with torch.no_grad(): |
|
if past_key_values is None: |
|
outputs = model(input_ids) |
|
else: |
|
outputs = model(input_ids[:, -1:], past_key_values=past_key_values) |
|
logits = outputs.logits[:, -1, :] |
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
logits /= temperature |
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
|
probs_sum = torch.cumsum(probs_sort, dim=-1) |
|
mask = probs_sum - probs_sort > top_p |
|
probs_sort[mask] = 0.0 |
|
|
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
|
next_token = torch.multinomial(probs_sort, num_samples=1) |
|
next_token = torch.gather(probs_idx, -1, next_token) |
|
|
|
input_ids = torch.cat((input_ids, next_token), dim=-1) |
|
|
|
generated_tokens.append(next_token[0].item()) |
|
text = tokenizer.decode(generated_tokens) |
|
|
|
yield text |
|
if any([x in text for x in stop_words]): |
|
return |
|
|
|
|
|
def get_prompt_with_history(text, history, tokenizer, max_length=2048): |
|
prompt = SYSTEM |
|
history = [f"\n{HUMAN} {x[0]}\n{AI} {x[1]}" for x in history] |
|
history.append(f"\n{HUMAN} {text}\n{AI}") |
|
history_text = "" |
|
flag = False |
|
for x in history[::-1]: |
|
if ( |
|
tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size( |
|
-1 |
|
) |
|
<= max_length |
|
): |
|
history_text = x + history_text |
|
flag = True |
|
else: |
|
break |
|
if flag: |
|
return prompt + history_text, tokenizer( |
|
prompt + history_text, return_tensors="pt" |
|
) |
|
else: |
|
return None |
|
|
|
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool: |
|
for stop_word in stop_words: |
|
if s.endswith(stop_word): |
|
return True |
|
for i in range(1, len(stop_word)): |
|
if s.endswith(stop_word[:i]): |
|
return True |
|
return False |