|
--- |
|
license: apache-2.0 |
|
library_name: peft |
|
pipeline_tag: text-generation |
|
inference: false |
|
tags: |
|
- finetuned |
|
- prompt generator |
|
- stable-diffusion |
|
- mistralai/Mistral-7B-v0.1 |
|
- 'inference: false' |
|
--- |
|
|
|
<br> |
|
|
|
<h1 style="font-size: 2em;">Prompt Diffusion π</h1> |
|
|
|
|<img src="coffee.png" width="550"/>|<img src="skii.png" width="550"/>|<img src="dance_rain.png" width="550"/>|<img src="man_beach.png" width="550"/>| |
|
|---------------------------------------|----------------------------------|-----------------------------------------|--------------------------------------| |
|
|
|
|
|
This model was build via parameter-efficient QLoRA finetuning of [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) on a custom dataset. |
|
The goal of this project is to build a model which takes simple and short natural language inputs and generate a highly detailed and descriptive prompt for [Stable Diffusion](https://stability.ai/stable-diffusion) and other similar image generation models. |
|
Hence making it easier for everyone to get high quality images easily. |
|
|
|
## Note |
|
|
|
>This is the first iteration of this project and it was to check the capabilities of Mistral-7B model. |
|
>As of now the model is fine-tuned mostly on the prompts which includes humans, with description of the subject and the scene. |
|
>This is a work in progress and this is just a prototype. Hence, one can expect issues with response from the model. |
|
|
|
## Input Template |
|
|
|
``` |
|
###Human: |
|
generate a stable diffusion prompt for {input} |
|
|
|
###Assistant: |
|
``` |
|
|
|
## Output Template |
|
|
|
``` |
|
###Human: |
|
generate a stable diffusion prompt for a model in a cafe sipping coffee |
|
|
|
###Assistant: |
|
A model in a cafe sipping coffee, with a warm glow illuminating her face. She wears a stylish dress and has a book in her hand. The background is a cozy cafe with dim lights. |
|
``` |
|
|
|
## Example prompts and responses |
|
|
|
Example 1: |
|
|
|
**User**: |
|
|
|
>a model in a cafe sipping coffee |
|
|
|
**Prompt Diffusion** |
|
|
|
>A model in a cafe sipping coffee, with a warm glow illuminating her face. |
|
She wears a stylish dress and has a book in her hand. The background is a cozy cafe with dim lights. |
|
|
|
<br> |
|
|
|
Example 2: |
|
|
|
**User**: |
|
|
|
>a man skiing |
|
|
|
**Prompt Diffusion** |
|
|
|
>A man in a ski suit, with a helmet and goggles, is skiing down a snowy mountain, with the wind blowing in his face and the sun shining brightly. |
|
The man is enjoying the thrill of skiing, as he glides down the slopes, leaving behind a trail of fresh powder. |
|
|
|
<br> |
|
|
|
Example 3: |
|
|
|
**User**: |
|
|
|
>a beautiful woman dancing in rain |
|
|
|
**Prompt Diffusion** |
|
|
|
>A beautiful woman, dressed in a flowing white gown, dances gracefully in the rain. |
|
The raindrops fall gently around her, creating a magical and ethereal atmosphere. |
|
She moves with grace and elegance, her long hair flowing in the wind. |
|
The rain creates a shimmering effect on her gown, making her look like a goddess. |
|
|
|
<br> |
|
|
|
Example 4: |
|
|
|
**User**: |
|
|
|
>a man on a beach |
|
|
|
**Prompt Diffusion** |
|
|
|
>A man sits on a beach, the waves crashing against the shore. |
|
The sun is setting, casting a warm glow over the sand and water. |
|
The man is lost in thought, his mind wandering as he takes in the beauty of the scene. |
|
|
|
<br> |
|
|
|
## Demo |
|
|
|
>This code can be run even on the free version of Google Colab. Change the runtime to **GPU - T4** and run the notebook below: |
|
|
|
## <a href="https://colab.research.google.com/drive/1NfNA6sYWUG0Nqy42lhyJfUIqDxsrV64d?usp=sharing" target="_blank" style="color: orange; font-weight: bold;">Try on Colab</a> π |
|
|
|
|
|
## Basic usage |
|
|
|
```python |
|
!pip install git+https://github.com/huggingface/transformers |
|
!pip install git+https://github.com/huggingface/peft.git |
|
!pip install torch |
|
!pip install -q bitsandbytes accelerate |
|
``` |
|
|
|
```python |
|
#Importing libraries |
|
|
|
from peft import PeftConfig, PeftModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
import torch |
|
import re |
|
``` |
|
|
|
```python |
|
#Loading adapter model and merging it with base model for inferencing |
|
|
|
torch.set_default_device('cuda') |
|
|
|
peft_model_id = "abhishek7/Prompt_diffusion-v0.1" |
|
|
|
config = PeftConfig.from_pretrained(peft_model_id) |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
config.base_model_name_or_path, |
|
low_cpu_mem_usage=True, |
|
load_in_4bit=True, |
|
quantization_config=bnb_config, |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
) |
|
|
|
model = PeftModel.from_pretrained(model, peft_model_id) |
|
|
|
model = model.merge_and_unload() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, trust_remote_code=True) |
|
tokenizer.padding_side = "right" |
|
``` |
|
|
|
```python |
|
# Function to truncate text based on punctuation count |
|
|
|
def truncate_text(text, max_punctuation): |
|
punctuation_count = 0 |
|
truncated_text = "" |
|
for char in text: |
|
truncated_text += char |
|
if char in [',', '.']: |
|
punctuation_count += 1 |
|
if punctuation_count >= max_punctuation: |
|
break |
|
# Replace the last comma with a full stop if the last punctuation is a comma |
|
if truncated_text.rstrip()[-1] == ',': |
|
truncated_text = truncated_text.rstrip()[:-1] + '.' |
|
|
|
return truncated_text |
|
|
|
|
|
# Function to generate prompt |
|
|
|
def generate_prompt(input, max_length, temperature): |
|
input_context = f''' |
|
###Human: |
|
generate a stable diffusion prompt for {input} |
|
|
|
###Assistant: |
|
''' |
|
|
|
inputs = tokenizer.encode(input_context, return_tensors="pt") |
|
outputs = model.generate(inputs, max_length=max_length, temperature=temperature, num_return_sequences=1) |
|
|
|
output_text = tokenizer.decode(outputs[0], skip_special_tokens = True) |
|
|
|
# Extract the Assistant's response using regex |
|
match = re.search(r'###Assistant:(.*?)(###Human:|$)', output_text, re.DOTALL) |
|
if match: |
|
assistant_response = match.group(1) |
|
else: |
|
raise ValueError("No Assistant response found") |
|
|
|
|
|
|
|
# Truncate the Assistant's response based on the criteria |
|
truncated_response = truncate_text(assistant_response, max_punctuation=10) |
|
|
|
return truncated_response |
|
``` |
|
|
|
```python |
|
# Usage: |
|
|
|
input_text = "a beautiful woman dancing in rain" |
|
|
|
prompt = generate_prompt(input_text, 150, 0.3) |
|
|
|
print("\nPrompt: " + prompt) |
|
``` |
|
|
|
## Contributing |
|
|
|
Contributions are welcome! If you find any bugs, create an issue or submit a pull request with your proposed changes. |
|
|
|
## Acknowledgements |
|
|
|
This model was finetuned by [Abhishek Kalra](https://github.com/abhishek7kalra) on Sep 29, 2023 and is for research applications only. |
|
|
|
[![Gmail](https://img.shields.io/badge/Gmail-D14836?style=for-the-badge&logo=gmail&logoColor=white)](mailto:[email protected]) |
|
|
|
## mistralai/Mistral-7B-v0.1 citation |
|
|
|
``` |
|
coming |
|
``` |
|
|
|
## Framework versions |
|
|
|
- PEFT 0.6.0.dev0 |