lemur-7B / utils /inference.py
tianyang's picture
change fp16 to bf16 since no GPU
92828cc
raw
history blame
4.43 kB
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from peft import PeftModel
from typing import Iterator
from variables import SYSTEM, HUMAN, AI
def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
"""
Loads the tokenizer and chatbot model.
Args:
base_model (str): The base model to use (path to the model).
adapter_model (str): The LoRA model to use (path to LoRA model).
load_8bit (bool): Whether to load the model in 8-bit mode.
"""
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except:
pass
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device}
)
if adapter_model is not None:
model = PeftModel.from_pretrained(
model,
adapter_model,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device},
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
offload_folder="."
)
if adapter_model is not None:
model = PeftModel.from_pretrained(
model,
adapter_model,
torch_dtype=torch.bfloat16,
offload_folder="."
)
model.eval()
return tokenizer, model, device
class State:
interrupted = False
def interrupt(self):
self.interrupted = True
def recover(self):
self.interrupted = False
shared_state = State()
def decode(
input_ids: torch.Tensor,
model: PeftModel,
tokenizer: LlamaTokenizer,
stop_words: list,
max_length: int,
temperature: float = 1.0,
top_p: float = 1.0,
) -> Iterator[str]:
generated_tokens = []
past_key_values = None
for _ in range(max_length):
with torch.no_grad():
if past_key_values is None:
outputs = model(input_ids)
else:
outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
# apply temperature
logits /= temperature
probs = torch.softmax(logits, dim=-1)
# apply top_p
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
input_ids = torch.cat((input_ids, next_token), dim=-1)
generated_tokens.append(next_token[0].item())
text = tokenizer.decode(generated_tokens)
yield text
if any([x in text for x in stop_words]):
return
def get_prompt_with_history(text, history, tokenizer, max_length=2048):
prompt = SYSTEM
history = [f"\n{HUMAN} {x[0]}\n{AI} {x[1]}" for x in history]
history.append(f"\n{HUMAN} {text}\n{AI}")
history_text = ""
flag = False
for x in history[::-1]:
if (
tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size(
-1
)
<= max_length
):
history_text = x + history_text
flag = True
else:
break
if flag:
return prompt + history_text, tokenizer(
prompt + history_text, return_tensors="pt"
)
else:
return None
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
for stop_word in stop_words:
if s.endswith(stop_word):
return True
for i in range(1, len(stop_word)):
if s.endswith(stop_word[:i]):
return True
return False