gemma-2-2b-it-PromptEnhancing
gemma-2-2b-it-PromptEnhancing is a LoRA-finetuned instruction-tuned text-generation model.
This model was released alongside three other models in the 2-3b parameters range, all trained on the same dataset with the same training arguments.
Model Details
Model Description
This model is a LoRA fine-tune of google/gemma-2-2b-it. The goal of this finetune is to provide a light-weight prompt enhancing model for stable diffusion (or other diffusers sharing the same prompting conventions) to make image generation more accessible to everyone.
- Developed by: groloch
- Model type: LoRA
- Language(s) (NLP): English
- License: gemma
- Finetuned from model: google/gemma-2-2b-it
Model Sources [optional]
- Paper: Coming soon
- Demo: Coming soon
Uses
This model should be used as a prompt-enhancing model for diffusers. To use it, the simplest is to try out at the official demo (coming soon).
Direct Use
If you want to use it locally, refer to the following code snippet:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
base_repo_id = 'google/gemma-2-2b-it'
adapter_repo_id = 'groloch/gemma-2-2b-it-PromptEnhancing'
tokenizer = AutoTokenizer.from_pretrained(base_repo_id)
model = AutoModelForCausalLM.from_pretrained(base_repo_id, torch_dtype=torch.bfloat16).to('cuda')
model.load_adapter(adapter_repo_id)
prompt_to_enhance = 'Sinister crocodile eating a jolly rabbit'
chat = [
{'role' : 'user', 'content': prompt_to_enhance}
]
prompt = tokenizer.apply_chat_template(chat,
tokenize=False,
add_generation_prompt=True,
return_tensors='pt')
encoding = tokenizer(prompt, return_tensors="pt").to('cuda')
generation_config = model.generation_config
generation_config.do_sample = True
generation_config.max_new_tokens = 96
generation_config.temperature = 0.3
generation_config.top_p = 0.7
generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.repetition_penalty = 2.0
with torch.inference_mode():
outputs = model.generate(
input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask,
generation_config=generation_config
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Out-of-Scope Use
This model is meant to be used as a prompt enhancer. Inputs should be concise and not too detailed (no full prompts).
Using this model for other purposes may yield unexpected behavior.
Bias, Risks, and Limitations
This model was trained on a dataset partially generated by AI, which may contain bias.
This is a pretty lightweight model, so it may have significant limitations.
Recommendations
Use high repetition penalty (> 2.0) and low temperature (< 0.4) for generation. Do not generate more than 128 tokens.
Training Details
Training Data
This model was trained for one epoch on groloch/stable_diffusion_prompts_instruct.
Training Hyperparameters
coming soon
- PEFT 0.13.2
- Downloads last month
- 227