How to introduce stop_strings in llama3?
#132
by
Srinjoy
- opened
Hey all, I want to stop my text generation when I encounter certain strings. The code I am using is as follows:
def get_tokenizer_model(config):
global generator_model
global model_tokenizer
if model_tokenizer is None:
load_dotenv(find_dotenv())
HF_TOKEN = os.getenv('HF_TOKEN')
model_tokenizer = AutoTokenizer.from_pretrained(
config['model_name'],
token=HF_TOKEN
)
print(f'======================>The model name is :{config["model_name"]}')
bnb_config = None
if config['quantisation']:
print('======================>The model is quantised?', config["quantisation"])
bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = 'nf4',
bnb_4bit_compute_dtype = torch.bfloat16
)
generator_model = AutoModelForCausalLM.from_pretrained(
config['model_name'],
device_map='auto',
quantization_config = bnb_config
)
model_tokenizer.pad_token = model_tokenizer.eos_token
# text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device = 'auto')
print(f'======================>The device used by pipeline is:{generator_model.device}')
return model_tokenizer, generator_model
and the model is used here:
def get_hf_chat(prompt:str, model: Model = 'llama3-8b-8192', temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False, args = None) -> str:
config = {
'quantisation':args.quantised,
'model_name': args.model,
'stop_strings': stop_strs
}
global generator_model
global model_tokenizer
model_tokenizer, generator_model = get_tokenizer_model(config=config)
inputs = model_tokenizer(prompt, return_tensors='pt')
gen_out = generator_model.generate(**inputs,
temperature=temperature,
max_new_tokens = max_tokens,
stop_strings = stop_strs,
tokenizer=model_tokenizer)
output_text = model_tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0][len(prompt):]
return output_text
When I call the above function I am encountering the following error:
File "/data/home/srinjoym/reflexion/alfworld_runs/demo.py", line 207, in <module>
print(get_hf_chat(prompt=prompt1, model='meta-llama/Meta-Llama-3-8B-Instruct', temperature=0.2, max_tokens=512, stop_strs=['\n'], args = args))
File "/data/home/srinjoym/reflexion/alfworld_runs/utils.py", line 146, in get_hf_chat
gen_out = generator_model.generate(**inputs,
File "/data/home/srinjoym/miniconda3/envs/alfworld/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/data/home/srinjoym/miniconda3/envs/alfworld/lib/python3.9/site-packages/transformers/generation/utils.py", line 1384, in generate
self._validate_model_kwargs(model_kwargs.copy())
File "/data/home/srinjoym/miniconda3/envs/alfworld/lib/python3.9/site-packages/transformers/generation/utils.py", line 1130, in _validate_model_kwargs
raise ValueError(
ValueError: The following `model_kwargs` are not used by the model: ['stop_strings', 'tokenizer'] (note: typos in the generate arguments will also show up in this list)
I used the code from this place
My question is how can i use/introduce the stop strings in this model? Also, do we need to checked whether stop_strings argument can be used for each model or is there some function standard all pretrained huggingface models follow?