Edit model card

gemma-2-2b-it-PromptEnhancing

gemma-2-2b-it-PromptEnhancing is a LoRA-finetuned instruction-tuned text-generation model.

This model was released alongside three other models in the 2-3b parameters range, all trained on the same dataset with the same training arguments.

Model Details

Model Description

This model is a LoRA fine-tune of google/gemma-2-2b-it. The goal of this finetune is to provide a light-weight prompt enhancing model for stable diffusion (or other diffusers sharing the same prompting conventions) to make image generation more accessible to everyone.

Model Sources [optional]

  • Paper: Coming soon
  • Demo: Coming soon

Uses

This model should be used as a prompt-enhancing model for diffusers. To use it, the simplest is to try out at the official demo (coming soon).

Direct Use

If you want to use it locally, refer to the following code snippet:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


base_repo_id = 'google/gemma-2-2b-it'
adapter_repo_id = 'groloch/gemma-2-2b-it-PromptEnhancing'

tokenizer = AutoTokenizer.from_pretrained(base_repo_id)
model = AutoModelForCausalLM.from_pretrained(base_repo_id, torch_dtype=torch.bfloat16).to('cuda')
model.load_adapter(adapter_repo_id)

prompt_to_enhance = 'Sinister crocodile eating a jolly rabbit'

chat = [
    {'role' : 'user', 'content': prompt_to_enhance}
]

prompt = tokenizer.apply_chat_template(chat, 
                                       tokenize=False, 
                                       add_generation_prompt=True,
                                       return_tensors='pt')

encoding = tokenizer(prompt, return_tensors="pt").to('cuda')

generation_config = model.generation_config
generation_config.do_sample = True
generation_config.max_new_tokens = 96
generation_config.temperature = 0.3
generation_config.top_p = 0.7
generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.repetition_penalty = 2.0

with torch.inference_mode():
    outputs = model.generate(
        input_ids=encoding.input_ids,
        attention_mask=encoding.attention_mask,
        generation_config=generation_config
    )
    
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Out-of-Scope Use

This model is meant to be used as a prompt enhancer. Inputs should be concise and not too detailed (no full prompts).

Using this model for other purposes may yield unexpected behavior.

Bias, Risks, and Limitations

This model was trained on a dataset partially generated by AI, which may contain bias.

This is a pretty lightweight model, so it may have significant limitations.

Recommendations

Use high repetition penalty (> 2.0) and low temperature (< 0.4) for generation. Do not generate more than 128 tokens.

Training Details

Training Data

This model was trained for one epoch on groloch/stable_diffusion_prompts_instruct.

Training Hyperparameters

coming soon

  • PEFT 0.13.2
Downloads last month
227
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for groloch/gemma-2-2b-it-PromptEnhancing

Base model

google/gemma-2-2b
Adapter
(149)
this model

Space using groloch/gemma-2-2b-it-PromptEnhancing 1

Collection including groloch/gemma-2-2b-it-PromptEnhancing