Spaces:
Running
on
Zero
Running
on
Zero
from contextlib import nullcontext | |
from random import shuffle | |
import torch | |
from llama_cpp import Llama | |
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase | |
def generate( | |
model: PreTrainedModel | Llama, | |
tokenizer: PreTrainedTokenizerBase, | |
prompt="", | |
temperature=0.5, | |
top_p=0.95, | |
top_k=45, | |
repetition_penalty=1.17, | |
max_new_tokens=128, | |
autocast_gen=lambda: torch.autocast("cpu", enabled=False), | |
**kwargs, | |
): | |
if isinstance(model, Llama): | |
result = model.create_completion( | |
prompt, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
max_tokens=max_new_tokens, | |
repeat_penalty=repetition_penalty or 1, | |
) | |
return prompt + result["choices"][0]["text"] | |
torch.cuda.empty_cache() | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(next(model.parameters()).device) | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
**kwargs, | |
) | |
with torch.no_grad(), autocast_gen(): | |
generation_output = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
max_new_tokens=max_new_tokens, | |
) | |
s = generation_output.sequences[0] | |
output = tokenizer.decode(s) | |
torch.cuda.empty_cache() | |
return output | |
def tag_gen( | |
text_model, | |
tokenizer, | |
prompt, | |
prompt_tags, | |
len_target, | |
black_list, | |
temperature=0.5, | |
top_p=0.95, | |
top_k=100, | |
max_new_tokens=256, | |
max_retry=5, | |
): | |
prev_len = 0 | |
retry = max_retry | |
llm_gen = "" | |
while True: | |
llm_gen = generate( | |
model=text_model, | |
tokenizer=tokenizer, | |
prompt=prompt, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=None, | |
max_new_tokens=max_new_tokens, | |
stream_output=False, | |
autocast_gen=nullcontext, | |
prompt_lookup_num_tokens=10, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
llm_gen = llm_gen.replace("</s>", "").replace("<s>", "") | |
extra = llm_gen.split("<|input_end|>")[-1].strip().strip(",") | |
extra_tokens = list( | |
set( | |
[ | |
tok.strip() | |
for tok in extra.split(",") | |
if tok.strip() not in black_list | |
] | |
) | |
) | |
llm_gen = llm_gen.replace(extra, ", ".join(extra_tokens)) | |
yield llm_gen, extra_tokens | |
if len(prompt_tags) + len(extra_tokens) < len_target: | |
if len(extra_tokens) == prev_len and prev_len > 0: | |
if retry < 0: | |
break | |
retry -= 1 | |
shuffle(extra_tokens) | |
retry = max_retry | |
prev_len = len(extra_tokens) | |
prompt = llm_gen.strip().replace(" <|", " <|") | |
else: | |
break | |
yield llm_gen, extra_tokens | |